diff --git a/app/database/score.py b/app/database/score.py
index c5f1a38..642eac1 100644
--- a/app/database/score.py
+++ b/app/database/score.py
@@ -39,7 +39,7 @@ from .relationship import (
)
from .score_token import ScoreToken
-from redis import Redis
+from redis.asyncio import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased
diff --git a/app/dependencies/database.py b/app/dependencies/database.py
index fe09139..77b15c3 100644
--- a/app/dependencies/database.py
+++ b/app/dependencies/database.py
@@ -5,15 +5,11 @@ import json
from app.config import settings
from pydantic import BaseModel
+import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
-try:
- import redis
-except ImportError:
- redis = None
-
def json_serializer(value):
if isinstance(value, BaseModel | SQLModel):
@@ -25,10 +21,7 @@ def json_serializer(value):
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
# Redis 连接
-if redis:
- redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
-else:
- redis_client = None
+redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
# 数据库依赖
diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py
index d3c216a..806eb87 100644
--- a/app/dependencies/fetcher.py
+++ b/app/dependencies/fetcher.py
@@ -8,7 +8,7 @@ from app.log import logger
fetcher: Fetcher | None = None
-def get_fetcher() -> Fetcher:
+async def get_fetcher() -> Fetcher:
global fetcher
if fetcher is None:
fetcher = Fetcher(
@@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher:
settings.FETCHER_CALLBACK_URL,
)
redis = get_redis()
- if redis:
- access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}")
- if access_token:
- fetcher.access_token = str(access_token)
- refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
- if refresh_token:
- fetcher.refresh_token = str(refresh_token)
- if not fetcher.access_token or not fetcher.refresh_token:
- logger.opt(colors=True).info(
- f"Login to initialize fetcher: {fetcher.authorize_url}"
- )
+ access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
+ if access_token:
+ fetcher.access_token = str(access_token)
+ refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
+ if refresh_token:
+ fetcher.refresh_token = str(refresh_token)
+ if not fetcher.access_token or not fetcher.refresh_token:
+ logger.opt(colors=True).info(
+ f"Login to initialize fetcher: {fetcher.authorize_url}"
+ )
return fetcher
diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py
index 08e3508..2717a35 100644
--- a/app/fetcher/_base.py
+++ b/app/fetcher/_base.py
@@ -59,16 +59,15 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
- if redis:
- redis.set(
- f"fetcher:access_token:{self.client_id}",
- self.access_token,
- ex=token_data["expires_in"],
- )
- redis.set(
- f"fetcher:refresh_token:{self.client_id}",
- self.refresh_token,
- )
+ await redis.set(
+ f"fetcher:access_token:{self.client_id}",
+ self.access_token,
+ ex=token_data["expires_in"],
+ )
+ await redis.set(
+ f"fetcher:refresh_token:{self.client_id}",
+ self.refresh_token,
+ )
async def refresh_access_token(self) -> None:
async with AsyncClient() as client:
@@ -87,13 +86,12 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
- if redis:
- redis.set(
- f"fetcher:access_token:{self.client_id}",
- self.access_token,
- ex=token_data["expires_in"],
- )
- redis.set(
- f"fetcher:refresh_token:{self.client_id}",
- self.refresh_token,
- )
+ await redis.set(
+ f"fetcher:access_token:{self.client_id}",
+ self.access_token,
+ ex=token_data["expires_in"],
+ )
+ await redis.set(
+ f"fetcher:refresh_token:{self.client_id}",
+ self.refresh_token,
+ )
diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py
index 08b8dfc..cb3897f 100644
--- a/app/fetcher/osu_dot_direct.py
+++ b/app/fetcher/osu_dot_direct.py
@@ -4,7 +4,7 @@ from ._base import BaseFetcher
from httpx import AsyncClient
from loguru import logger
-import redis
+import redis.asyncio as redis
class OsuDotDirectFetcher(BaseFetcher):
@@ -23,7 +23,7 @@ class OsuDotDirectFetcher(BaseFetcher):
self, redis: redis.Redis, beatmap_id: int
) -> str:
if redis.exists(f"beatmap:{beatmap_id}:raw"):
- return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
+ return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
raw = await self.get_beatmap_raw(beatmap_id)
- redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
+ await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
return raw
diff --git a/app/router/beatmap.py b/app/router/beatmap.py
index df5f20d..0a25562 100644
--- a/app/router/beatmap.py
+++ b/app/router/beatmap.py
@@ -22,7 +22,7 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
-from redis import Redis
+from redis.asyncio import Redis
import rosu_pp_py as rosu
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -127,8 +127,8 @@ async def get_beatmap_attributes(
f"beatmap:{beatmap}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
- if redis.exists(key):
- return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
+ if await redis.exists(key):
+ return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
@@ -138,7 +138,7 @@ async def get_beatmap_attributes(
)
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
raise HTTPException(status_code=400, detail=str(e))
- redis.set(key, attr.model_dump_json())
+ await redis.set(key, attr.model_dump_json())
return attr
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")
diff --git a/app/router/room.py b/app/router/room.py
index ed540fc..3a65617 100644
--- a/app/router/room.py
+++ b/app/router/room.py
@@ -6,7 +6,8 @@ from app.models.room import Room
from .api_router import router
-from fastapi import Depends, HTTPException, Query
+from fastapi import Depends, Query
+from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -19,17 +20,14 @@ async def get_all_rooms(
status: str = Query(None),
category: str = Query(None),
db: AsyncSession = Depends(get_db),
+ redis: Redis = Depends(get_redis),
):
all_room_ids = (await db.exec(select(RoomIndex).where(True))).all()
- redis = get_redis()
roomsList: list[Room] = []
- if redis:
- for room_index in all_room_ids:
- dumped_room = redis.get(str(room_index.id))
- if dumped_room:
- actual_room = Room.model_validate_json(str(dumped_room))
- if actual_room.status == status and actual_room.category == category:
- roomsList.append(actual_room)
- return roomsList
- else:
- raise HTTPException(status_code=500, detail="Redis Error")
+ for room_index in all_room_ids:
+ dumped_room = await redis.get(str(room_index.id))
+ if dumped_room:
+ actual_room = Room.model_validate_json(str(dumped_room))
+ if actual_room.status == status and actual_room.category == category:
+ roomsList.append(actual_room)
+ return roomsList
diff --git a/app/router/score.py b/app/router/score.py
index 6c6a475..2f1303e 100644
--- a/app/router/score.py
+++ b/app/router/score.py
@@ -18,7 +18,7 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel
-from redis import Redis
+from redis.asyncio import Redis
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
diff --git a/main.py b/main.py
index 92d4402..f5d20c1 100644
--- a/main.py
+++ b/main.py
@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
from datetime import datetime
from app.config import settings
-from app.dependencies.database import create_tables, engine
+from app.dependencies.database import create_tables, engine, redis_client
from app.dependencies.fetcher import get_fetcher
from app.router import api_router, auth_router, fetcher_router, signalr_router
@@ -15,10 +15,11 @@ from fastapi import FastAPI
async def lifespan(app: FastAPI):
# on startup
await create_tables()
- get_fetcher() # 初始化 fetcher
+ await get_fetcher() # 初始化 fetcher
# on shutdown
yield
await engine.dispose()
+ await redis_client.aclose()
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)