fix Pydantic serializer warnings
This commit is contained in:
@@ -17,6 +17,7 @@ from .statistics import UserStatistics, UserStatisticsResp
|
||||
from .team import Team, TeamMember
|
||||
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
|
||||
|
||||
from pydantic import field_validator
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
@@ -135,6 +136,18 @@ class UserBase(UTCBaseModel, SQLModel):
|
||||
is_qat: bool = False
|
||||
is_bng: bool = False
|
||||
|
||||
@field_validator('playmode', mode='before')
|
||||
@classmethod
|
||||
def validate_playmode(cls, v):
|
||||
"""将字符串转换为 GameMode 枚举"""
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return GameMode(v)
|
||||
except ValueError:
|
||||
# 如果转换失败,返回默认值
|
||||
return GameMode.OSU
|
||||
return v
|
||||
|
||||
|
||||
class User(AsyncAttrs, UserBase, table=True):
|
||||
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
@@ -100,6 +100,26 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
sa_column=Column(JSON), default_factory=dict
|
||||
)
|
||||
|
||||
@field_validator('maximum_statistics', mode='before')
|
||||
@classmethod
|
||||
def validate_maximum_statistics(cls, v):
|
||||
"""处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举"""
|
||||
if isinstance(v, dict):
|
||||
converted = {}
|
||||
for key, value in v.items():
|
||||
if isinstance(key, str):
|
||||
try:
|
||||
# 尝试将字符串转换为 HitResult 枚举
|
||||
enum_key = HitResult(key)
|
||||
converted[enum_key] = value
|
||||
except ValueError:
|
||||
# 如果转换失败,跳过这个键值对
|
||||
continue
|
||||
else:
|
||||
converted[key] = value
|
||||
return converted
|
||||
return v
|
||||
|
||||
# optional
|
||||
# TODO: current_user_attributes
|
||||
|
||||
@@ -131,6 +151,18 @@ class Score(ScoreBase, table=True):
|
||||
gamemode: GameMode = Field(index=True)
|
||||
pinned_order: int = Field(default=0, exclude=True)
|
||||
|
||||
@field_validator('gamemode', mode='before')
|
||||
@classmethod
|
||||
def validate_gamemode(cls, v):
|
||||
"""将字符串转换为 GameMode 枚举"""
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return GameMode(v)
|
||||
except ValueError:
|
||||
# 如果转换失败,返回默认值
|
||||
return GameMode.OSU
|
||||
return v
|
||||
|
||||
# optional
|
||||
beatmap: Beatmap = Relationship()
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
@@ -185,6 +217,26 @@ class ScoreResp(ScoreBase):
|
||||
return bool(v)
|
||||
return v
|
||||
|
||||
@field_validator('statistics', 'maximum_statistics', mode='before')
|
||||
@classmethod
|
||||
def validate_statistics_fields(cls, v):
|
||||
"""处理统计字段中的字符串键,转换为 HitResult 枚举"""
|
||||
if isinstance(v, dict):
|
||||
converted = {}
|
||||
for key, value in v.items():
|
||||
if isinstance(key, str):
|
||||
try:
|
||||
# 尝试将字符串转换为 HitResult 枚举
|
||||
enum_key = HitResult(key)
|
||||
converted[enum_key] = value
|
||||
except ValueError:
|
||||
# 如果转换失败,跳过这个键值对
|
||||
continue
|
||||
else:
|
||||
converted[key] = value
|
||||
return converted
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
||||
s = cls.model_validate(score.model_dump())
|
||||
|
||||
@@ -6,6 +6,7 @@ from app.models.score import GameMode
|
||||
|
||||
from .rank_history import RankHistory
|
||||
|
||||
from pydantic import field_validator
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
@@ -43,6 +44,18 @@ class UserStatisticsBase(SQLModel):
|
||||
replays_watched_by_others: int = Field(default=0)
|
||||
is_ranked: bool = Field(default=True)
|
||||
|
||||
@field_validator('mode', mode='before')
|
||||
@classmethod
|
||||
def validate_mode(cls, v):
|
||||
"""将字符串转换为 GameMode 枚举"""
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return GameMode(v)
|
||||
except ValueError:
|
||||
# 如果转换失败,返回默认值
|
||||
return GameMode.OSU
|
||||
return v
|
||||
|
||||
|
||||
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
|
||||
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
@@ -22,6 +22,20 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""自定义 JSON 编码器,支持 datetime 序列化"""
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def safe_json_dumps(data) -> str:
|
||||
"""安全的 JSON 序列化,支持 datetime 对象"""
|
||||
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
|
||||
class RankingCacheService:
|
||||
"""用户排行榜缓存服务"""
|
||||
|
||||
@@ -95,7 +109,7 @@ class RankingCacheService:
|
||||
ttl = settings.ranking_cache_expire_minutes * 60
|
||||
await self.redis.set(
|
||||
cache_key,
|
||||
json.dumps(ranking_data, separators=(",", ":")),
|
||||
safe_json_dumps(ranking_data),
|
||||
ex=ttl
|
||||
)
|
||||
logger.debug(f"Cached ranking data for {cache_key}")
|
||||
@@ -136,7 +150,7 @@ class RankingCacheService:
|
||||
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
|
||||
await self.redis.set(
|
||||
cache_key,
|
||||
json.dumps(stats, separators=(",", ":")),
|
||||
safe_json_dumps(stats),
|
||||
ex=ttl
|
||||
)
|
||||
logger.debug(f"Cached stats for {cache_key}")
|
||||
@@ -174,7 +188,7 @@ class RankingCacheService:
|
||||
ttl = settings.ranking_cache_expire_minutes * 60
|
||||
await self.redis.set(
|
||||
cache_key,
|
||||
json.dumps(ranking_data, separators=(",", ":")),
|
||||
safe_json_dumps(ranking_data),
|
||||
ex=ttl
|
||||
)
|
||||
logger.debug(f"Cached country ranking data for {cache_key}")
|
||||
@@ -207,7 +221,7 @@ class RankingCacheService:
|
||||
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
|
||||
await self.redis.set(
|
||||
cache_key,
|
||||
json.dumps(stats, separators=(",", ":")),
|
||||
safe_json_dumps(stats),
|
||||
ex=ttl
|
||||
)
|
||||
logger.debug(f"Cached country stats for {cache_key}")
|
||||
@@ -283,13 +297,15 @@ class RankingCacheService:
|
||||
if not statistics_data:
|
||||
break # 没有更多数据
|
||||
|
||||
# 转换为响应格式
|
||||
# 转换为响应格式并确保正确序列化
|
||||
ranking_data = []
|
||||
for statistics in statistics_data:
|
||||
user_stats_resp = await UserStatisticsResp.from_db(
|
||||
statistics, session, None, include
|
||||
)
|
||||
ranking_data.append(user_stats_resp.model_dump())
|
||||
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题
|
||||
user_dict = json.loads(user_stats_resp.model_dump_json())
|
||||
ranking_data.append(user_dict)
|
||||
|
||||
# 缓存这一页的数据
|
||||
await self.cache_ranking(
|
||||
|
||||
165
test_ranking_serialization.py
Normal file
165
test_ranking_serialization.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
"""测试排行榜缓存序列化修复"""
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
from datetime import datetime, UTC
|
||||
from app.service.ranking_cache_service import DateTimeEncoder, safe_json_dumps
|
||||
|
||||
|
||||
def test_datetime_serialization():
|
||||
"""测试 datetime 序列化"""
|
||||
print("🧪 测试 datetime 序列化...")
|
||||
|
||||
test_data = {
|
||||
"id": 1,
|
||||
"username": "test_user",
|
||||
"last_updated": datetime.now(UTC),
|
||||
"join_date": datetime(2020, 1, 1, tzinfo=UTC),
|
||||
"stats": {
|
||||
"pp": 1000.0,
|
||||
"accuracy": 95.5,
|
||||
"last_played": datetime.now(UTC)
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
# 测试自定义编码器
|
||||
json_result = safe_json_dumps(test_data)
|
||||
print("✅ datetime 序列化成功")
|
||||
print(f" 序列化结果长度: {len(json_result)}")
|
||||
|
||||
# 验证可以重新解析
|
||||
import json
|
||||
parsed = json.loads(json_result)
|
||||
assert "last_updated" in parsed
|
||||
assert isinstance(parsed["last_updated"], str)
|
||||
print("✅ 序列化的 JSON 可以正确解析")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ datetime 序列化测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def test_boolean_serialization():
|
||||
"""测试布尔值序列化"""
|
||||
print("\n🧪 测试布尔值序列化...")
|
||||
|
||||
test_data = {
|
||||
"user": {
|
||||
"is_active": 1, # 数据库中的整数布尔值
|
||||
"is_supporter": 0, # 数据库中的整数布尔值
|
||||
"has_profile": True, # 正常布尔值
|
||||
},
|
||||
"stats": {
|
||||
"is_ranked": 1, # 数据库中的整数布尔值
|
||||
"verified": False, # 正常布尔值
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
json_result = safe_json_dumps(test_data)
|
||||
|
||||
# 检查是否有 Pydantic 序列化警告
|
||||
pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)]
|
||||
if pydantic_warnings:
|
||||
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个布尔值序列化警告")
|
||||
for warning in pydantic_warnings:
|
||||
print(f" {warning.message}")
|
||||
else:
|
||||
print("✅ 布尔值序列化无警告")
|
||||
|
||||
# 验证序列化结果
|
||||
import json
|
||||
parsed = json.loads(json_result)
|
||||
print(f"✅ 布尔值序列化成功,结果: {parsed}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 布尔值序列化测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def test_complex_ranking_data():
|
||||
"""测试复杂的排行榜数据序列化"""
|
||||
print("\n🧪 测试复杂排行榜数据序列化...")
|
||||
|
||||
# 模拟排行榜数据结构
|
||||
ranking_data = [
|
||||
{
|
||||
"id": 1,
|
||||
"user": {
|
||||
"id": 1,
|
||||
"username": "player1",
|
||||
"country_code": "US",
|
||||
"is_active": 1, # 整数布尔值
|
||||
"is_supporter": 0, # 整数布尔值
|
||||
"join_date": datetime(2020, 1, 1, tzinfo=UTC),
|
||||
"last_visit": datetime.now(UTC),
|
||||
},
|
||||
"statistics": {
|
||||
"pp": 8000.0,
|
||||
"accuracy": 98.5,
|
||||
"play_count": 5000,
|
||||
"is_ranked": 1, # 整数布尔值
|
||||
"last_updated": datetime.now(UTC),
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"user": {
|
||||
"id": 2,
|
||||
"username": "player2",
|
||||
"country_code": "JP",
|
||||
"is_active": 1,
|
||||
"is_supporter": 1,
|
||||
"join_date": datetime(2019, 6, 15, tzinfo=UTC),
|
||||
"last_visit": datetime.now(UTC),
|
||||
},
|
||||
"statistics": {
|
||||
"pp": 7500.0,
|
||||
"accuracy": 97.8,
|
||||
"play_count": 4500,
|
||||
"is_ranked": 1,
|
||||
"last_updated": datetime.now(UTC),
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
json_result = safe_json_dumps(ranking_data)
|
||||
|
||||
pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)]
|
||||
if pydantic_warnings:
|
||||
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个序列化警告")
|
||||
for warning in pydantic_warnings:
|
||||
print(f" {warning.message}")
|
||||
else:
|
||||
print("✅ 复杂排行榜数据序列化无警告")
|
||||
|
||||
# 验证序列化结果
|
||||
import json
|
||||
parsed = json.loads(json_result)
|
||||
assert len(parsed) == 2
|
||||
assert parsed[0]["user"]["username"] == "player1"
|
||||
print(f"✅ 复杂排行榜数据序列化成功,包含 {len(parsed)} 个条目")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 复杂排行榜数据序列化测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 开始排行榜缓存序列化测试\n")
|
||||
|
||||
test_datetime_serialization()
|
||||
test_boolean_serialization()
|
||||
test_complex_ranking_data()
|
||||
|
||||
print("\n🎉 排行榜缓存序列化测试完成!")
|
||||
Reference in New Issue
Block a user