ruff fix
This commit is contained in:
@@ -88,6 +88,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
|||||||
|
|
||||||
# 使用线程池执行计算密集型操作以避免阻塞事件循环
|
# 使用线程池执行计算密集型操作以避免阻塞事件循环
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def _calculate_pp_sync():
|
def _calculate_pp_sync():
|
||||||
@@ -131,11 +132,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
|||||||
|
|
||||||
|
|
||||||
async def pre_fetch_and_calculate_pp(
|
async def pre_fetch_and_calculate_pp(
|
||||||
score: "Score",
|
score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher
|
||||||
beatmap_id: int,
|
|
||||||
session: AsyncSession,
|
|
||||||
redis,
|
|
||||||
fetcher
|
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
优化版PP计算:预先获取beatmap文件并使用缓存
|
优化版PP计算:预先获取beatmap文件并使用缓存
|
||||||
@@ -148,9 +145,7 @@ async def pre_fetch_and_calculate_pp(
|
|||||||
if settings.suspicious_score_check:
|
if settings.suspicious_score_check:
|
||||||
beatmap_banned = (
|
beatmap_banned = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
select(exists()).where(
|
select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id)
|
||||||
col(BannedBeatmaps.beatmap_id) == beatmap_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if beatmap_banned:
|
if beatmap_banned:
|
||||||
@@ -184,10 +179,7 @@ async def pre_fetch_and_calculate_pp(
|
|||||||
|
|
||||||
|
|
||||||
async def batch_calculate_pp(
|
async def batch_calculate_pp(
|
||||||
scores_data: list[tuple["Score", int]],
|
scores_data: list[tuple["Score", int]], session: AsyncSession, redis, fetcher
|
||||||
session: AsyncSession,
|
|
||||||
redis,
|
|
||||||
fetcher
|
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""
|
"""
|
||||||
批量计算PP:适用于重新计算或批量处理场景
|
批量计算PP:适用于重新计算或批量处理场景
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from app.models.score import GameMode
|
|||||||
from .lazer_user import BASE_INCLUDES, User, UserResp
|
from .lazer_user import BASE_INCLUDES, User, UserResp
|
||||||
|
|
||||||
from pydantic import BaseModel, field_validator, model_validator
|
from pydantic import BaseModel, field_validator, model_validator
|
||||||
from sqlalchemy import Boolean, JSON, Column, DateTime, Text
|
from sqlalchemy import JSON, Boolean, Column, DateTime, Text
|
||||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
|
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -205,7 +205,18 @@ class BeatmapsetResp(BeatmapsetBase):
|
|||||||
favourite_count: int = 0
|
favourite_count: int = 0
|
||||||
recent_favourites: list[UserResp] = Field(default_factory=list)
|
recent_favourites: list[UserResp] = Field(default_factory=list)
|
||||||
|
|
||||||
@field_validator('nsfw', 'spotlight', 'video', 'can_be_hyped', 'discussion_locked', 'storyboard', 'discussion_enabled', 'is_scoreable', 'has_favourited', mode='before')
|
@field_validator(
|
||||||
|
"nsfw",
|
||||||
|
"spotlight",
|
||||||
|
"video",
|
||||||
|
"can_be_hyped",
|
||||||
|
"discussion_locked",
|
||||||
|
"storyboard",
|
||||||
|
"discussion_enabled",
|
||||||
|
"is_scoreable",
|
||||||
|
"has_favourited",
|
||||||
|
mode="before",
|
||||||
|
)
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_bool_fields(cls, v):
|
def validate_bool_fields(cls, v):
|
||||||
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
|
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
|
||||||
|
|||||||
@@ -2,13 +2,16 @@
|
|||||||
数据库字段类型工具
|
数据库字段类型工具
|
||||||
提供处理数据库和 Pydantic 之间类型转换的工具
|
提供处理数据库和 Pydantic 之间类型转换的工具
|
||||||
"""
|
"""
|
||||||
from typing import Any, Union
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
from sqlalchemy import Boolean
|
from sqlalchemy import Boolean
|
||||||
|
|
||||||
|
|
||||||
def bool_field_validator(field_name: str):
|
def bool_field_validator(field_name: str):
|
||||||
"""为特定布尔字段创建验证器,处理数据库中的 0/1 整数"""
|
"""为特定布尔字段创建验证器,处理数据库中的 0/1 整数"""
|
||||||
|
|
||||||
@field_validator(field_name, mode="before")
|
@field_validator(field_name, mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_bool_field(cls, v: Any) -> bool:
|
def validate_bool_field(cls, v: Any) -> bool:
|
||||||
@@ -16,20 +19,21 @@ def bool_field_validator(field_name: str):
|
|||||||
if isinstance(v, int):
|
if isinstance(v, int):
|
||||||
return bool(v)
|
return bool(v)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
return validate_bool_field
|
return validate_bool_field
|
||||||
|
|
||||||
|
|
||||||
def create_bool_field(**kwargs):
|
def create_bool_field(**kwargs):
|
||||||
"""创建一个带有正确 SQLAlchemy 列定义的布尔字段"""
|
"""创建一个带有正确 SQLAlchemy 列定义的布尔字段"""
|
||||||
from sqlmodel import Field, Column
|
from sqlmodel import Column, Field
|
||||||
|
|
||||||
# 如果没有指定 sa_column,则使用 Boolean 类型
|
# 如果没有指定 sa_column,则使用 Boolean 类型
|
||||||
if 'sa_column' not in kwargs:
|
if "sa_column" not in kwargs:
|
||||||
# 处理 index 参数
|
# 处理 index 参数
|
||||||
index = kwargs.pop('index', False)
|
index = kwargs.pop("index", False)
|
||||||
if index:
|
if index:
|
||||||
kwargs['sa_column'] = Column(Boolean, index=True)
|
kwargs["sa_column"] = Column(Boolean, index=True)
|
||||||
else:
|
else:
|
||||||
kwargs['sa_column'] = Column(Boolean)
|
kwargs["sa_column"] = Column(Boolean)
|
||||||
|
|
||||||
return Field(**kwargs)
|
return Field(**kwargs)
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ class UserBase(UTCBaseModel, SQLModel):
|
|||||||
is_qat: bool = False
|
is_qat: bool = False
|
||||||
is_bng: bool = False
|
is_bng: bool = False
|
||||||
|
|
||||||
@field_validator('playmode', mode='before')
|
@field_validator("playmode", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_playmode(cls, v):
|
def validate_playmode(cls, v):
|
||||||
"""将字符串转换为 GameMode 枚举"""
|
"""将字符串转换为 GameMode 枚举"""
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
|||||||
sa_column=Column(JSON), default_factory=dict
|
sa_column=Column(JSON), default_factory=dict
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator('maximum_statistics', mode='before')
|
@field_validator("maximum_statistics", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_maximum_statistics(cls, v):
|
def validate_maximum_statistics(cls, v):
|
||||||
"""处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举"""
|
"""处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举"""
|
||||||
@@ -151,7 +151,7 @@ class Score(ScoreBase, table=True):
|
|||||||
gamemode: GameMode = Field(index=True)
|
gamemode: GameMode = Field(index=True)
|
||||||
pinned_order: int = Field(default=0, exclude=True)
|
pinned_order: int = Field(default=0, exclude=True)
|
||||||
|
|
||||||
@field_validator('gamemode', mode='before')
|
@field_validator("gamemode", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_gamemode(cls, v):
|
def validate_gamemode(cls, v):
|
||||||
"""将字符串转换为 GameMode 枚举"""
|
"""将字符串转换为 GameMode 枚举"""
|
||||||
@@ -209,7 +209,16 @@ class ScoreResp(ScoreBase):
|
|||||||
ranked: bool = False
|
ranked: bool = False
|
||||||
current_user_attributes: CurrentUserAttributes | None = None
|
current_user_attributes: CurrentUserAttributes | None = None
|
||||||
|
|
||||||
@field_validator('has_replay', 'passed', 'preserve', 'is_perfect_combo', 'legacy_perfect', 'processed', 'ranked', mode='before')
|
@field_validator(
|
||||||
|
"has_replay",
|
||||||
|
"passed",
|
||||||
|
"preserve",
|
||||||
|
"is_perfect_combo",
|
||||||
|
"legacy_perfect",
|
||||||
|
"processed",
|
||||||
|
"ranked",
|
||||||
|
mode="before",
|
||||||
|
)
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_bool_fields(cls, v):
|
def validate_bool_fields(cls, v):
|
||||||
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
|
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
|
||||||
@@ -217,7 +226,7 @@ class ScoreResp(ScoreBase):
|
|||||||
return bool(v)
|
return bool(v)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator('statistics', 'maximum_statistics', mode='before')
|
@field_validator("statistics", "maximum_statistics", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_statistics_fields(cls, v):
|
def validate_statistics_fields(cls, v):
|
||||||
"""处理统计字段中的字符串键,转换为 HitResult 枚举"""
|
"""处理统计字段中的字符串键,转换为 HitResult 枚举"""
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class UserStatisticsBase(SQLModel):
|
|||||||
replays_watched_by_others: int = Field(default=0)
|
replays_watched_by_others: int = Field(default=0)
|
||||||
is_ranked: bool = Field(default=True)
|
is_ranked: bool = Field(default=True)
|
||||||
|
|
||||||
@field_validator('mode', mode='before')
|
@field_validator("mode", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_mode(cls, v):
|
def validate_mode(cls, v):
|
||||||
"""将字符串转换为 GameMode 枚举"""
|
"""将字符串转换为 GameMode 枚举"""
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from app.dependencies.database import get_redis
|
from app.dependencies.database import get_redis
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
@@ -11,6 +10,7 @@ from httpx import AsyncClient
|
|||||||
|
|
||||||
class TokenAuthError(Exception):
|
class TokenAuthError(Exception):
|
||||||
"""Token 授权失败异常"""
|
"""Token 授权失败异常"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class BaseFetcher:
|
|||||||
return await self._request_with_retry(url, method, **kwargs)
|
return await self._request_with_retry(url, method, **kwargs)
|
||||||
|
|
||||||
async def _request_with_retry(
|
async def _request_with_retry(
|
||||||
self, url: str, method: str = "GET", max_retries: Optional[int] = None, **kwargs
|
self, url: str, method: str = "GET", max_retries: int | None = None, **kwargs
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
带重试机制的请求方法
|
带重试机制的请求方法
|
||||||
@@ -126,7 +126,9 @@ class BaseFetcher:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.error(f"Request failed after {max_retries + 1} attempts: {e}")
|
logger.error(
|
||||||
|
f"Request failed after {max_retries + 1} attempts: {e}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# 如果所有重试都失败了
|
# 如果所有重试都失败了
|
||||||
@@ -194,9 +196,13 @@ class BaseFetcher:
|
|||||||
f"fetcher:refresh_token:{self.client_id}",
|
f"fetcher:refresh_token:{self.client_id}",
|
||||||
self.refresh_token,
|
self.refresh_token,
|
||||||
)
|
)
|
||||||
logger.info(f"Successfully refreshed access token for client {self.client_id}")
|
logger.info(
|
||||||
|
f"Successfully refreshed access token for client {self.client_id}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to refresh access token for client {self.client_id}: {e}")
|
logger.error(
|
||||||
|
f"Failed to refresh access token for client {self.client_id}: {e}"
|
||||||
|
)
|
||||||
# 清除无效的 token,要求重新授权
|
# 清除无效的 token,要求重新授权
|
||||||
self.access_token = ""
|
self.access_token = ""
|
||||||
self.refresh_token = ""
|
self.refresh_token = ""
|
||||||
@@ -204,7 +210,9 @@ class BaseFetcher:
|
|||||||
redis = get_redis()
|
redis = get_redis()
|
||||||
await redis.delete(f"fetcher:access_token:{self.client_id}")
|
await redis.delete(f"fetcher:access_token:{self.client_id}")
|
||||||
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
|
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
|
||||||
logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}")
|
logger.warning(
|
||||||
|
f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _trigger_reauthorization(self) -> None:
|
async def _trigger_reauthorization(self) -> None:
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
return json.loads(cursor_json)
|
return json.loads(cursor_json)
|
||||||
except Exception:
|
except Exception:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
|
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
|
||||||
logger.opt(colors=True).debug(
|
logger.opt(colors=True).debug(
|
||||||
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
|
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
|
||||||
@@ -164,9 +165,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
# 将结果缓存 15 分钟
|
# 将结果缓存 15 分钟
|
||||||
cache_ttl = 15 * 60 # 15 分钟
|
cache_ttl = 15 * 60 # 15 分钟
|
||||||
await redis_client.set(
|
await redis_client.set(
|
||||||
cache_key,
|
cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl
|
||||||
json.dumps(api_response, separators=(",", ":")),
|
|
||||||
ex=cache_ttl
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.opt(colors=True).debug(
|
logger.opt(colors=True).debug(
|
||||||
@@ -178,10 +177,12 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
|
|
||||||
# 智能预取:只在用户明确搜索时才预取,避免过多API请求
|
# 智能预取:只在用户明确搜索时才预取,避免过多API请求
|
||||||
# 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
|
# 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
|
||||||
if (api_response.get("cursor") and
|
if api_response.get("cursor") and (
|
||||||
(query.q or query.s != "leaderboard" or cursor)):
|
query.q or query.s != "leaderboard" or cursor
|
||||||
|
):
|
||||||
# 在后台预取下1页(减少预取量)
|
# 在后台预取下1页(减少预取量)
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# 不立即创建任务,而是延迟一段时间再预取
|
# 不立即创建任务,而是延迟一段时间再预取
|
||||||
async def delayed_prefetch():
|
async def delayed_prefetch():
|
||||||
await asyncio.sleep(3.0) # 延迟3秒
|
await asyncio.sleep(3.0) # 延迟3秒
|
||||||
@@ -200,8 +201,11 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
async def prefetch_next_pages(
|
async def prefetch_next_pages(
|
||||||
self, query: SearchQueryModel, current_cursor: Cursor,
|
self,
|
||||||
redis_client: redis.Redis, pages: int = 3
|
query: SearchQueryModel,
|
||||||
|
current_cursor: Cursor,
|
||||||
|
redis_client: redis.Redis,
|
||||||
|
pages: int = 3,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""预取下几页内容"""
|
"""预取下几页内容"""
|
||||||
if not current_cursor:
|
if not current_cursor:
|
||||||
@@ -269,7 +273,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
await redis_client.set(
|
await redis_client.set(
|
||||||
next_cache_key,
|
next_cache_key,
|
||||||
json.dumps(api_response, separators=(",", ":")),
|
json.dumps(api_response, separators=(",", ":")),
|
||||||
ex=prefetch_ttl
|
ex=prefetch_ttl,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.opt(colors=True).debug(
|
logger.opt(colors=True).debug(
|
||||||
@@ -317,7 +321,6 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if api_response.get("cursor"):
|
if api_response.get("cursor"):
|
||||||
cursor_dict = api_response["cursor"]
|
cursor_dict = api_response["cursor"]
|
||||||
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
|
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
|
||||||
@@ -327,7 +330,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
await redis_client.set(
|
await redis_client.set(
|
||||||
cache_key,
|
cache_key,
|
||||||
json.dumps(api_response, separators=(",", ":")),
|
json.dumps(api_response, separators=(",", ":")),
|
||||||
ex=cache_ttl
|
ex=cache_ttl,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.opt(colors=True).info(
|
logger.opt(colors=True).info(
|
||||||
@@ -335,7 +338,6 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
|
f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if api_response.get("cursor"):
|
if api_response.get("cursor"):
|
||||||
await self.prefetch_next_pages(
|
await self.prefetch_next_pages(
|
||||||
query, api_response["cursor"], redis_client, pages=2
|
query, api_response["cursor"], redis_client, pages=2
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Rate limiter for osu! API requests to avoid abuse detection.
|
|||||||
- 突发:短时间内最多 200 次额外请求
|
- 突发:短时间内最多 200 次额外请求
|
||||||
- 建议:每分钟不超过 60 次请求以避免滥用检测
|
- 建议:每分钟不超过 60 次请求以避免滥用检测
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import asyncio
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from app.config import settings
|
|
||||||
from app.database.lazer_user import User
|
from app.database.lazer_user import User
|
||||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
from app.database.statistics import UserStatistics, UserStatisticsResp
|
||||||
from app.dependencies.database import Database, get_redis
|
from app.dependencies.database import Database, get_redis
|
||||||
@@ -132,7 +131,9 @@ async def get_user(
|
|||||||
if is_id_query:
|
if is_id_query:
|
||||||
try:
|
try:
|
||||||
user_id_for_cache = int(user)
|
user_id_for_cache = int(user)
|
||||||
cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset)
|
cached_v1_user = await cache_service.get_v1_user_from_cache(
|
||||||
|
user_id_for_cache, ruleset
|
||||||
|
)
|
||||||
if cached_v1_user:
|
if cached_v1_user:
|
||||||
return [V1User(**cached_v1_user)]
|
return [V1User(**cached_v1_user)]
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
|
|||||||
@@ -2,15 +2,15 @@
|
|||||||
缓存管理和监控接口
|
缓存管理和监控接口
|
||||||
提供缓存统计、清理和预热功能
|
提供缓存统计、清理和预热功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from app.dependencies.database import get_redis
|
from app.dependencies.database import get_redis
|
||||||
from app.dependencies.user import get_current_user
|
|
||||||
from app.service.user_cache_service import get_user_cache_service
|
from app.service.user_cache_service import get_user_cache_service
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Security
|
from fastapi import Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
@@ -54,13 +54,10 @@ async def get_cache_stats(
|
|||||||
hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0
|
hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0
|
||||||
redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2)
|
redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2)
|
||||||
|
|
||||||
return CacheStatsResponse(
|
return CacheStatsResponse(user_cache=user_cache_stats, redis_info=redis_stats)
|
||||||
user_cache=user_cache_stats,
|
|
||||||
redis_info=redis_stats
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(500, f"Failed to get cache stats: {str(e)}")
|
raise HTTPException(500, f"Failed to get cache stats: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -80,7 +77,7 @@ async def invalidate_user_cache(
|
|||||||
await cache_service.invalidate_v1_user_cache(user_id)
|
await cache_service.invalidate_v1_user_cache(user_id)
|
||||||
return {"message": f"Cache invalidated for user {user_id}"}
|
return {"message": f"Cache invalidated for user {user_id}"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(500, f"Failed to invalidate cache: {str(e)}")
|
raise HTTPException(500, f"Failed to invalidate cache: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -106,7 +103,7 @@ async def clear_all_user_cache(
|
|||||||
return {"message": "No cache entries found"}
|
return {"message": "No cache entries found"}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(500, f"Failed to clear cache: {str(e)}")
|
raise HTTPException(500, f"Failed to clear cache: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
class CacheWarmupRequest(BaseModel):
|
class CacheWarmupRequest(BaseModel):
|
||||||
@@ -131,14 +128,18 @@ async def warmup_cache(
|
|||||||
if request.user_ids:
|
if request.user_ids:
|
||||||
# 预热指定用户
|
# 预热指定用户
|
||||||
from app.dependencies.database import with_db
|
from app.dependencies.database import with_db
|
||||||
|
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
await cache_service.preload_user_cache(session, request.user_ids)
|
await cache_service.preload_user_cache(session, request.user_ids)
|
||||||
return {"message": f"Warmed up cache for {len(request.user_ids)} users"}
|
return {"message": f"Warmed up cache for {len(request.user_ids)} users"}
|
||||||
else:
|
else:
|
||||||
# 预热活跃用户
|
# 预热活跃用户
|
||||||
from app.scheduler.user_cache_scheduler import schedule_user_cache_preload_task
|
from app.scheduler.user_cache_scheduler import (
|
||||||
|
schedule_user_cache_preload_task,
|
||||||
|
)
|
||||||
|
|
||||||
await schedule_user_cache_preload_task()
|
await schedule_user_cache_preload_task()
|
||||||
return {"message": f"Warmed up cache for top {request.limit} active users"}
|
return {"message": f"Warmed up cache for top {request.limit} active users"}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(500, f"Failed to warmup cache: {str(e)}")
|
raise HTTPException(500, f"Failed to warmup cache: {e!s}")
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ async def get_country_ranking(
|
|||||||
|
|
||||||
# 创建后台任务来缓存数据
|
# 创建后台任务来缓存数据
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.create_task(cache_task)
|
asyncio.create_task(cache_task)
|
||||||
|
|
||||||
# 返回当前页的结果
|
# 返回当前页的结果
|
||||||
@@ -144,16 +145,12 @@ async def get_user_ranking(
|
|||||||
cache_service = get_ranking_cache_service(redis)
|
cache_service = get_ranking_cache_service(redis)
|
||||||
|
|
||||||
# 尝试从缓存获取数据
|
# 尝试从缓存获取数据
|
||||||
cached_data = await cache_service.get_cached_ranking(
|
cached_data = await cache_service.get_cached_ranking(ruleset, type, country, page)
|
||||||
ruleset, type, country, page
|
|
||||||
)
|
|
||||||
|
|
||||||
if cached_data:
|
if cached_data:
|
||||||
# 从缓存返回数据
|
# 从缓存返回数据
|
||||||
return TopUsersResponse(
|
return TopUsersResponse(
|
||||||
ranking=[
|
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
|
||||||
UserStatisticsResp.model_validate(item) for item in cached_data
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 缓存未命中,从数据库查询
|
# 缓存未命中,从数据库查询
|
||||||
@@ -191,11 +188,17 @@ async def get_user_ranking(
|
|||||||
# 使用配置文件中的TTL设置
|
# 使用配置文件中的TTL设置
|
||||||
cache_data = [item.model_dump() for item in ranking_data]
|
cache_data = [item.model_dump() for item in ranking_data]
|
||||||
cache_task = cache_service.cache_ranking(
|
cache_task = cache_service.cache_ranking(
|
||||||
ruleset, type, cache_data, country, page, ttl=settings.ranking_cache_expire_minutes * 60
|
ruleset,
|
||||||
|
type,
|
||||||
|
cache_data,
|
||||||
|
country,
|
||||||
|
page,
|
||||||
|
ttl=settings.ranking_cache_expire_minutes * 60,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建后台任务来缓存数据
|
# 创建后台任务来缓存数据
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.create_task(cache_task)
|
asyncio.create_task(cache_task)
|
||||||
|
|
||||||
resp = TopUsersResponse(ranking=ranking_data)
|
resp = TopUsersResponse(ranking=ranking_data)
|
||||||
|
|||||||
@@ -73,7 +73,9 @@ async def get_users(
|
|||||||
# 查询未缓存的用户
|
# 查询未缓存的用户
|
||||||
if uncached_user_ids:
|
if uncached_user_ids:
|
||||||
searched_users = (
|
searched_users = (
|
||||||
await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))
|
await session.exec(
|
||||||
|
select(User).where(col(User.id).in_(uncached_user_ids))
|
||||||
|
)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
# 将查询到的用户添加到缓存并返回
|
# 将查询到的用户添加到缓存并返回
|
||||||
@@ -275,9 +277,13 @@ async def get_user_beatmapsets(
|
|||||||
# 异步缓存结果
|
# 异步缓存结果
|
||||||
async def cache_beatmapsets():
|
async def cache_beatmapsets():
|
||||||
try:
|
try:
|
||||||
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
|
await cache_service.cache_user_beatmapsets(
|
||||||
|
user_id, type.value, resp, limit, offset
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
|
logger.error(
|
||||||
|
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
asyncio.create_task(cache_beatmapsets())
|
asyncio.create_task(cache_beatmapsets())
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""缓存调度器模块"""
|
"""缓存调度器模块"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler
|
from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import get_redis
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.scheduler.user_cache_scheduler import (
|
from app.scheduler.user_cache_scheduler import (
|
||||||
@@ -59,7 +59,9 @@ class CacheScheduler:
|
|||||||
# 从配置文件获取间隔设置
|
# 从配置文件获取间隔设置
|
||||||
check_interval = 5 * 60 # 5分钟检查间隔
|
check_interval = 5 * 60 # 5分钟检查间隔
|
||||||
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
|
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
|
||||||
ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取
|
ranking_cache_interval = (
|
||||||
|
settings.ranking_cache_refresh_interval_minutes * 60
|
||||||
|
) # 从配置读取
|
||||||
user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔
|
user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔
|
||||||
user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔
|
user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔
|
||||||
|
|
||||||
@@ -131,13 +133,12 @@ class CacheScheduler:
|
|||||||
redis = get_redis()
|
redis = get_redis()
|
||||||
|
|
||||||
# 导入排行榜缓存服务
|
# 导入排行榜缓存服务
|
||||||
|
# 使用独立的数据库会话
|
||||||
|
from app.dependencies.database import with_db
|
||||||
from app.service.ranking_cache_service import (
|
from app.service.ranking_cache_service import (
|
||||||
get_ranking_cache_service,
|
|
||||||
schedule_ranking_refresh_task,
|
schedule_ranking_refresh_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用独立的数据库会话
|
|
||||||
from app.dependencies.database import with_db
|
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
await schedule_ranking_refresh_task(session, redis)
|
await schedule_ranking_refresh_task(session, redis)
|
||||||
|
|
||||||
@@ -171,6 +172,7 @@ class CacheScheduler:
|
|||||||
# Beatmap缓存调度器(保持向后兼容)
|
# Beatmap缓存调度器(保持向后兼容)
|
||||||
class BeatmapsetCacheScheduler(CacheScheduler):
|
class BeatmapsetCacheScheduler(CacheScheduler):
|
||||||
"""谱面集缓存调度器 - 为了向后兼容"""
|
"""谱面集缓存调度器 - 为了向后兼容"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
用户缓存预热任务调度器
|
用户缓存预热任务调度器
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import User
|
|
||||||
from app.database.score import Score
|
from app.database.score import Score
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import get_redis
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.service.user_cache_service import get_user_cache_service
|
from app.service.user_cache_service import get_user_cache_service
|
||||||
|
|
||||||
@@ -31,6 +31,7 @@ async def schedule_user_cache_preload_task():
|
|||||||
|
|
||||||
# 使用独立的数据库会话
|
# 使用独立的数据库会话
|
||||||
from app.dependencies.database import with_db
|
from app.dependencies.database import with_db
|
||||||
|
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
# 获取最近24小时内活跃的用户(提交过成绩的用户)
|
# 获取最近24小时内活跃的用户(提交过成绩的用户)
|
||||||
recent_time = datetime.now(UTC) - timedelta(hours=24)
|
recent_time = datetime.now(UTC) - timedelta(hours=24)
|
||||||
@@ -68,6 +69,7 @@ async def schedule_user_cache_warmup_task():
|
|||||||
|
|
||||||
# 使用独立的数据库会话
|
# 使用独立的数据库会话
|
||||||
from app.dependencies.database import with_db
|
from app.dependencies.database import with_db
|
||||||
|
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
# 获取全球排行榜前100的用户
|
# 获取全球排行榜前100的用户
|
||||||
from app.database.statistics import UserStatistics
|
from app.database.statistics import UserStatistics
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
Beatmap缓存预取服务
|
Beatmap缓存预取服务
|
||||||
用于提前缓存热门beatmap,减少成绩计算时的获取延迟
|
用于提前缓存热门beatmap,减少成绩计算时的获取延迟
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -155,9 +156,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS
|
|||||||
|
|
||||||
|
|
||||||
async def schedule_preload_task(
|
async def schedule_preload_task(
|
||||||
session: AsyncSession,
|
session: AsyncSession, redis: Redis, fetcher: "Fetcher"
|
||||||
redis: Redis,
|
|
||||||
fetcher: "Fetcher"
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
定时预加载任务
|
定时预加载任务
|
||||||
|
|||||||
@@ -2,11 +2,12 @@
|
|||||||
用户排行榜缓存服务
|
用户排行榜缓存服务
|
||||||
用于缓存用户排行榜数据,减轻数据库压力
|
用于缓存用户排行榜数据,减轻数据库压力
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import UTC, datetime
|
||||||
import json
|
import json
|
||||||
from datetime import UTC, datetime, timedelta
|
|
||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@@ -33,7 +34,9 @@ class DateTimeEncoder(json.JSONEncoder):
|
|||||||
|
|
||||||
def safe_json_dumps(data) -> str:
|
def safe_json_dumps(data) -> str:
|
||||||
"""安全的 JSON 序列化,支持 datetime 对象"""
|
"""安全的 JSON 序列化,支持 datetime 对象"""
|
||||||
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
|
return json.dumps(
|
||||||
|
data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RankingCacheService:
|
class RankingCacheService:
|
||||||
@@ -107,11 +110,7 @@ class RankingCacheService:
|
|||||||
# 使用配置文件的TTL设置
|
# 使用配置文件的TTL设置
|
||||||
if ttl is None:
|
if ttl is None:
|
||||||
ttl = settings.ranking_cache_expire_minutes * 60
|
ttl = settings.ranking_cache_expire_minutes * 60
|
||||||
await self.redis.set(
|
await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
|
||||||
cache_key,
|
|
||||||
safe_json_dumps(ranking_data),
|
|
||||||
ex=ttl
|
|
||||||
)
|
|
||||||
logger.debug(f"Cached ranking data for {cache_key}")
|
logger.debug(f"Cached ranking data for {cache_key}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error caching ranking: {e}")
|
logger.error(f"Error caching ranking: {e}")
|
||||||
@@ -148,11 +147,7 @@ class RankingCacheService:
|
|||||||
# 使用配置文件的TTL设置,统计信息缓存时间更长
|
# 使用配置文件的TTL设置,统计信息缓存时间更长
|
||||||
if ttl is None:
|
if ttl is None:
|
||||||
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
|
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
|
||||||
await self.redis.set(
|
await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
|
||||||
cache_key,
|
|
||||||
safe_json_dumps(stats),
|
|
||||||
ex=ttl
|
|
||||||
)
|
|
||||||
logger.debug(f"Cached stats for {cache_key}")
|
logger.debug(f"Cached stats for {cache_key}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error caching stats: {e}")
|
logger.error(f"Error caching stats: {e}")
|
||||||
@@ -186,11 +181,7 @@ class RankingCacheService:
|
|||||||
cache_key = self._get_country_cache_key(ruleset, page)
|
cache_key = self._get_country_cache_key(ruleset, page)
|
||||||
if ttl is None:
|
if ttl is None:
|
||||||
ttl = settings.ranking_cache_expire_minutes * 60
|
ttl = settings.ranking_cache_expire_minutes * 60
|
||||||
await self.redis.set(
|
await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
|
||||||
cache_key,
|
|
||||||
safe_json_dumps(ranking_data),
|
|
||||||
ex=ttl
|
|
||||||
)
|
|
||||||
logger.debug(f"Cached country ranking data for {cache_key}")
|
logger.debug(f"Cached country ranking data for {cache_key}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error caching country ranking: {e}")
|
logger.error(f"Error caching country ranking: {e}")
|
||||||
@@ -219,11 +210,7 @@ class RankingCacheService:
|
|||||||
cache_key = self._get_country_stats_cache_key(ruleset)
|
cache_key = self._get_country_stats_cache_key(ruleset)
|
||||||
if ttl is None:
|
if ttl is None:
|
||||||
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
|
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
|
||||||
await self.redis.set(
|
await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
|
||||||
cache_key,
|
|
||||||
safe_json_dumps(stats),
|
|
||||||
ex=ttl
|
|
||||||
)
|
|
||||||
logger.debug(f"Cached country stats for {cache_key}")
|
logger.debug(f"Cached country stats for {cache_key}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error caching country stats: {e}")
|
logger.error(f"Error caching country stats: {e}")
|
||||||
@@ -238,7 +225,9 @@ class RankingCacheService:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""刷新排行榜缓存"""
|
"""刷新排行榜缓存"""
|
||||||
if self._refreshing:
|
if self._refreshing:
|
||||||
logger.info(f"Ranking cache refresh already in progress for {ruleset}:{type}")
|
logger.info(
|
||||||
|
f"Ranking cache refresh already in progress for {ruleset}:{type}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 使用配置文件的设置
|
# 使用配置文件的设置
|
||||||
@@ -264,7 +253,9 @@ class RankingCacheService:
|
|||||||
order_by = col(UserStatistics.ranked_score).desc()
|
order_by = col(UserStatistics.ranked_score).desc()
|
||||||
|
|
||||||
if country:
|
if country:
|
||||||
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
wheres.append(
|
||||||
|
col(UserStatistics.user).has(country_code=country.upper())
|
||||||
|
)
|
||||||
|
|
||||||
# 获取总用户数用于统计
|
# 获取总用户数用于统计
|
||||||
total_users_query = select(UserStatistics).where(*wheres)
|
total_users_query = select(UserStatistics).where(*wheres)
|
||||||
@@ -308,9 +299,7 @@ class RankingCacheService:
|
|||||||
ranking_data.append(user_dict)
|
ranking_data.append(user_dict)
|
||||||
|
|
||||||
# 缓存这一页的数据
|
# 缓存这一页的数据
|
||||||
await self.cache_ranking(
|
await self.cache_ranking(ruleset, type, ranking_data, country, page)
|
||||||
ruleset, type, ranking_data, country, page
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加延迟避免数据库过载
|
# 添加延迟避免数据库过载
|
||||||
if page < max_pages:
|
if page < max_pages:
|
||||||
@@ -334,7 +323,9 @@ class RankingCacheService:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""刷新地区排行榜缓存"""
|
"""刷新地区排行榜缓存"""
|
||||||
if self._refreshing:
|
if self._refreshing:
|
||||||
logger.info(f"Country ranking cache refresh already in progress for {ruleset}")
|
logger.info(
|
||||||
|
f"Country ranking cache refresh already in progress for {ruleset}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if max_pages is None:
|
if max_pages is None:
|
||||||
@@ -346,6 +337,7 @@ class RankingCacheService:
|
|||||||
|
|
||||||
# 获取所有国家
|
# 获取所有国家
|
||||||
from app.database import User
|
from app.database import User
|
||||||
|
|
||||||
countries = (await session.exec(select(User.country_code).distinct())).all()
|
countries = (await session.exec(select(User.country_code).distinct())).all()
|
||||||
|
|
||||||
# 计算每个国家的统计数据
|
# 计算每个国家的统计数据
|
||||||
@@ -430,6 +422,7 @@ class RankingCacheService:
|
|||||||
|
|
||||||
# 获取需要缓存的国家列表(活跃用户数量前20的国家)
|
# 获取需要缓存的国家列表(活跃用户数量前20的国家)
|
||||||
from app.database import User
|
from app.database import User
|
||||||
|
|
||||||
from sqlmodel import func
|
from sqlmodel import func
|
||||||
|
|
||||||
countries_query = (
|
countries_query = (
|
||||||
@@ -456,7 +449,9 @@ class RankingCacheService:
|
|||||||
for country in top_countries:
|
for country in top_countries:
|
||||||
for mode in game_modes:
|
for mode in game_modes:
|
||||||
for ranking_type in ranking_types:
|
for ranking_type in ranking_types:
|
||||||
task = self.refresh_ranking_cache(session, mode, ranking_type, country)
|
task = self.refresh_ranking_cache(
|
||||||
|
session, mode, ranking_type, country
|
||||||
|
)
|
||||||
refresh_tasks.append(task)
|
refresh_tasks.append(task)
|
||||||
|
|
||||||
# 地区排行榜
|
# 地区排行榜
|
||||||
@@ -498,12 +493,14 @@ class RankingCacheService:
|
|||||||
if keys:
|
if keys:
|
||||||
await self.redis.delete(*keys)
|
await self.redis.delete(*keys)
|
||||||
deleted_keys += len(keys)
|
deleted_keys += len(keys)
|
||||||
logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}")
|
logger.info(
|
||||||
|
f"Invalidated {len(keys)} cache keys for {ruleset}:{type}"
|
||||||
|
)
|
||||||
elif ruleset:
|
elif ruleset:
|
||||||
# 删除特定游戏模式的所有缓存
|
# 删除特定游戏模式的所有缓存
|
||||||
patterns = [
|
patterns = [
|
||||||
f"ranking:{ruleset}:*",
|
f"ranking:{ruleset}:*",
|
||||||
f"country_ranking:{ruleset}:*" if include_country_ranking else None
|
f"country_ranking:{ruleset}:*" if include_country_ranking else None,
|
||||||
]
|
]
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
if pattern:
|
if pattern:
|
||||||
|
|||||||
@@ -2,24 +2,23 @@
|
|||||||
用户缓存服务
|
用户缓存服务
|
||||||
用于缓存用户信息,提供热缓存和实时刷新功能
|
用于缓存用户信息,提供热缓存和实时刷新功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
from datetime import UTC, datetime, timedelta
|
from typing import TYPE_CHECKING, Any
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.const import BANCHOBOT_ID
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database import User, UserResp
|
from app.database import User, UserResp
|
||||||
from app.database.lazer_user import SEARCH_INCLUDED
|
from app.database.lazer_user import SEARCH_INCLUDED
|
||||||
from app.database.pp_best_score import PPBestScore
|
from app.database.score import ScoreResp
|
||||||
from app.database.score import Score, ScoreResp
|
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from sqlmodel import col, exists, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -48,16 +47,16 @@ class UserCacheService:
|
|||||||
self._refreshing = False
|
self._refreshing = False
|
||||||
self._background_tasks: set = set()
|
self._background_tasks: set = set()
|
||||||
|
|
||||||
def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str:
|
def _get_v1_user_cache_key(
|
||||||
|
self, user_id: int, ruleset: GameMode | None = None
|
||||||
|
) -> str:
|
||||||
"""生成 V1 用户缓存键"""
|
"""生成 V1 用户缓存键"""
|
||||||
if ruleset:
|
if ruleset:
|
||||||
return f"v1_user:{user_id}:ruleset:{ruleset}"
|
return f"v1_user:{user_id}:ruleset:{ruleset}"
|
||||||
return f"v1_user:{user_id}"
|
return f"v1_user:{user_id}"
|
||||||
|
|
||||||
async def get_v1_user_from_cache(
|
async def get_v1_user_from_cache(
|
||||||
self,
|
self, user_id: int, ruleset: GameMode | None = None
|
||||||
user_id: int,
|
|
||||||
ruleset: GameMode | None = None
|
|
||||||
) -> dict | None:
|
) -> dict | None:
|
||||||
"""从缓存获取 V1 用户信息"""
|
"""从缓存获取 V1 用户信息"""
|
||||||
try:
|
try:
|
||||||
@@ -76,7 +75,7 @@ class UserCacheService:
|
|||||||
user_data: dict,
|
user_data: dict,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
ruleset: GameMode | None = None,
|
ruleset: GameMode | None = None,
|
||||||
expire_seconds: int | None = None
|
expire_seconds: int | None = None,
|
||||||
):
|
):
|
||||||
"""缓存 V1 用户信息"""
|
"""缓存 V1 用户信息"""
|
||||||
try:
|
try:
|
||||||
@@ -97,7 +96,9 @@ class UserCacheService:
|
|||||||
keys = await self.redis.keys(pattern)
|
keys = await self.redis.keys(pattern)
|
||||||
if keys:
|
if keys:
|
||||||
await self.redis.delete(*keys)
|
await self.redis.delete(*keys)
|
||||||
logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}")
|
logger.info(
|
||||||
|
f"Invalidated {len(keys)} V1 cache entries for user {user_id}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error invalidating V1 user cache: {e}")
|
logger.error(f"Error invalidating V1 user cache: {e}")
|
||||||
|
|
||||||
@@ -113,26 +114,20 @@ class UserCacheService:
|
|||||||
score_type: str,
|
score_type: str,
|
||||||
mode: GameMode | None = None,
|
mode: GameMode | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0
|
offset: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""生成用户成绩缓存键"""
|
"""生成用户成绩缓存键"""
|
||||||
mode_part = f":{mode}" if mode else ""
|
mode_part = f":{mode}" if mode else ""
|
||||||
return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}"
|
return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}"
|
||||||
|
|
||||||
def _get_user_beatmapsets_cache_key(
|
def _get_user_beatmapsets_cache_key(
|
||||||
self,
|
self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0
|
||||||
user_id: int,
|
|
||||||
beatmapset_type: str,
|
|
||||||
limit: int = 100,
|
|
||||||
offset: int = 0
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""生成用户谱面集缓存键"""
|
"""生成用户谱面集缓存键"""
|
||||||
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
|
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
|
||||||
|
|
||||||
async def get_user_from_cache(
|
async def get_user_from_cache(
|
||||||
self,
|
self, user_id: int, ruleset: GameMode | None = None
|
||||||
user_id: int,
|
|
||||||
ruleset: GameMode | None = None
|
|
||||||
) -> UserResp | None:
|
) -> UserResp | None:
|
||||||
"""从缓存获取用户信息"""
|
"""从缓存获取用户信息"""
|
||||||
try:
|
try:
|
||||||
@@ -151,7 +146,7 @@ class UserCacheService:
|
|||||||
self,
|
self,
|
||||||
user_resp: UserResp,
|
user_resp: UserResp,
|
||||||
ruleset: GameMode | None = None,
|
ruleset: GameMode | None = None,
|
||||||
expire_seconds: int | None = None
|
expire_seconds: int | None = None,
|
||||||
):
|
):
|
||||||
"""缓存用户信息"""
|
"""缓存用户信息"""
|
||||||
try:
|
try:
|
||||||
@@ -173,14 +168,18 @@ class UserCacheService:
|
|||||||
score_type: str,
|
score_type: str,
|
||||||
mode: GameMode | None = None,
|
mode: GameMode | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0
|
offset: int = 0,
|
||||||
) -> list[ScoreResp] | None:
|
) -> list[ScoreResp] | None:
|
||||||
"""从缓存获取用户成绩"""
|
"""从缓存获取用户成绩"""
|
||||||
try:
|
try:
|
||||||
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
|
cache_key = self._get_user_scores_cache_key(
|
||||||
|
user_id, score_type, mode, limit, offset
|
||||||
|
)
|
||||||
cached_data = await self.redis.get(cache_key)
|
cached_data = await self.redis.get(cache_key)
|
||||||
if cached_data:
|
if cached_data:
|
||||||
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
|
logger.debug(
|
||||||
|
f"User scores cache hit for user {user_id}, type {score_type}"
|
||||||
|
)
|
||||||
data = json.loads(cached_data)
|
data = json.loads(cached_data)
|
||||||
return [ScoreResp(**score_data) for score_data in data]
|
return [ScoreResp(**score_data) for score_data in data]
|
||||||
return None
|
return None
|
||||||
@@ -196,34 +195,38 @@ class UserCacheService:
|
|||||||
mode: GameMode | None = None,
|
mode: GameMode | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
expire_seconds: int | None = None
|
expire_seconds: int | None = None,
|
||||||
):
|
):
|
||||||
"""缓存用户成绩"""
|
"""缓存用户成绩"""
|
||||||
try:
|
try:
|
||||||
if expire_seconds is None:
|
if expire_seconds is None:
|
||||||
expire_seconds = settings.user_scores_cache_expire_seconds
|
expire_seconds = settings.user_scores_cache_expire_seconds
|
||||||
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
|
cache_key = self._get_user_scores_cache_key(
|
||||||
|
user_id, score_type, mode, limit, offset
|
||||||
|
)
|
||||||
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
|
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
|
||||||
scores_json_list = [score.model_dump_json() for score in scores]
|
scores_json_list = [score.model_dump_json() for score in scores]
|
||||||
cached_data = f"[{','.join(scores_json_list)}]"
|
cached_data = f"[{','.join(scores_json_list)}]"
|
||||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||||
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
|
logger.debug(
|
||||||
|
f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error caching user scores: {e}")
|
logger.error(f"Error caching user scores: {e}")
|
||||||
|
|
||||||
async def get_user_beatmapsets_from_cache(
|
async def get_user_beatmapsets_from_cache(
|
||||||
self,
|
self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0
|
||||||
user_id: int,
|
|
||||||
beatmapset_type: str,
|
|
||||||
limit: int = 100,
|
|
||||||
offset: int = 0
|
|
||||||
) -> list[Any] | None:
|
) -> list[Any] | None:
|
||||||
"""从缓存获取用户谱面集"""
|
"""从缓存获取用户谱面集"""
|
||||||
try:
|
try:
|
||||||
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
|
cache_key = self._get_user_beatmapsets_cache_key(
|
||||||
|
user_id, beatmapset_type, limit, offset
|
||||||
|
)
|
||||||
cached_data = await self.redis.get(cache_key)
|
cached_data = await self.redis.get(cache_key)
|
||||||
if cached_data:
|
if cached_data:
|
||||||
logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}")
|
logger.debug(
|
||||||
|
f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}"
|
||||||
|
)
|
||||||
return json.loads(cached_data)
|
return json.loads(cached_data)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -237,23 +240,27 @@ class UserCacheService:
|
|||||||
beatmapsets: list[Any],
|
beatmapsets: list[Any],
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
expire_seconds: int | None = None
|
expire_seconds: int | None = None,
|
||||||
):
|
):
|
||||||
"""缓存用户谱面集"""
|
"""缓存用户谱面集"""
|
||||||
try:
|
try:
|
||||||
if expire_seconds is None:
|
if expire_seconds is None:
|
||||||
expire_seconds = settings.user_beatmapsets_cache_expire_seconds
|
expire_seconds = settings.user_beatmapsets_cache_expire_seconds
|
||||||
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
|
cache_key = self._get_user_beatmapsets_cache_key(
|
||||||
|
user_id, beatmapset_type, limit, offset
|
||||||
|
)
|
||||||
# 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps
|
# 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps
|
||||||
serialized_beatmapsets = []
|
serialized_beatmapsets = []
|
||||||
for bms in beatmapsets:
|
for bms in beatmapsets:
|
||||||
if hasattr(bms, 'model_dump_json'):
|
if hasattr(bms, "model_dump_json"):
|
||||||
serialized_beatmapsets.append(bms.model_dump_json())
|
serialized_beatmapsets.append(bms.model_dump_json())
|
||||||
else:
|
else:
|
||||||
serialized_beatmapsets.append(safe_json_dumps(bms))
|
serialized_beatmapsets.append(safe_json_dumps(bms))
|
||||||
cached_data = f"[{','.join(serialized_beatmapsets)}]"
|
cached_data = f"[{','.join(serialized_beatmapsets)}]"
|
||||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||||
logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s")
|
logger.debug(
|
||||||
|
f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error caching user beatmapsets: {e}")
|
logger.error(f"Error caching user beatmapsets: {e}")
|
||||||
|
|
||||||
@@ -269,7 +276,9 @@ class UserCacheService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error invalidating user cache: {e}")
|
logger.error(f"Error invalidating user cache: {e}")
|
||||||
|
|
||||||
async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None):
|
async def invalidate_user_scores_cache(
|
||||||
|
self, user_id: int, mode: GameMode | None = None
|
||||||
|
):
|
||||||
"""使用户成绩缓存失效"""
|
"""使用户成绩缓存失效"""
|
||||||
try:
|
try:
|
||||||
# 删除用户成绩相关缓存
|
# 删除用户成绩相关缓存
|
||||||
@@ -278,7 +287,9 @@ class UserCacheService:
|
|||||||
keys = await self.redis.keys(pattern)
|
keys = await self.redis.keys(pattern)
|
||||||
if keys:
|
if keys:
|
||||||
await self.redis.delete(*keys)
|
await self.redis.delete(*keys)
|
||||||
logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}")
|
logger.info(
|
||||||
|
f"Invalidated {len(keys)} score cache entries for user {user_id}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error invalidating user scores cache: {e}")
|
logger.error(f"Error invalidating user scores cache: {e}")
|
||||||
|
|
||||||
@@ -293,9 +304,7 @@ class UserCacheService:
|
|||||||
|
|
||||||
# 批量获取用户
|
# 批量获取用户
|
||||||
users = (
|
users = (
|
||||||
await session.exec(
|
await session.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||||
select(User).where(col(User.id).in_(user_ids))
|
|
||||||
)
|
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
# 串行缓存用户信息,避免并发数据库访问问题
|
# 串行缓存用户信息,避免并发数据库访问问题
|
||||||
@@ -324,10 +333,7 @@ class UserCacheService:
|
|||||||
logger.error(f"Error caching single user {user.id}: {e}")
|
logger.error(f"Error caching single user {user.id}: {e}")
|
||||||
|
|
||||||
async def refresh_user_cache_on_score_submit(
|
async def refresh_user_cache_on_score_submit(
|
||||||
self,
|
self, session: AsyncSession, user_id: int, mode: GameMode
|
||||||
session: AsyncSession,
|
|
||||||
user_id: int,
|
|
||||||
mode: GameMode
|
|
||||||
):
|
):
|
||||||
"""成绩提交后刷新用户缓存"""
|
"""成绩提交后刷新用户缓存"""
|
||||||
try:
|
try:
|
||||||
@@ -361,10 +367,20 @@ class UserCacheService:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]),
|
"cached_users": len(
|
||||||
"cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]),
|
[
|
||||||
|
k
|
||||||
|
for k in user_keys
|
||||||
|
if ":scores:" not in k and ":beatmapsets:" not in k
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"cached_v1_users": len(
|
||||||
|
[k for k in v1_user_keys if ":scores:" not in k]
|
||||||
|
),
|
||||||
"cached_user_scores": len([k for k in user_keys if ":scores:" in k]),
|
"cached_user_scores": len([k for k in user_keys if ":scores:" in k]),
|
||||||
"cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]),
|
"cached_user_beatmapsets": len(
|
||||||
|
[k for k in user_keys if ":beatmapsets:" in k]
|
||||||
|
),
|
||||||
"total_cached_entries": len(all_keys),
|
"total_cached_entries": len(all_keys),
|
||||||
"estimated_total_size_mb": (
|
"estimated_total_size_mb": (
|
||||||
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
|
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""测试排行榜缓存序列化修复"""
|
"""测试排行榜缓存序列化修复"""
|
||||||
|
|
||||||
import asyncio
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime, UTC
|
|
||||||
from app.service.ranking_cache_service import DateTimeEncoder, safe_json_dumps
|
from app.service.ranking_cache_service import safe_json_dumps
|
||||||
|
|
||||||
|
|
||||||
def test_datetime_serialization():
|
def test_datetime_serialization():
|
||||||
@@ -16,11 +18,7 @@ def test_datetime_serialization():
|
|||||||
"username": "test_user",
|
"username": "test_user",
|
||||||
"last_updated": datetime.now(UTC),
|
"last_updated": datetime.now(UTC),
|
||||||
"join_date": datetime(2020, 1, 1, tzinfo=UTC),
|
"join_date": datetime(2020, 1, 1, tzinfo=UTC),
|
||||||
"stats": {
|
"stats": {"pp": 1000.0, "accuracy": 95.5, "last_played": datetime.now(UTC)},
|
||||||
"pp": 1000.0,
|
|
||||||
"accuracy": 95.5,
|
|
||||||
"last_played": datetime.now(UTC)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -31,6 +29,7 @@ def test_datetime_serialization():
|
|||||||
|
|
||||||
# 验证可以重新解析
|
# 验证可以重新解析
|
||||||
import json
|
import json
|
||||||
|
|
||||||
parsed = json.loads(json_result)
|
parsed = json.loads(json_result)
|
||||||
assert "last_updated" in parsed
|
assert "last_updated" in parsed
|
||||||
assert isinstance(parsed["last_updated"], str)
|
assert isinstance(parsed["last_updated"], str)
|
||||||
@@ -39,6 +38,7 @@ def test_datetime_serialization():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ datetime 序列化测试失败: {e}")
|
print(f"❌ datetime 序列化测试失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
@@ -48,14 +48,14 @@ def test_boolean_serialization():
|
|||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
"user": {
|
"user": {
|
||||||
"is_active": 1, # 数据库中的整数布尔值
|
"is_active": 1, # 数据库中的整数布尔值
|
||||||
"is_supporter": 0, # 数据库中的整数布尔值
|
"is_supporter": 0, # 数据库中的整数布尔值
|
||||||
"has_profile": True, # 正常布尔值
|
"has_profile": True, # 正常布尔值
|
||||||
},
|
},
|
||||||
"stats": {
|
"stats": {
|
||||||
"is_ranked": 1, # 数据库中的整数布尔值
|
"is_ranked": 1, # 数据库中的整数布尔值
|
||||||
"verified": False, # 正常布尔值
|
"verified": False, # 正常布尔值
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -64,7 +64,11 @@ def test_boolean_serialization():
|
|||||||
json_result = safe_json_dumps(test_data)
|
json_result = safe_json_dumps(test_data)
|
||||||
|
|
||||||
# 检查是否有 Pydantic 序列化警告
|
# 检查是否有 Pydantic 序列化警告
|
||||||
pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)]
|
pydantic_warnings = [
|
||||||
|
warning
|
||||||
|
for warning in w
|
||||||
|
if "PydanticSerializationUnexpectedValue" in str(warning.message)
|
||||||
|
]
|
||||||
if pydantic_warnings:
|
if pydantic_warnings:
|
||||||
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个布尔值序列化警告")
|
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个布尔值序列化警告")
|
||||||
for warning in pydantic_warnings:
|
for warning in pydantic_warnings:
|
||||||
@@ -74,12 +78,14 @@ def test_boolean_serialization():
|
|||||||
|
|
||||||
# 验证序列化结果
|
# 验证序列化结果
|
||||||
import json
|
import json
|
||||||
|
|
||||||
parsed = json.loads(json_result)
|
parsed = json.loads(json_result)
|
||||||
print(f"✅ 布尔值序列化成功,结果: {parsed}")
|
print(f"✅ 布尔值序列化成功,结果: {parsed}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ 布尔值序列化测试失败: {e}")
|
print(f"❌ 布尔值序列化测试失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
@@ -95,8 +101,8 @@ def test_complex_ranking_data():
|
|||||||
"id": 1,
|
"id": 1,
|
||||||
"username": "player1",
|
"username": "player1",
|
||||||
"country_code": "US",
|
"country_code": "US",
|
||||||
"is_active": 1, # 整数布尔值
|
"is_active": 1, # 整数布尔值
|
||||||
"is_supporter": 0, # 整数布尔值
|
"is_supporter": 0, # 整数布尔值
|
||||||
"join_date": datetime(2020, 1, 1, tzinfo=UTC),
|
"join_date": datetime(2020, 1, 1, tzinfo=UTC),
|
||||||
"last_visit": datetime.now(UTC),
|
"last_visit": datetime.now(UTC),
|
||||||
},
|
},
|
||||||
@@ -104,9 +110,9 @@ def test_complex_ranking_data():
|
|||||||
"pp": 8000.0,
|
"pp": 8000.0,
|
||||||
"accuracy": 98.5,
|
"accuracy": 98.5,
|
||||||
"play_count": 5000,
|
"play_count": 5000,
|
||||||
"is_ranked": 1, # 整数布尔值
|
"is_ranked": 1, # 整数布尔值
|
||||||
"last_updated": datetime.now(UTC),
|
"last_updated": datetime.now(UTC),
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2,
|
"id": 2,
|
||||||
@@ -125,8 +131,8 @@ def test_complex_ranking_data():
|
|||||||
"play_count": 4500,
|
"play_count": 4500,
|
||||||
"is_ranked": 1,
|
"is_ranked": 1,
|
||||||
"last_updated": datetime.now(UTC),
|
"last_updated": datetime.now(UTC),
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -134,7 +140,11 @@ def test_complex_ranking_data():
|
|||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
json_result = safe_json_dumps(ranking_data)
|
json_result = safe_json_dumps(ranking_data)
|
||||||
|
|
||||||
pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)]
|
pydantic_warnings = [
|
||||||
|
warning
|
||||||
|
for warning in w
|
||||||
|
if "PydanticSerializationUnexpectedValue" in str(warning.message)
|
||||||
|
]
|
||||||
if pydantic_warnings:
|
if pydantic_warnings:
|
||||||
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个序列化警告")
|
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个序列化警告")
|
||||||
for warning in pydantic_warnings:
|
for warning in pydantic_warnings:
|
||||||
@@ -144,6 +154,7 @@ def test_complex_ranking_data():
|
|||||||
|
|
||||||
# 验证序列化结果
|
# 验证序列化结果
|
||||||
import json
|
import json
|
||||||
|
|
||||||
parsed = json.loads(json_result)
|
parsed = json.loads(json_result)
|
||||||
assert len(parsed) == 2
|
assert len(parsed) == 2
|
||||||
assert parsed[0]["user"]["username"] == "player1"
|
assert parsed[0]["user"]["username"] == "player1"
|
||||||
@@ -152,6 +163,7 @@ def test_complex_ranking_data():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ 复杂排行榜数据序列化测试失败: {e}")
|
print(f"❌ 复杂排行榜数据序列化测试失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user