feat(redis): use asyncio
This commit is contained in:
@@ -39,7 +39,7 @@ from .relationship import (
|
|||||||
)
|
)
|
||||||
from .score_token import ScoreToken
|
from .score_token import ScoreToken
|
||||||
|
|
||||||
from redis import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
|
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
|
||||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlalchemy.orm import aliased
|
from sqlalchemy.orm import aliased
|
||||||
|
|||||||
@@ -5,15 +5,11 @@ import json
|
|||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
import redis.asyncio as redis
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
try:
|
|
||||||
import redis
|
|
||||||
except ImportError:
|
|
||||||
redis = None
|
|
||||||
|
|
||||||
|
|
||||||
def json_serializer(value):
|
def json_serializer(value):
|
||||||
if isinstance(value, BaseModel | SQLModel):
|
if isinstance(value, BaseModel | SQLModel):
|
||||||
@@ -25,10 +21,7 @@ def json_serializer(value):
|
|||||||
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
|
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
|
||||||
|
|
||||||
# Redis 连接
|
# Redis 连接
|
||||||
if redis:
|
|
||||||
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||||
else:
|
|
||||||
redis_client = None
|
|
||||||
|
|
||||||
|
|
||||||
# 数据库依赖
|
# 数据库依赖
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from app.log import logger
|
|||||||
fetcher: Fetcher | None = None
|
fetcher: Fetcher | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_fetcher() -> Fetcher:
|
async def get_fetcher() -> Fetcher:
|
||||||
global fetcher
|
global fetcher
|
||||||
if fetcher is None:
|
if fetcher is None:
|
||||||
fetcher = Fetcher(
|
fetcher = Fetcher(
|
||||||
@@ -18,11 +18,10 @@ def get_fetcher() -> Fetcher:
|
|||||||
settings.FETCHER_CALLBACK_URL,
|
settings.FETCHER_CALLBACK_URL,
|
||||||
)
|
)
|
||||||
redis = get_redis()
|
redis = get_redis()
|
||||||
if redis:
|
access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
|
||||||
access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}")
|
|
||||||
if access_token:
|
if access_token:
|
||||||
fetcher.access_token = str(access_token)
|
fetcher.access_token = str(access_token)
|
||||||
refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
|
refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
|
||||||
if refresh_token:
|
if refresh_token:
|
||||||
fetcher.refresh_token = str(refresh_token)
|
fetcher.refresh_token = str(refresh_token)
|
||||||
if not fetcher.access_token or not fetcher.refresh_token:
|
if not fetcher.access_token or not fetcher.refresh_token:
|
||||||
|
|||||||
@@ -59,13 +59,12 @@ class BaseFetcher:
|
|||||||
self.refresh_token = token_data.get("refresh_token", "")
|
self.refresh_token = token_data.get("refresh_token", "")
|
||||||
self.token_expiry = int(time.time()) + token_data["expires_in"]
|
self.token_expiry = int(time.time()) + token_data["expires_in"]
|
||||||
redis = get_redis()
|
redis = get_redis()
|
||||||
if redis:
|
await redis.set(
|
||||||
redis.set(
|
|
||||||
f"fetcher:access_token:{self.client_id}",
|
f"fetcher:access_token:{self.client_id}",
|
||||||
self.access_token,
|
self.access_token,
|
||||||
ex=token_data["expires_in"],
|
ex=token_data["expires_in"],
|
||||||
)
|
)
|
||||||
redis.set(
|
await redis.set(
|
||||||
f"fetcher:refresh_token:{self.client_id}",
|
f"fetcher:refresh_token:{self.client_id}",
|
||||||
self.refresh_token,
|
self.refresh_token,
|
||||||
)
|
)
|
||||||
@@ -87,13 +86,12 @@ class BaseFetcher:
|
|||||||
self.refresh_token = token_data.get("refresh_token", "")
|
self.refresh_token = token_data.get("refresh_token", "")
|
||||||
self.token_expiry = int(time.time()) + token_data["expires_in"]
|
self.token_expiry = int(time.time()) + token_data["expires_in"]
|
||||||
redis = get_redis()
|
redis = get_redis()
|
||||||
if redis:
|
await redis.set(
|
||||||
redis.set(
|
|
||||||
f"fetcher:access_token:{self.client_id}",
|
f"fetcher:access_token:{self.client_id}",
|
||||||
self.access_token,
|
self.access_token,
|
||||||
ex=token_data["expires_in"],
|
ex=token_data["expires_in"],
|
||||||
)
|
)
|
||||||
redis.set(
|
await redis.set(
|
||||||
f"fetcher:refresh_token:{self.client_id}",
|
f"fetcher:refresh_token:{self.client_id}",
|
||||||
self.refresh_token,
|
self.refresh_token,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from ._base import BaseFetcher
|
|||||||
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import redis
|
import redis.asyncio as redis
|
||||||
|
|
||||||
|
|
||||||
class OsuDotDirectFetcher(BaseFetcher):
|
class OsuDotDirectFetcher(BaseFetcher):
|
||||||
@@ -23,7 +23,7 @@ class OsuDotDirectFetcher(BaseFetcher):
|
|||||||
self, redis: redis.Redis, beatmap_id: int
|
self, redis: redis.Redis, beatmap_id: int
|
||||||
) -> str:
|
) -> str:
|
||||||
if redis.exists(f"beatmap:{beatmap_id}:raw"):
|
if redis.exists(f"beatmap:{beatmap_id}:raw"):
|
||||||
return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
|
return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
|
||||||
raw = await self.get_beatmap_raw(beatmap_id)
|
raw = await self.get_beatmap_raw(beatmap_id)
|
||||||
redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
|
await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
|
||||||
return raw
|
return raw
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from .api_router import router
|
|||||||
from fastapi import Depends, HTTPException, Query
|
from fastapi import Depends, HTTPException, Query
|
||||||
from httpx import HTTPError, HTTPStatusError
|
from httpx import HTTPError, HTTPStatusError
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from redis import Redis
|
from redis.asyncio import Redis
|
||||||
import rosu_pp_py as rosu
|
import rosu_pp_py as rosu
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -127,8 +127,8 @@ async def get_beatmap_attributes(
|
|||||||
f"beatmap:{beatmap}:{ruleset}:"
|
f"beatmap:{beatmap}:{ruleset}:"
|
||||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||||
)
|
)
|
||||||
if redis.exists(key):
|
if await redis.exists(key):
|
||||||
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
|
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
|
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
|
||||||
@@ -138,7 +138,7 @@ async def get_beatmap_attributes(
|
|||||||
)
|
)
|
||||||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
redis.set(key, attr.model_dump_json())
|
await redis.set(key, attr.model_dump_json())
|
||||||
return attr
|
return attr
|
||||||
except HTTPStatusError:
|
except HTTPStatusError:
|
||||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from app.models.room import Room
|
|||||||
|
|
||||||
from .api_router import router
|
from .api_router import router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query
|
from fastapi import Depends, Query
|
||||||
|
from redis.asyncio import Redis
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -19,17 +20,14 @@ async def get_all_rooms(
|
|||||||
status: str = Query(None),
|
status: str = Query(None),
|
||||||
category: str = Query(None),
|
category: str = Query(None),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
redis: Redis = Depends(get_redis),
|
||||||
):
|
):
|
||||||
all_room_ids = (await db.exec(select(RoomIndex).where(True))).all()
|
all_room_ids = (await db.exec(select(RoomIndex).where(True))).all()
|
||||||
redis = get_redis()
|
|
||||||
roomsList: list[Room] = []
|
roomsList: list[Room] = []
|
||||||
if redis:
|
|
||||||
for room_index in all_room_ids:
|
for room_index in all_room_ids:
|
||||||
dumped_room = redis.get(str(room_index.id))
|
dumped_room = await redis.get(str(room_index.id))
|
||||||
if dumped_room:
|
if dumped_room:
|
||||||
actual_room = Room.model_validate_json(str(dumped_room))
|
actual_room = Room.model_validate_json(str(dumped_room))
|
||||||
if actual_room.status == status and actual_room.category == category:
|
if actual_room.status == status and actual_room.category == category:
|
||||||
roomsList.append(actual_room)
|
roomsList.append(actual_room)
|
||||||
return roomsList
|
return roomsList
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=500, detail="Redis Error")
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from .api_router import router
|
|||||||
|
|
||||||
from fastapi import Depends, Form, HTTPException, Query
|
from fastapi import Depends, Form, HTTPException, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from redis import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|||||||
5
main.py
5
main.py
@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.dependencies.database import create_tables, engine
|
from app.dependencies.database import create_tables, engine, redis_client
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.router import api_router, auth_router, fetcher_router, signalr_router
|
from app.router import api_router, auth_router, fetcher_router, signalr_router
|
||||||
|
|
||||||
@@ -15,10 +15,11 @@ from fastapi import FastAPI
|
|||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# on startup
|
# on startup
|
||||||
await create_tables()
|
await create_tables()
|
||||||
get_fetcher() # 初始化 fetcher
|
await get_fetcher() # 初始化 fetcher
|
||||||
# on shutdown
|
# on shutdown
|
||||||
yield
|
yield
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
await redis_client.aclose()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
|
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
|
||||||
|
|||||||
Reference in New Issue
Block a user