diff --git a/app/database/score.py b/app/database/score.py index 88e2cfa..9710138 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -45,7 +45,7 @@ from .relationship import ( ) from .score_token import ScoreToken -from pydantic import field_validator +from pydantic import field_validator, field_serializer from redis.asyncio import Redis from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime from sqlalchemy.ext.asyncio import AsyncAttrs @@ -120,6 +120,29 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): return converted return v + @field_serializer("maximum_statistics", when_used="json") + def serialize_maximum_statistics(self, v): + """序列化 maximum_statistics 字段,确保枚举值正确转换为字符串""" + if isinstance(v, dict): + serialized = {} + for key, value in v.items(): + if hasattr(key, 'value'): + # 如果是枚举,使用其值 + serialized[key.value] = value + else: + # 否则直接使用键 + serialized[str(key)] = value + return serialized + return v + + @field_serializer("rank", when_used="json") + def serialize_rank(self, v): + """序列化等级,确保枚举值正确转换为字符串""" + if hasattr(v, 'value'): + return v.value + return str(v) + return v + # optional # TODO: current_user_attributes @@ -163,6 +186,13 @@ class Score(ScoreBase, table=True): return GameMode.OSU return v + @field_serializer("gamemode", when_used="json") + def serialize_gamemode(self, v): + """序列化游戏模式,确保枚举值正确转换为字符串""" + if hasattr(v, 'value'): + return v.value + return str(v) + # optional beatmap: Beatmap = Relationship() user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) @@ -246,6 +276,28 @@ class ScoreResp(ScoreBase): return converted return v + @field_serializer("statistics", "maximum_statistics", when_used="json") + def serialize_statistics_fields(self, v): + """序列化统计字段,确保枚举值正确转换为字符串""" + if isinstance(v, dict): + serialized = {} + for key, value in v.items(): + if hasattr(key, 'value'): + # 如果是枚举,使用其值 + serialized[key.value] = value + else: + # 否则直接使用键 + serialized[str(key)] = value + return serialized + return v + + @field_serializer("gamemode", when_used="json") + def serialize_gamemode(self, v): + """序列化游戏模式,确保枚举值正确转换为字符串""" + if hasattr(v, 'value'): + return v.value + return str(v) + @classmethod async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": s = cls.model_validate(score.model_dump()) diff --git a/app/models/score.py b/app/models/score.py index b319703..f75e827 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -7,7 +7,7 @@ from app.config import settings from .mods import API_MODS, APIMod -from pydantic import BaseModel, Field, ValidationInfo, field_validator +from pydantic import BaseModel, Field, ValidationInfo, field_validator, field_serializer if TYPE_CHECKING: import rosu_pp_py as rosu @@ -206,6 +206,28 @@ class SoloScoreSubmissionInfo(BaseModel): incompatible_mods.update(setting_mods["IncompatibleMods"]) return mods + @field_serializer("statistics", "maximum_statistics", when_used="json") + def serialize_statistics(self, v): + """序列化统计字段,确保枚举值正确转换为字符串""" + if isinstance(v, dict): + serialized = {} + for key, value in v.items(): + if hasattr(key, 'value'): + # 如果是枚举,使用其值 + serialized[key.value] = value + else: + # 否则直接使用键 + serialized[str(key)] = value + return serialized + return v + + @field_serializer("rank", when_used="json") + def serialize_rank(self, v): + """序列化等级,确保枚举值正确转换为字符串""" + if hasattr(v, 'value'): + return v.value + return str(v) + class LegacyReplaySoloScoreInfo(TypedDict): online_id: int diff --git a/app/router/v2/score.py b/app/router/v2/score.py index d3eb903..050fa3f 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -34,7 +34,7 @@ from app.database.score import ( process_score, process_user, ) -from app.dependencies.database import Database, get_redis, with_db +from app.dependencies.database import Database, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.storage import get_storage_service from app.dependencies.user import get_client_user, get_current_user @@ -75,8 +75,14 @@ READ_SCORE_TIMEOUT = 10 async def process_user_achievement(score_id: int): - async with with_db() as session: + from sqlmodel.ext.asyncio.session import AsyncSession + from app.dependencies.database import engine + + session = AsyncSession(engine) + try: await process_achievements(session, get_redis(), score_id) + finally: + await session.close() async def submit_score( @@ -184,20 +190,37 @@ async def submit_score( db.add(rank_event) await db.commit() - # 成绩提交后刷新用户缓存 - try: - user_cache_service = get_user_cache_service(redis) - if current_user.id is not None: - await user_cache_service.refresh_user_cache_on_score_submit( - db, current_user.id, score.gamemode - ) - except Exception as e: - logger.error(f"Failed to refresh user cache after score submit: {e}") - + # 成绩提交后刷新用户缓存 - 移至后台任务避免阻塞主流程 + if current_user.id is not None: + background_task.add_task( + _refresh_user_cache_background, + redis, + current_user.id, + score.gamemode + ) background_task.add_task(process_user_achievement, resp.id) return resp +async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameMode): + """后台任务:刷新用户缓存""" + try: + from sqlmodel.ext.asyncio.session import AsyncSession + from app.dependencies.database import engine + + user_cache_service = get_user_cache_service(redis) + # 创建独立的数据库会话 + session = AsyncSession(engine) + try: + await user_cache_service.refresh_user_cache_on_score_submit( + session, user_id, mode + ) + finally: + await session.close() + except Exception as e: + logger.error(f"Failed to refresh user cache after score submit: {e}") + + async def _preload_beatmap_for_pp_calculation(beatmap_id: int) -> None: """ 预缓存beatmap文件以加速PP计算