From c5fc6afc189fbe665801412c1cff9cc7a308ccd6 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 31 Jul 2025 14:38:10 +0000 Subject: [PATCH] feat(redis): use asyncio --- app/database/score.py | 2 +- app/dependencies/database.py | 11 ++-------- app/dependencies/fetcher.py | 23 ++++++++++----------- app/fetcher/_base.py | 38 +++++++++++++++++------------------ app/fetcher/osu_dot_direct.py | 6 +++--- app/router/beatmap.py | 8 ++++---- app/router/room.py | 22 +++++++++----------- app/router/score.py | 2 +- main.py | 5 +++-- 9 files changed, 53 insertions(+), 64 deletions(-) diff --git a/app/database/score.py b/app/database/score.py index c5f1a38..642eac1 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -39,7 +39,7 @@ from .relationship import ( ) from .score_token import ScoreToken -from redis import Redis +from redis.asyncio import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import aliased diff --git a/app/dependencies/database.py b/app/dependencies/database.py index fe09139..77b15c3 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -5,15 +5,11 @@ import json from app.config import settings from pydantic import BaseModel +import redis.asyncio as redis from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession -try: - import redis -except ImportError: - redis = None - def json_serializer(value): if isinstance(value, BaseModel | SQLModel): @@ -25,10 +21,7 @@ def json_serializer(value): engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer) # Redis 连接 -if redis: - redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) -else: - redis_client = None +redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) # 数据库依赖 diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py index d3c216a..806eb87 100644 --- a/app/dependencies/fetcher.py +++ b/app/dependencies/fetcher.py @@ -8,7 +8,7 @@ from app.log import logger fetcher: Fetcher | None = None -def get_fetcher() -> Fetcher: +async def get_fetcher() -> Fetcher: global fetcher if fetcher is None: fetcher = Fetcher( @@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher: settings.FETCHER_CALLBACK_URL, ) redis = get_redis() - if redis: - access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}") - if access_token: - fetcher.access_token = str(access_token) - refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}") - if refresh_token: - fetcher.refresh_token = str(refresh_token) - if not fetcher.access_token or not fetcher.refresh_token: - logger.opt(colors=True).info( - f"Login to initialize fetcher: {fetcher.authorize_url}" - ) + access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}") + if access_token: + fetcher.access_token = str(access_token) + refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}") + if refresh_token: + fetcher.refresh_token = str(refresh_token) + if not fetcher.access_token or not fetcher.refresh_token: + logger.opt(colors=True).info( + f"Login to initialize fetcher: {fetcher.authorize_url}" + ) return fetcher diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py index 08e3508..2717a35 100644 --- a/app/fetcher/_base.py +++ b/app/fetcher/_base.py @@ -59,16 +59,15 @@ class BaseFetcher: self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() - if redis: - redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) async def refresh_access_token(self) -> None: async with AsyncClient() as client: @@ -87,13 +86,12 @@ class BaseFetcher: self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() - if redis: - redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py index 08b8dfc..cb3897f 100644 --- a/app/fetcher/osu_dot_direct.py +++ b/app/fetcher/osu_dot_direct.py @@ -4,7 +4,7 @@ from ._base import BaseFetcher from httpx import AsyncClient from loguru import logger -import redis +import redis.asyncio as redis class OsuDotDirectFetcher(BaseFetcher): @@ -23,7 +23,7 @@ class OsuDotDirectFetcher(BaseFetcher): self, redis: redis.Redis, beatmap_id: int ) -> str: if redis.exists(f"beatmap:{beatmap_id}:raw"): - return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] + return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] raw = await self.get_beatmap_raw(beatmap_id) - redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) + await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) return raw diff --git a/app/router/beatmap.py b/app/router/beatmap.py index df5f20d..0a25562 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -22,7 +22,7 @@ from .api_router import router from fastapi import Depends, HTTPException, Query from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel -from redis import Redis +from redis.asyncio import Redis import rosu_pp_py as rosu from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -127,8 +127,8 @@ async def get_beatmap_attributes( f"beatmap:{beatmap}:{ruleset}:" f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" ) - if redis.exists(key): - return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType] + if await redis.exists(key): + return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] try: resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap) @@ -138,7 +138,7 @@ async def get_beatmap_attributes( ) except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] raise HTTPException(status_code=400, detail=str(e)) - redis.set(key, attr.model_dump_json()) + await redis.set(key, attr.model_dump_json()) return attr except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmap not found") diff --git a/app/router/room.py b/app/router/room.py index ed540fc..3a65617 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -6,7 +6,8 @@ from app.models.room import Room from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, Query +from redis.asyncio import Redis from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -19,17 +20,14 @@ async def get_all_rooms( status: str = Query(None), category: str = Query(None), db: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), ): all_room_ids = (await db.exec(select(RoomIndex).where(True))).all() - redis = get_redis() roomsList: list[Room] = [] - if redis: - for room_index in all_room_ids: - dumped_room = redis.get(str(room_index.id)) - if dumped_room: - actual_room = Room.model_validate_json(str(dumped_room)) - if actual_room.status == status and actual_room.category == category: - roomsList.append(actual_room) - return roomsList - else: - raise HTTPException(status_code=500, detail="Redis Error") + for room_index in all_room_ids: + dumped_room = await redis.get(str(room_index.id)) + if dumped_room: + actual_room = Room.model_validate_json(str(dumped_room)) + if actual_room.status == status and actual_room.category == category: + roomsList.append(actual_room) + return roomsList diff --git a/app/router/score.py b/app/router/score.py index 6c6a475..2f1303e 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -18,7 +18,7 @@ from .api_router import router from fastapi import Depends, Form, HTTPException, Query from pydantic import BaseModel -from redis import Redis +from redis.asyncio import Redis from sqlalchemy.orm import joinedload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession diff --git a/main.py b/main.py index 92d4402..f5d20c1 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager from datetime import datetime from app.config import settings -from app.dependencies.database import create_tables, engine +from app.dependencies.database import create_tables, engine, redis_client from app.dependencies.fetcher import get_fetcher from app.router import api_router, auth_router, fetcher_router, signalr_router @@ -15,10 +15,11 @@ from fastapi import FastAPI async def lifespan(app: FastAPI): # on startup await create_tables() - get_fetcher() # 初始化 fetcher + await get_fetcher() # 初始化 fetcher # on shutdown yield await engine.dispose() + await redis_client.aclose() app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)