fix Pydantic serializer warnings

This commit is contained in:
咕谷酱
2025-08-21 23:48:58 +08:00
parent 822d7c6377
commit ad51514fb1
5 changed files with 265 additions and 6 deletions

View File

@@ -17,6 +17,7 @@ from .statistics import UserStatistics, UserStatisticsResp
from .team import Team, TeamMember
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
from pydantic import field_validator
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
JSON,
@@ -135,6 +136,18 @@ class UserBase(UTCBaseModel, SQLModel):
is_qat: bool = False
is_bng: bool = False
@field_validator('playmode', mode='before')
@classmethod
def validate_playmode(cls, v):
"""将字符串转换为 GameMode 枚举"""
if isinstance(v, str):
try:
return GameMode(v)
except ValueError:
# 如果转换失败,返回默认值
return GameMode.OSU
return v
class User(AsyncAttrs, UserBase, table=True):
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]

View File

@@ -100,6 +100,26 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
sa_column=Column(JSON), default_factory=dict
)
@field_validator('maximum_statistics', mode='before')
@classmethod
def validate_maximum_statistics(cls, v):
"""处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举"""
if isinstance(v, dict):
converted = {}
for key, value in v.items():
if isinstance(key, str):
try:
# 尝试将字符串转换为 HitResult 枚举
enum_key = HitResult(key)
converted[enum_key] = value
except ValueError:
# 如果转换失败,跳过这个键值对
continue
else:
converted[key] = value
return converted
return v
# optional
# TODO: current_user_attributes
@@ -131,6 +151,18 @@ class Score(ScoreBase, table=True):
gamemode: GameMode = Field(index=True)
pinned_order: int = Field(default=0, exclude=True)
@field_validator('gamemode', mode='before')
@classmethod
def validate_gamemode(cls, v):
"""将字符串转换为 GameMode 枚举"""
if isinstance(v, str):
try:
return GameMode(v)
except ValueError:
# 如果转换失败,返回默认值
return GameMode.OSU
return v
# optional
beatmap: Beatmap = Relationship()
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@@ -185,6 +217,26 @@ class ScoreResp(ScoreBase):
return bool(v)
return v
@field_validator('statistics', 'maximum_statistics', mode='before')
@classmethod
def validate_statistics_fields(cls, v):
"""处理统计字段中的字符串键,转换为 HitResult 枚举"""
if isinstance(v, dict):
converted = {}
for key, value in v.items():
if isinstance(key, str):
try:
# 尝试将字符串转换为 HitResult 枚举
enum_key = HitResult(key)
converted[enum_key] = value
except ValueError:
# 如果转换失败,跳过这个键值对
continue
else:
converted[key] = value
return converted
return v
@classmethod
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
s = cls.model_validate(score.model_dump())

View File

@@ -6,6 +6,7 @@ from app.models.score import GameMode
from .rank_history import RankHistory
from pydantic import field_validator
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
@@ -43,6 +44,18 @@ class UserStatisticsBase(SQLModel):
replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=True)
@field_validator('mode', mode='before')
@classmethod
def validate_mode(cls, v):
"""将字符串转换为 GameMode 枚举"""
if isinstance(v, str):
try:
return GameMode(v)
except ValueError:
# 如果转换失败,返回默认值
return GameMode.OSU
return v
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]

View File

@@ -22,6 +22,20 @@ if TYPE_CHECKING:
pass
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
def safe_json_dumps(data) -> str:
"""安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
class RankingCacheService:
"""用户排行榜缓存服务"""
@@ -95,7 +109,7 @@ class RankingCacheService:
ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set(
cache_key,
json.dumps(ranking_data, separators=(",", ":")),
safe_json_dumps(ranking_data),
ex=ttl
)
logger.debug(f"Cached ranking data for {cache_key}")
@@ -136,7 +150,7 @@ class RankingCacheService:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set(
cache_key,
json.dumps(stats, separators=(",", ":")),
safe_json_dumps(stats),
ex=ttl
)
logger.debug(f"Cached stats for {cache_key}")
@@ -174,7 +188,7 @@ class RankingCacheService:
ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set(
cache_key,
json.dumps(ranking_data, separators=(",", ":")),
safe_json_dumps(ranking_data),
ex=ttl
)
logger.debug(f"Cached country ranking data for {cache_key}")
@@ -207,7 +221,7 @@ class RankingCacheService:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set(
cache_key,
json.dumps(stats, separators=(",", ":")),
safe_json_dumps(stats),
ex=ttl
)
logger.debug(f"Cached country stats for {cache_key}")
@@ -283,13 +297,15 @@ class RankingCacheService:
if not statistics_data:
break # 没有更多数据
# 转换为响应格式
# 转换为响应格式并确保正确序列化
ranking_data = []
for statistics in statistics_data:
user_stats_resp = await UserStatisticsResp.from_db(
statistics, session, None, include
)
ranking_data.append(user_stats_resp.model_dump())
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题
user_dict = json.loads(user_stats_resp.model_dump_json())
ranking_data.append(user_dict)
# 缓存这一页的数据
await self.cache_ranking(