From ad51514fb183ce3e10bc222d60ecdb076a356d80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= <74496778+GooGuJiang@users.noreply.github.com> Date: Thu, 21 Aug 2025 23:48:58 +0800 Subject: [PATCH] fix Pydantic serializer warnings --- app/database/lazer_user.py | 13 +++ app/database/score.py | 52 +++++++++ app/database/statistics.py | 13 +++ app/service/ranking_cache_service.py | 28 ++++- test_ranking_serialization.py | 165 +++++++++++++++++++++++++++ 5 files changed, 265 insertions(+), 6 deletions(-) create mode 100644 test_ranking_serialization.py diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index f3523f7..515f769 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -17,6 +17,7 @@ from .statistics import UserStatistics, UserStatisticsResp from .team import Team, TeamMember from .user_account_history import UserAccountHistory, UserAccountHistoryResp +from pydantic import field_validator from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import ( JSON, @@ -135,6 +136,18 @@ class UserBase(UTCBaseModel, SQLModel): is_qat: bool = False is_bng: bool = False + @field_validator('playmode', mode='before') + @classmethod + def validate_playmode(cls, v): + """将字符串转换为 GameMode 枚举""" + if isinstance(v, str): + try: + return GameMode(v) + except ValueError: + # 如果转换失败,返回默认值 + return GameMode.OSU + return v + class User(AsyncAttrs, UserBase, table=True): __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] diff --git a/app/database/score.py b/app/database/score.py index a92c6ad..ddfef6d 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -100,6 +100,26 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): sa_column=Column(JSON), default_factory=dict ) + @field_validator('maximum_statistics', mode='before') + @classmethod + def validate_maximum_statistics(cls, v): + """处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举""" + if isinstance(v, dict): + converted = {} + for key, value in v.items(): + if isinstance(key, str): + try: + # 尝试将字符串转换为 HitResult 枚举 + enum_key = HitResult(key) + converted[enum_key] = value + except ValueError: + # 如果转换失败,跳过这个键值对 + continue + else: + converted[key] = value + return converted + return v + # optional # TODO: current_user_attributes @@ -131,6 +151,18 @@ class Score(ScoreBase, table=True): gamemode: GameMode = Field(index=True) pinned_order: int = Field(default=0, exclude=True) + @field_validator('gamemode', mode='before') + @classmethod + def validate_gamemode(cls, v): + """将字符串转换为 GameMode 枚举""" + if isinstance(v, str): + try: + return GameMode(v) + except ValueError: + # 如果转换失败,返回默认值 + return GameMode.OSU + return v + # optional beatmap: Beatmap = Relationship() user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) @@ -185,6 +217,26 @@ class ScoreResp(ScoreBase): return bool(v) return v + @field_validator('statistics', 'maximum_statistics', mode='before') + @classmethod + def validate_statistics_fields(cls, v): + """处理统计字段中的字符串键,转换为 HitResult 枚举""" + if isinstance(v, dict): + converted = {} + for key, value in v.items(): + if isinstance(key, str): + try: + # 尝试将字符串转换为 HitResult 枚举 + enum_key = HitResult(key) + converted[enum_key] = value + except ValueError: + # 如果转换失败,跳过这个键值对 + continue + else: + converted[key] = value + return converted + return v + @classmethod async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": s = cls.model_validate(score.model_dump()) diff --git a/app/database/statistics.py b/app/database/statistics.py index be1f762..04b037d 100644 --- a/app/database/statistics.py +++ b/app/database/statistics.py @@ -6,6 +6,7 @@ from app.models.score import GameMode from .rank_history import RankHistory +from pydantic import field_validator from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import ( BigInteger, @@ -43,6 +44,18 @@ class UserStatisticsBase(SQLModel): replays_watched_by_others: int = Field(default=0) is_ranked: bool = Field(default=True) + @field_validator('mode', mode='before') + @classmethod + def validate_mode(cls, v): + """将字符串转换为 GameMode 枚举""" + if isinstance(v, str): + try: + return GameMode(v) + except ValueError: + # 如果转换失败,返回默认值 + return GameMode.OSU + return v + class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True): __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType] diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py index 0cd75ac..7eff17d 100644 --- a/app/service/ranking_cache_service.py +++ b/app/service/ranking_cache_service.py @@ -22,6 +22,20 @@ if TYPE_CHECKING: pass +class DateTimeEncoder(json.JSONEncoder): + """自定义 JSON 编码器,支持 datetime 序列化""" + + def default(self, obj): + if isinstance(obj, datetime): + return obj.isoformat() + return super().default(obj) + + +def safe_json_dumps(data) -> str: + """安全的 JSON 序列化,支持 datetime 对象""" + return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")) + + class RankingCacheService: """用户排行榜缓存服务""" @@ -95,7 +109,7 @@ class RankingCacheService: ttl = settings.ranking_cache_expire_minutes * 60 await self.redis.set( cache_key, - json.dumps(ranking_data, separators=(",", ":")), + safe_json_dumps(ranking_data), ex=ttl ) logger.debug(f"Cached ranking data for {cache_key}") @@ -136,7 +150,7 @@ class RankingCacheService: ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间 await self.redis.set( cache_key, - json.dumps(stats, separators=(",", ":")), + safe_json_dumps(stats), ex=ttl ) logger.debug(f"Cached stats for {cache_key}") @@ -174,7 +188,7 @@ class RankingCacheService: ttl = settings.ranking_cache_expire_minutes * 60 await self.redis.set( cache_key, - json.dumps(ranking_data, separators=(",", ":")), + safe_json_dumps(ranking_data), ex=ttl ) logger.debug(f"Cached country ranking data for {cache_key}") @@ -207,7 +221,7 @@ class RankingCacheService: ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间 await self.redis.set( cache_key, - json.dumps(stats, separators=(",", ":")), + safe_json_dumps(stats), ex=ttl ) logger.debug(f"Cached country stats for {cache_key}") @@ -283,13 +297,15 @@ class RankingCacheService: if not statistics_data: break # 没有更多数据 - # 转换为响应格式 + # 转换为响应格式并确保正确序列化 ranking_data = [] for statistics in statistics_data: user_stats_resp = await UserStatisticsResp.from_db( statistics, session, None, include ) - ranking_data.append(user_stats_resp.model_dump()) + # 将 UserStatisticsResp 转换为字典,处理所有序列化问题 + user_dict = json.loads(user_stats_resp.model_dump_json()) + ranking_data.append(user_dict) # 缓存这一页的数据 await self.cache_ranking( diff --git a/test_ranking_serialization.py b/test_ranking_serialization.py new file mode 100644 index 0000000..91e767a --- /dev/null +++ b/test_ranking_serialization.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +"""测试排行榜缓存序列化修复""" + +import asyncio +import warnings +from datetime import datetime, UTC +from app.service.ranking_cache_service import DateTimeEncoder, safe_json_dumps + + +def test_datetime_serialization(): + """测试 datetime 序列化""" + print("🧪 测试 datetime 序列化...") + + test_data = { + "id": 1, + "username": "test_user", + "last_updated": datetime.now(UTC), + "join_date": datetime(2020, 1, 1, tzinfo=UTC), + "stats": { + "pp": 1000.0, + "accuracy": 95.5, + "last_played": datetime.now(UTC) + } + } + + try: + # 测试自定义编码器 + json_result = safe_json_dumps(test_data) + print("✅ datetime 序列化成功") + print(f" 序列化结果长度: {len(json_result)}") + + # 验证可以重新解析 + import json + parsed = json.loads(json_result) + assert "last_updated" in parsed + assert isinstance(parsed["last_updated"], str) + print("✅ 序列化的 JSON 可以正确解析") + + except Exception as e: + print(f"❌ datetime 序列化测试失败: {e}") + import traceback + traceback.print_exc() + + +def test_boolean_serialization(): + """测试布尔值序列化""" + print("\n🧪 测试布尔值序列化...") + + test_data = { + "user": { + "is_active": 1, # 数据库中的整数布尔值 + "is_supporter": 0, # 数据库中的整数布尔值 + "has_profile": True, # 正常布尔值 + }, + "stats": { + "is_ranked": 1, # 数据库中的整数布尔值 + "verified": False, # 正常布尔值 + } + } + + try: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + json_result = safe_json_dumps(test_data) + + # 检查是否有 Pydantic 序列化警告 + pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)] + if pydantic_warnings: + print(f"⚠️ 仍有 {len(pydantic_warnings)} 个布尔值序列化警告") + for warning in pydantic_warnings: + print(f" {warning.message}") + else: + print("✅ 布尔值序列化无警告") + + # 验证序列化结果 + import json + parsed = json.loads(json_result) + print(f"✅ 布尔值序列化成功,结果: {parsed}") + + except Exception as e: + print(f"❌ 布尔值序列化测试失败: {e}") + import traceback + traceback.print_exc() + + +def test_complex_ranking_data(): + """测试复杂的排行榜数据序列化""" + print("\n🧪 测试复杂排行榜数据序列化...") + + # 模拟排行榜数据结构 + ranking_data = [ + { + "id": 1, + "user": { + "id": 1, + "username": "player1", + "country_code": "US", + "is_active": 1, # 整数布尔值 + "is_supporter": 0, # 整数布尔值 + "join_date": datetime(2020, 1, 1, tzinfo=UTC), + "last_visit": datetime.now(UTC), + }, + "statistics": { + "pp": 8000.0, + "accuracy": 98.5, + "play_count": 5000, + "is_ranked": 1, # 整数布尔值 + "last_updated": datetime.now(UTC), + } + }, + { + "id": 2, + "user": { + "id": 2, + "username": "player2", + "country_code": "JP", + "is_active": 1, + "is_supporter": 1, + "join_date": datetime(2019, 6, 15, tzinfo=UTC), + "last_visit": datetime.now(UTC), + }, + "statistics": { + "pp": 7500.0, + "accuracy": 97.8, + "play_count": 4500, + "is_ranked": 1, + "last_updated": datetime.now(UTC), + } + } + ] + + try: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + json_result = safe_json_dumps(ranking_data) + + pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)] + if pydantic_warnings: + print(f"⚠️ 仍有 {len(pydantic_warnings)} 个序列化警告") + for warning in pydantic_warnings: + print(f" {warning.message}") + else: + print("✅ 复杂排行榜数据序列化无警告") + + # 验证序列化结果 + import json + parsed = json.loads(json_result) + assert len(parsed) == 2 + assert parsed[0]["user"]["username"] == "player1" + print(f"✅ 复杂排行榜数据序列化成功,包含 {len(parsed)} 个条目") + + except Exception as e: + print(f"❌ 复杂排行榜数据序列化测试失败: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + print("🚀 开始排行榜缓存序列化测试\n") + + test_datetime_serialization() + test_boolean_serialization() + test_complex_ranking_data() + + print("\n🎉 排行榜缓存序列化测试完成!")