feat(redis): use asyncio

This commit is contained in:
MingxuanGame
2025-07-31 14:38:10 +00:00
parent 1635641654
commit c5fc6afc18
9 changed files with 53 additions and 64 deletions

View File

@@ -39,7 +39,7 @@ from .relationship import (
) )
from .score_token import ScoreToken from .score_token import ScoreToken
from redis import Redis from redis.asyncio import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased

View File

@@ -5,15 +5,11 @@ import json
from app.config import settings from app.config import settings
from pydantic import BaseModel from pydantic import BaseModel
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
try:
import redis
except ImportError:
redis = None
def json_serializer(value): def json_serializer(value):
if isinstance(value, BaseModel | SQLModel): if isinstance(value, BaseModel | SQLModel):
@@ -25,10 +21,7 @@ def json_serializer(value):
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer) engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
# Redis 连接 # Redis 连接
if redis: redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
else:
redis_client = None
# 数据库依赖 # 数据库依赖

View File

@@ -8,7 +8,7 @@ from app.log import logger
fetcher: Fetcher | None = None fetcher: Fetcher | None = None
def get_fetcher() -> Fetcher: async def get_fetcher() -> Fetcher:
global fetcher global fetcher
if fetcher is None: if fetcher is None:
fetcher = Fetcher( fetcher = Fetcher(
@@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher:
settings.FETCHER_CALLBACK_URL, settings.FETCHER_CALLBACK_URL,
) )
redis = get_redis() redis = get_redis()
if redis: access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}") if access_token:
if access_token: fetcher.access_token = str(access_token)
fetcher.access_token = str(access_token) refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}") if refresh_token:
if refresh_token: fetcher.refresh_token = str(refresh_token)
fetcher.refresh_token = str(refresh_token) if not fetcher.access_token or not fetcher.refresh_token:
if not fetcher.access_token or not fetcher.refresh_token: logger.opt(colors=True).info(
logger.opt(colors=True).info( f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>" )
)
return fetcher return fetcher

View File

@@ -59,16 +59,15 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "") self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"] self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis() redis = get_redis()
if redis: await redis.set(
redis.set( f"fetcher:access_token:{self.client_id}",
f"fetcher:access_token:{self.client_id}", self.access_token,
self.access_token, ex=token_data["expires_in"],
ex=token_data["expires_in"], )
) await redis.set(
redis.set( f"fetcher:refresh_token:{self.client_id}",
f"fetcher:refresh_token:{self.client_id}", self.refresh_token,
self.refresh_token, )
)
async def refresh_access_token(self) -> None: async def refresh_access_token(self) -> None:
async with AsyncClient() as client: async with AsyncClient() as client:
@@ -87,13 +86,12 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "") self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"] self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis() redis = get_redis()
if redis: await redis.set(
redis.set( f"fetcher:access_token:{self.client_id}",
f"fetcher:access_token:{self.client_id}", self.access_token,
self.access_token, ex=token_data["expires_in"],
ex=token_data["expires_in"], )
) await redis.set(
redis.set( f"fetcher:refresh_token:{self.client_id}",
f"fetcher:refresh_token:{self.client_id}", self.refresh_token,
self.refresh_token, )
)

View File

@@ -4,7 +4,7 @@ from ._base import BaseFetcher
from httpx import AsyncClient from httpx import AsyncClient
from loguru import logger from loguru import logger
import redis import redis.asyncio as redis
class OsuDotDirectFetcher(BaseFetcher): class OsuDotDirectFetcher(BaseFetcher):
@@ -23,7 +23,7 @@ class OsuDotDirectFetcher(BaseFetcher):
self, redis: redis.Redis, beatmap_id: int self, redis: redis.Redis, beatmap_id: int
) -> str: ) -> str:
if redis.exists(f"beatmap:{beatmap_id}:raw"): if redis.exists(f"beatmap:{beatmap_id}:raw"):
return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
raw = await self.get_beatmap_raw(beatmap_id) raw = await self.get_beatmap_raw(beatmap_id)
redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
return raw return raw

View File

@@ -22,7 +22,7 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query from fastapi import Depends, HTTPException, Query
from httpx import HTTPError, HTTPStatusError from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel from pydantic import BaseModel
from redis import Redis from redis.asyncio import Redis
import rosu_pp_py as rosu import rosu_pp_py as rosu
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -127,8 +127,8 @@ async def get_beatmap_attributes(
f"beatmap:{beatmap}:{ruleset}:" f"beatmap:{beatmap}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
) )
if redis.exists(key): if await redis.exists(key):
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType] return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try: try:
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap) resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
@@ -138,7 +138,7 @@ async def get_beatmap_attributes(
) )
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
redis.set(key, attr.model_dump_json()) await redis.set(key, attr.model_dump_json())
return attr return attr
except HTTPStatusError: except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found") raise HTTPException(status_code=404, detail="Beatmap not found")

View File

@@ -6,7 +6,8 @@ from app.models.room import Room
from .api_router import router from .api_router import router
from fastapi import Depends, HTTPException, Query from fastapi import Depends, Query
from redis.asyncio import Redis
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -19,17 +20,14 @@ async def get_all_rooms(
status: str = Query(None), status: str = Query(None),
category: str = Query(None), category: str = Query(None),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
): ):
all_room_ids = (await db.exec(select(RoomIndex).where(True))).all() all_room_ids = (await db.exec(select(RoomIndex).where(True))).all()
redis = get_redis()
roomsList: list[Room] = [] roomsList: list[Room] = []
if redis: for room_index in all_room_ids:
for room_index in all_room_ids: dumped_room = await redis.get(str(room_index.id))
dumped_room = redis.get(str(room_index.id)) if dumped_room:
if dumped_room: actual_room = Room.model_validate_json(str(dumped_room))
actual_room = Room.model_validate_json(str(dumped_room)) if actual_room.status == status and actual_room.category == category:
if actual_room.status == status and actual_room.category == category: roomsList.append(actual_room)
roomsList.append(actual_room) return roomsList
return roomsList
else:
raise HTTPException(status_code=500, detail="Redis Error")

View File

@@ -18,7 +18,7 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from redis import Redis from redis.asyncio import Redis
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession

View File

@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from app.config import settings from app.config import settings
from app.dependencies.database import create_tables, engine from app.dependencies.database import create_tables, engine, redis_client
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.router import api_router, auth_router, fetcher_router, signalr_router from app.router import api_router, auth_router, fetcher_router, signalr_router
@@ -15,10 +15,11 @@ from fastapi import FastAPI
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# on startup # on startup
await create_tables() await create_tables()
get_fetcher() # 初始化 fetcher await get_fetcher() # 初始化 fetcher
# on shutdown # on shutdown
yield yield
await engine.dispose() await engine.dispose()
await redis_client.aclose()
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan) app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)