feat(redis): use asyncio

This commit is contained in:
MingxuanGame
2025-07-31 14:38:10 +00:00
parent 1635641654
commit c5fc6afc18
9 changed files with 53 additions and 64 deletions

View File

@@ -39,7 +39,7 @@ from .relationship import (
)
from .score_token import ScoreToken
from redis import Redis
from redis.asyncio import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased

View File

@@ -5,15 +5,11 @@ import json
from app.config import settings
from pydantic import BaseModel
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
try:
import redis
except ImportError:
redis = None
def json_serializer(value):
if isinstance(value, BaseModel | SQLModel):
@@ -25,10 +21,7 @@ def json_serializer(value):
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
# Redis 连接
if redis:
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
else:
redis_client = None
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
# 数据库依赖

View File

@@ -8,7 +8,7 @@ from app.log import logger
fetcher: Fetcher | None = None
def get_fetcher() -> Fetcher:
async def get_fetcher() -> Fetcher:
global fetcher
if fetcher is None:
fetcher = Fetcher(
@@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher:
settings.FETCHER_CALLBACK_URL,
)
redis = get_redis()
if redis:
access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}")
if access_token:
fetcher.access_token = str(access_token)
refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
if refresh_token:
fetcher.refresh_token = str(refresh_token)
if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
)
access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
if access_token:
fetcher.access_token = str(access_token)
refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
if refresh_token:
fetcher.refresh_token = str(refresh_token)
if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
)
return fetcher

View File

@@ -59,16 +59,15 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
if redis:
redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
await redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
await redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
async def refresh_access_token(self) -> None:
async with AsyncClient() as client:
@@ -87,13 +86,12 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
if redis:
redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
await redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
await redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)

View File

@@ -4,7 +4,7 @@ from ._base import BaseFetcher
from httpx import AsyncClient
from loguru import logger
import redis
import redis.asyncio as redis
class OsuDotDirectFetcher(BaseFetcher):
@@ -23,7 +23,7 @@ class OsuDotDirectFetcher(BaseFetcher):
self, redis: redis.Redis, beatmap_id: int
) -> str:
if redis.exists(f"beatmap:{beatmap_id}:raw"):
return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
raw = await self.get_beatmap_raw(beatmap_id)
redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
return raw

View File

@@ -22,7 +22,7 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from redis import Redis
from redis.asyncio import Redis
import rosu_pp_py as rosu
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -127,8 +127,8 @@ async def get_beatmap_attributes(
f"beatmap:{beatmap}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
if redis.exists(key):
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
@@ -138,7 +138,7 @@ async def get_beatmap_attributes(
)
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
raise HTTPException(status_code=400, detail=str(e))
redis.set(key, attr.model_dump_json())
await redis.set(key, attr.model_dump_json())
return attr
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")

View File

@@ -6,7 +6,8 @@ from app.models.room import Room
from .api_router import router
from fastapi import Depends, HTTPException, Query
from fastapi import Depends, Query
from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -19,17 +20,14 @@ async def get_all_rooms(
status: str = Query(None),
category: str = Query(None),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
all_room_ids = (await db.exec(select(RoomIndex).where(True))).all()
redis = get_redis()
roomsList: list[Room] = []
if redis:
for room_index in all_room_ids:
dumped_room = redis.get(str(room_index.id))
if dumped_room:
actual_room = Room.model_validate_json(str(dumped_room))
if actual_room.status == status and actual_room.category == category:
roomsList.append(actual_room)
return roomsList
else:
raise HTTPException(status_code=500, detail="Redis Error")
for room_index in all_room_ids:
dumped_room = await redis.get(str(room_index.id))
if dumped_room:
actual_room = Room.model_validate_json(str(dumped_room))
if actual_room.status == status and actual_room.category == category:
roomsList.append(actual_room)
return roomsList

View File

@@ -18,7 +18,7 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel
from redis import Redis
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession

View File

@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
from datetime import datetime
from app.config import settings
from app.dependencies.database import create_tables, engine
from app.dependencies.database import create_tables, engine, redis_client
from app.dependencies.fetcher import get_fetcher
from app.router import api_router, auth_router, fetcher_router, signalr_router
@@ -15,10 +15,11 @@ from fastapi import FastAPI
async def lifespan(app: FastAPI):
# on startup
await create_tables()
get_fetcher() # 初始化 fetcher
await get_fetcher() # 初始化 fetcher
# on shutdown
yield
await engine.dispose()
await redis_client.aclose()
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)