This commit is contained in:
咕谷酱
2025-08-22 00:07:19 +08:00
parent bade8658ed
commit 80d4237c5d
22 changed files with 423 additions and 356 deletions

View File

@@ -88,6 +88,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
# 使用线程池执行计算密集型操作以避免阻塞事件循环 # 使用线程池执行计算密集型操作以避免阻塞事件循环
import asyncio import asyncio
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
def _calculate_pp_sync(): def _calculate_pp_sync():
@@ -131,11 +132,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
async def pre_fetch_and_calculate_pp( async def pre_fetch_and_calculate_pp(
score: "Score", score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher
beatmap_id: int,
session: AsyncSession,
redis,
fetcher
) -> float: ) -> float:
""" """
优化版PP计算预先获取beatmap文件并使用缓存 优化版PP计算预先获取beatmap文件并使用缓存
@@ -148,9 +145,7 @@ async def pre_fetch_and_calculate_pp(
if settings.suspicious_score_check: if settings.suspicious_score_check:
beatmap_banned = ( beatmap_banned = (
await session.exec( await session.exec(
select(exists()).where( select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id)
col(BannedBeatmaps.beatmap_id) == beatmap_id
)
) )
).first() ).first()
if beatmap_banned: if beatmap_banned:
@@ -184,10 +179,7 @@ async def pre_fetch_and_calculate_pp(
async def batch_calculate_pp( async def batch_calculate_pp(
scores_data: list[tuple["Score", int]], scores_data: list[tuple["Score", int]], session: AsyncSession, redis, fetcher
session: AsyncSession,
redis,
fetcher
) -> list[float]: ) -> list[float]:
""" """
批量计算PP适用于重新计算或批量处理场景 批量计算PP适用于重新计算或批量处理场景

View File

@@ -8,7 +8,7 @@ from app.models.score import GameMode
from .lazer_user import BASE_INCLUDES, User, UserResp from .lazer_user import BASE_INCLUDES, User, UserResp
from pydantic import BaseModel, field_validator, model_validator from pydantic import BaseModel, field_validator, model_validator
from sqlalchemy import Boolean, JSON, Column, DateTime, Text from sqlalchemy import JSON, Boolean, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -205,7 +205,18 @@ class BeatmapsetResp(BeatmapsetBase):
favourite_count: int = 0 favourite_count: int = 0
recent_favourites: list[UserResp] = Field(default_factory=list) recent_favourites: list[UserResp] = Field(default_factory=list)
@field_validator('nsfw', 'spotlight', 'video', 'can_be_hyped', 'discussion_locked', 'storyboard', 'discussion_enabled', 'is_scoreable', 'has_favourited', mode='before') @field_validator(
"nsfw",
"spotlight",
"video",
"can_be_hyped",
"discussion_locked",
"storyboard",
"discussion_enabled",
"is_scoreable",
"has_favourited",
mode="before",
)
@classmethod @classmethod
def validate_bool_fields(cls, v): def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段""" """将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""

View File

@@ -2,13 +2,16 @@
数据库字段类型工具 数据库字段类型工具
提供处理数据库和 Pydantic 之间类型转换的工具 提供处理数据库和 Pydantic 之间类型转换的工具
""" """
from typing import Any, Union
from typing import Any
from pydantic import field_validator from pydantic import field_validator
from sqlalchemy import Boolean from sqlalchemy import Boolean
def bool_field_validator(field_name: str): def bool_field_validator(field_name: str):
"""为特定布尔字段创建验证器,处理数据库中的 0/1 整数""" """为特定布尔字段创建验证器,处理数据库中的 0/1 整数"""
@field_validator(field_name, mode="before") @field_validator(field_name, mode="before")
@classmethod @classmethod
def validate_bool_field(cls, v: Any) -> bool: def validate_bool_field(cls, v: Any) -> bool:
@@ -16,20 +19,21 @@ def bool_field_validator(field_name: str):
if isinstance(v, int): if isinstance(v, int):
return bool(v) return bool(v)
return v return v
return validate_bool_field return validate_bool_field
def create_bool_field(**kwargs): def create_bool_field(**kwargs):
"""创建一个带有正确 SQLAlchemy 列定义的布尔字段""" """创建一个带有正确 SQLAlchemy 列定义的布尔字段"""
from sqlmodel import Field, Column from sqlmodel import Column, Field
# 如果没有指定 sa_column则使用 Boolean 类型 # 如果没有指定 sa_column则使用 Boolean 类型
if 'sa_column' not in kwargs: if "sa_column" not in kwargs:
# 处理 index 参数 # 处理 index 参数
index = kwargs.pop('index', False) index = kwargs.pop("index", False)
if index: if index:
kwargs['sa_column'] = Column(Boolean, index=True) kwargs["sa_column"] = Column(Boolean, index=True)
else: else:
kwargs['sa_column'] = Column(Boolean) kwargs["sa_column"] = Column(Boolean)
return Field(**kwargs) return Field(**kwargs)

View File

@@ -136,7 +136,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_qat: bool = False is_qat: bool = False
is_bng: bool = False is_bng: bool = False
@field_validator('playmode', mode='before') @field_validator("playmode", mode="before")
@classmethod @classmethod
def validate_playmode(cls, v): def validate_playmode(cls, v):
"""将字符串转换为 GameMode 枚举""" """将字符串转换为 GameMode 枚举"""

View File

@@ -100,7 +100,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
sa_column=Column(JSON), default_factory=dict sa_column=Column(JSON), default_factory=dict
) )
@field_validator('maximum_statistics', mode='before') @field_validator("maximum_statistics", mode="before")
@classmethod @classmethod
def validate_maximum_statistics(cls, v): def validate_maximum_statistics(cls, v):
"""处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举""" """处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举"""
@@ -151,7 +151,7 @@ class Score(ScoreBase, table=True):
gamemode: GameMode = Field(index=True) gamemode: GameMode = Field(index=True)
pinned_order: int = Field(default=0, exclude=True) pinned_order: int = Field(default=0, exclude=True)
@field_validator('gamemode', mode='before') @field_validator("gamemode", mode="before")
@classmethod @classmethod
def validate_gamemode(cls, v): def validate_gamemode(cls, v):
"""将字符串转换为 GameMode 枚举""" """将字符串转换为 GameMode 枚举"""
@@ -209,7 +209,16 @@ class ScoreResp(ScoreBase):
ranked: bool = False ranked: bool = False
current_user_attributes: CurrentUserAttributes | None = None current_user_attributes: CurrentUserAttributes | None = None
@field_validator('has_replay', 'passed', 'preserve', 'is_perfect_combo', 'legacy_perfect', 'processed', 'ranked', mode='before') @field_validator(
"has_replay",
"passed",
"preserve",
"is_perfect_combo",
"legacy_perfect",
"processed",
"ranked",
mode="before",
)
@classmethod @classmethod
def validate_bool_fields(cls, v): def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段""" """将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
@@ -217,7 +226,7 @@ class ScoreResp(ScoreBase):
return bool(v) return bool(v)
return v return v
@field_validator('statistics', 'maximum_statistics', mode='before') @field_validator("statistics", "maximum_statistics", mode="before")
@classmethod @classmethod
def validate_statistics_fields(cls, v): def validate_statistics_fields(cls, v):
"""处理统计字段中的字符串键,转换为 HitResult 枚举""" """处理统计字段中的字符串键,转换为 HitResult 枚举"""

View File

@@ -44,7 +44,7 @@ class UserStatisticsBase(SQLModel):
replays_watched_by_others: int = Field(default=0) replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=True) is_ranked: bool = Field(default=True)
@field_validator('mode', mode='before') @field_validator("mode", mode="before")
@classmethod @classmethod
def validate_mode(cls, v): def validate_mode(cls, v):
"""将字符串转换为 GameMode 枚举""" """将字符串转换为 GameMode 枚举"""

View File

@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import time import time
from typing import Optional
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
from app.log import logger from app.log import logger
@@ -11,6 +10,7 @@ from httpx import AsyncClient
class TokenAuthError(Exception): class TokenAuthError(Exception):
"""Token 授权失败异常""" """Token 授权失败异常"""
pass pass
@@ -55,7 +55,7 @@ class BaseFetcher:
return await self._request_with_retry(url, method, **kwargs) return await self._request_with_retry(url, method, **kwargs)
async def _request_with_retry( async def _request_with_retry(
self, url: str, method: str = "GET", max_retries: Optional[int] = None, **kwargs self, url: str, method: str = "GET", max_retries: int | None = None, **kwargs
) -> dict: ) -> dict:
""" """
带重试机制的请求方法 带重试机制的请求方法
@@ -126,7 +126,9 @@ class BaseFetcher:
) )
continue continue
else: else:
logger.error(f"Request failed after {max_retries + 1} attempts: {e}") logger.error(
f"Request failed after {max_retries + 1} attempts: {e}"
)
break break
# 如果所有重试都失败了 # 如果所有重试都失败了
@@ -194,9 +196,13 @@ class BaseFetcher:
f"fetcher:refresh_token:{self.client_id}", f"fetcher:refresh_token:{self.client_id}",
self.refresh_token, self.refresh_token,
) )
logger.info(f"Successfully refreshed access token for client {self.client_id}") logger.info(
f"Successfully refreshed access token for client {self.client_id}"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to refresh access token for client {self.client_id}: {e}") logger.error(
f"Failed to refresh access token for client {self.client_id}: {e}"
)
# 清除无效的 token要求重新授权 # 清除无效的 token要求重新授权
self.access_token = "" self.access_token = ""
self.refresh_token = "" self.refresh_token = ""
@@ -204,7 +210,9 @@ class BaseFetcher:
redis = get_redis() redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}") await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:refresh_token:{self.client_id}") await redis.delete(f"fetcher:refresh_token:{self.client_id}")
logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}") logger.warning(
f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}"
)
raise raise
async def _trigger_reauthorization(self) -> None: async def _trigger_reauthorization(self) -> None:

View File

@@ -101,6 +101,7 @@ class BeatmapsetFetcher(BaseFetcher):
return json.loads(cursor_json) return json.loads(cursor_json)
except Exception: except Exception:
return {} return {}
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp: async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>" f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
@@ -164,9 +165,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 将结果缓存 15 分钟 # 将结果缓存 15 分钟
cache_ttl = 15 * 60 # 15 分钟 cache_ttl = 15 * 60 # 15 分钟
await redis_client.set( await redis_client.set(
cache_key, cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl
json.dumps(api_response, separators=(",", ":")),
ex=cache_ttl
) )
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
@@ -178,10 +177,12 @@ class BeatmapsetFetcher(BaseFetcher):
# 智能预取只在用户明确搜索时才预取避免过多API请求 # 智能预取只在用户明确搜索时才预取避免过多API请求
# 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取 # 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
if (api_response.get("cursor") and if api_response.get("cursor") and (
(query.q or query.s != "leaderboard" or cursor)): query.q or query.s != "leaderboard" or cursor
):
# 在后台预取下1页减少预取量 # 在后台预取下1页减少预取量
import asyncio import asyncio
# 不立即创建任务,而是延迟一段时间再预取 # 不立即创建任务,而是延迟一段时间再预取
async def delayed_prefetch(): async def delayed_prefetch():
await asyncio.sleep(3.0) # 延迟3秒 await asyncio.sleep(3.0) # 延迟3秒
@@ -200,8 +201,11 @@ class BeatmapsetFetcher(BaseFetcher):
return resp return resp
async def prefetch_next_pages( async def prefetch_next_pages(
self, query: SearchQueryModel, current_cursor: Cursor, self,
redis_client: redis.Redis, pages: int = 3 query: SearchQueryModel,
current_cursor: Cursor,
redis_client: redis.Redis,
pages: int = 3,
) -> None: ) -> None:
"""预取下几页内容""" """预取下几页内容"""
if not current_cursor: if not current_cursor:
@@ -269,7 +273,7 @@ class BeatmapsetFetcher(BaseFetcher):
await redis_client.set( await redis_client.set(
next_cache_key, next_cache_key,
json.dumps(api_response, separators=(",", ":")), json.dumps(api_response, separators=(",", ":")),
ex=prefetch_ttl ex=prefetch_ttl,
) )
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
@@ -317,7 +321,6 @@ class BeatmapsetFetcher(BaseFetcher):
params=params, params=params,
) )
if api_response.get("cursor"): if api_response.get("cursor"):
cursor_dict = api_response["cursor"] cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict) api_response["cursor_string"] = self._encode_cursor(cursor_dict)
@@ -327,7 +330,7 @@ class BeatmapsetFetcher(BaseFetcher):
await redis_client.set( await redis_client.set(
cache_key, cache_key,
json.dumps(api_response, separators=(",", ":")), json.dumps(api_response, separators=(",", ":")),
ex=cache_ttl ex=cache_ttl,
) )
logger.opt(colors=True).info( logger.opt(colors=True).info(
@@ -335,7 +338,6 @@ class BeatmapsetFetcher(BaseFetcher):
f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)" f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
) )
if api_response.get("cursor"): if api_response.get("cursor"):
await self.prefetch_next_pages( await self.prefetch_next_pages(
query, api_response["cursor"], redis_client, pages=2 query, api_response["cursor"], redis_client, pages=2

View File

@@ -5,6 +5,7 @@ Rate limiter for osu! API requests to avoid abuse detection.
- 突发:短时间内最多 200 次额外请求 - 突发:短时间内最多 200 次额外请求
- 建议:每分钟不超过 60 次请求以避免滥用检测 - 建议:每分钟不超过 60 次请求以避免滥用检测
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio

View File

@@ -4,7 +4,6 @@ import asyncio
from datetime import datetime from datetime import datetime
from typing import Literal from typing import Literal
from app.config import settings
from app.database.lazer_user import User from app.database.lazer_user import User
from app.database.statistics import UserStatistics, UserStatisticsResp from app.database.statistics import UserStatistics, UserStatisticsResp
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, get_redis
@@ -132,7 +131,9 @@ async def get_user(
if is_id_query: if is_id_query:
try: try:
user_id_for_cache = int(user) user_id_for_cache = int(user)
cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset) cached_v1_user = await cache_service.get_v1_user_from_cache(
user_id_for_cache, ruleset
)
if cached_v1_user: if cached_v1_user:
return [V1User(**cached_v1_user)] return [V1User(**cached_v1_user)]
except (ValueError, TypeError): except (ValueError, TypeError):

View File

@@ -2,15 +2,15 @@
缓存管理和监控接口 缓存管理和监控接口
提供缓存统计、清理和预热功能 提供缓存统计、清理和预热功能
""" """
from __future__ import annotations from __future__ import annotations
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
from app.dependencies.user import get_current_user
from app.service.user_cache_service import get_user_cache_service from app.service.user_cache_service import get_user_cache_service
from .router import router from .router import router
from fastapi import Depends, HTTPException, Security from fastapi import Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis from redis.asyncio import Redis
@@ -54,13 +54,10 @@ async def get_cache_stats(
hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0 hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0
redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2) redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2)
return CacheStatsResponse( return CacheStatsResponse(user_cache=user_cache_stats, redis_info=redis_stats)
user_cache=user_cache_stats,
redis_info=redis_stats
)
except Exception as e: except Exception as e:
raise HTTPException(500, f"Failed to get cache stats: {str(e)}") raise HTTPException(500, f"Failed to get cache stats: {e!s}")
@router.post( @router.post(
@@ -80,7 +77,7 @@ async def invalidate_user_cache(
await cache_service.invalidate_v1_user_cache(user_id) await cache_service.invalidate_v1_user_cache(user_id)
return {"message": f"Cache invalidated for user {user_id}"} return {"message": f"Cache invalidated for user {user_id}"}
except Exception as e: except Exception as e:
raise HTTPException(500, f"Failed to invalidate cache: {str(e)}") raise HTTPException(500, f"Failed to invalidate cache: {e!s}")
@router.post( @router.post(
@@ -106,7 +103,7 @@ async def clear_all_user_cache(
return {"message": "No cache entries found"} return {"message": "No cache entries found"}
except Exception as e: except Exception as e:
raise HTTPException(500, f"Failed to clear cache: {str(e)}") raise HTTPException(500, f"Failed to clear cache: {e!s}")
class CacheWarmupRequest(BaseModel): class CacheWarmupRequest(BaseModel):
@@ -131,14 +128,18 @@ async def warmup_cache(
if request.user_ids: if request.user_ids:
# 预热指定用户 # 预热指定用户
from app.dependencies.database import with_db from app.dependencies.database import with_db
async with with_db() as session: async with with_db() as session:
await cache_service.preload_user_cache(session, request.user_ids) await cache_service.preload_user_cache(session, request.user_ids)
return {"message": f"Warmed up cache for {len(request.user_ids)} users"} return {"message": f"Warmed up cache for {len(request.user_ids)} users"}
else: else:
# 预热活跃用户 # 预热活跃用户
from app.scheduler.user_cache_scheduler import schedule_user_cache_preload_task from app.scheduler.user_cache_scheduler import (
schedule_user_cache_preload_task,
)
await schedule_user_cache_preload_task() await schedule_user_cache_preload_task()
return {"message": f"Warmed up cache for top {request.limit} active users"} return {"message": f"Warmed up cache for top {request.limit} active users"}
except Exception as e: except Exception as e:
raise HTTPException(500, f"Failed to warmup cache: {str(e)}") raise HTTPException(500, f"Failed to warmup cache: {e!s}")

View File

@@ -111,6 +111,7 @@ async def get_country_ranking(
# 创建后台任务来缓存数据 # 创建后台任务来缓存数据
import asyncio import asyncio
asyncio.create_task(cache_task) asyncio.create_task(cache_task)
# 返回当前页的结果 # 返回当前页的结果
@@ -144,16 +145,12 @@ async def get_user_ranking(
cache_service = get_ranking_cache_service(redis) cache_service = get_ranking_cache_service(redis)
# 尝试从缓存获取数据 # 尝试从缓存获取数据
cached_data = await cache_service.get_cached_ranking( cached_data = await cache_service.get_cached_ranking(ruleset, type, country, page)
ruleset, type, country, page
)
if cached_data: if cached_data:
# 从缓存返回数据 # 从缓存返回数据
return TopUsersResponse( return TopUsersResponse(
ranking=[ ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
UserStatisticsResp.model_validate(item) for item in cached_data
]
) )
# 缓存未命中,从数据库查询 # 缓存未命中,从数据库查询
@@ -191,11 +188,17 @@ async def get_user_ranking(
# 使用配置文件中的TTL设置 # 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data] cache_data = [item.model_dump() for item in ranking_data]
cache_task = cache_service.cache_ranking( cache_task = cache_service.cache_ranking(
ruleset, type, cache_data, country, page, ttl=settings.ranking_cache_expire_minutes * 60 ruleset,
type,
cache_data,
country,
page,
ttl=settings.ranking_cache_expire_minutes * 60,
) )
# 创建后台任务来缓存数据 # 创建后台任务来缓存数据
import asyncio import asyncio
asyncio.create_task(cache_task) asyncio.create_task(cache_task)
resp = TopUsersResponse(ranking=ranking_data) resp = TopUsersResponse(ranking=ranking_data)

View File

@@ -73,7 +73,9 @@ async def get_users(
# 查询未缓存的用户 # 查询未缓存的用户
if uncached_user_ids: if uncached_user_ids:
searched_users = ( searched_users = (
await session.exec(select(User).where(col(User.id).in_(uncached_user_ids))) await session.exec(
select(User).where(col(User.id).in_(uncached_user_ids))
)
).all() ).all()
# 将查询到的用户添加到缓存并返回 # 将查询到的用户添加到缓存并返回
@@ -275,9 +277,13 @@ async def get_user_beatmapsets(
# 异步缓存结果 # 异步缓存结果
async def cache_beatmapsets(): async def cache_beatmapsets():
try: try:
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset) await cache_service.cache_user_beatmapsets(
user_id, type.value, resp, limit, offset
)
except Exception as e: except Exception as e:
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}") logger.error(
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
)
asyncio.create_task(cache_beatmapsets()) asyncio.create_task(cache_beatmapsets())

View File

@@ -1,4 +1,5 @@
"""缓存调度器模块""" """缓存调度器模块"""
from __future__ import annotations from __future__ import annotations
from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from app.config import settings from app.config import settings
from app.dependencies.database import get_db, get_redis from app.dependencies.database import get_redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.log import logger from app.log import logger
from app.scheduler.user_cache_scheduler import ( from app.scheduler.user_cache_scheduler import (
@@ -59,7 +59,9 @@ class CacheScheduler:
# 从配置文件获取间隔设置 # 从配置文件获取间隔设置
check_interval = 5 * 60 # 5分钟检查间隔 check_interval = 5 * 60 # 5分钟检查间隔
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔 beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取 ranking_cache_interval = (
settings.ranking_cache_refresh_interval_minutes * 60
) # 从配置读取
user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔 user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔
user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔 user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔
@@ -131,13 +133,12 @@ class CacheScheduler:
redis = get_redis() redis = get_redis()
# 导入排行榜缓存服务 # 导入排行榜缓存服务
# 使用独立的数据库会话
from app.dependencies.database import with_db
from app.service.ranking_cache_service import ( from app.service.ranking_cache_service import (
get_ranking_cache_service,
schedule_ranking_refresh_task, schedule_ranking_refresh_task,
) )
# 使用独立的数据库会话
from app.dependencies.database import with_db
async with with_db() as session: async with with_db() as session:
await schedule_ranking_refresh_task(session, redis) await schedule_ranking_refresh_task(session, redis)
@@ -171,6 +172,7 @@ class CacheScheduler:
# Beatmap缓存调度器保持向后兼容 # Beatmap缓存调度器保持向后兼容
class BeatmapsetCacheScheduler(CacheScheduler): class BeatmapsetCacheScheduler(CacheScheduler):
"""谱面集缓存调度器 - 为了向后兼容""" """谱面集缓存调度器 - 为了向后兼容"""
pass pass

View File

@@ -1,15 +1,15 @@
""" """
用户缓存预热任务调度器 用户缓存预热任务调度器
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from app.config import settings from app.config import settings
from app.database import User
from app.database.score import Score from app.database.score import Score
from app.dependencies.database import get_db, get_redis from app.dependencies.database import get_redis
from app.log import logger from app.log import logger
from app.service.user_cache_service import get_user_cache_service from app.service.user_cache_service import get_user_cache_service
@@ -31,6 +31,7 @@ async def schedule_user_cache_preload_task():
# 使用独立的数据库会话 # 使用独立的数据库会话
from app.dependencies.database import with_db from app.dependencies.database import with_db
async with with_db() as session: async with with_db() as session:
# 获取最近24小时内活跃的用户提交过成绩的用户 # 获取最近24小时内活跃的用户提交过成绩的用户
recent_time = datetime.now(UTC) - timedelta(hours=24) recent_time = datetime.now(UTC) - timedelta(hours=24)
@@ -68,6 +69,7 @@ async def schedule_user_cache_warmup_task():
# 使用独立的数据库会话 # 使用独立的数据库会话
from app.dependencies.database import with_db from app.dependencies.database import with_db
async with with_db() as session: async with with_db() as session:
# 获取全球排行榜前100的用户 # 获取全球排行榜前100的用户
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics

View File

@@ -2,6 +2,7 @@
Beatmap缓存预取服务 Beatmap缓存预取服务
用于提前缓存热门beatmap减少成绩计算时的获取延迟 用于提前缓存热门beatmap减少成绩计算时的获取延迟
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@@ -155,9 +156,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS
async def schedule_preload_task( async def schedule_preload_task(
session: AsyncSession, session: AsyncSession, redis: Redis, fetcher: "Fetcher"
redis: Redis,
fetcher: "Fetcher"
): ):
""" """
定时预加载任务 定时预加载任务

View File

@@ -2,11 +2,12 @@
用户排行榜缓存服务 用户排行榜缓存服务
用于缓存用户排行榜数据,减轻数据库压力 用于缓存用户排行榜数据,减轻数据库压力
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from datetime import UTC, datetime
import json import json
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
from app.config import settings from app.config import settings
@@ -33,7 +34,9 @@ class DateTimeEncoder(json.JSONEncoder):
def safe_json_dumps(data) -> str: def safe_json_dumps(data) -> str:
"""安全的 JSON 序列化,支持 datetime 对象""" """安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")) return json.dumps(
data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")
)
class RankingCacheService: class RankingCacheService:
@@ -107,11 +110,7 @@ class RankingCacheService:
# 使用配置文件的TTL设置 # 使用配置文件的TTL设置
if ttl is None: if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set( await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
cache_key,
safe_json_dumps(ranking_data),
ex=ttl
)
logger.debug(f"Cached ranking data for {cache_key}") logger.debug(f"Cached ranking data for {cache_key}")
except Exception as e: except Exception as e:
logger.error(f"Error caching ranking: {e}") logger.error(f"Error caching ranking: {e}")
@@ -148,11 +147,7 @@ class RankingCacheService:
# 使用配置文件的TTL设置统计信息缓存时间更长 # 使用配置文件的TTL设置统计信息缓存时间更长
if ttl is None: if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间 ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set( await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
cache_key,
safe_json_dumps(stats),
ex=ttl
)
logger.debug(f"Cached stats for {cache_key}") logger.debug(f"Cached stats for {cache_key}")
except Exception as e: except Exception as e:
logger.error(f"Error caching stats: {e}") logger.error(f"Error caching stats: {e}")
@@ -186,11 +181,7 @@ class RankingCacheService:
cache_key = self._get_country_cache_key(ruleset, page) cache_key = self._get_country_cache_key(ruleset, page)
if ttl is None: if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set( await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
cache_key,
safe_json_dumps(ranking_data),
ex=ttl
)
logger.debug(f"Cached country ranking data for {cache_key}") logger.debug(f"Cached country ranking data for {cache_key}")
except Exception as e: except Exception as e:
logger.error(f"Error caching country ranking: {e}") logger.error(f"Error caching country ranking: {e}")
@@ -219,11 +210,7 @@ class RankingCacheService:
cache_key = self._get_country_stats_cache_key(ruleset) cache_key = self._get_country_stats_cache_key(ruleset)
if ttl is None: if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间 ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set( await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
cache_key,
safe_json_dumps(stats),
ex=ttl
)
logger.debug(f"Cached country stats for {cache_key}") logger.debug(f"Cached country stats for {cache_key}")
except Exception as e: except Exception as e:
logger.error(f"Error caching country stats: {e}") logger.error(f"Error caching country stats: {e}")
@@ -238,7 +225,9 @@ class RankingCacheService:
) -> None: ) -> None:
"""刷新排行榜缓存""" """刷新排行榜缓存"""
if self._refreshing: if self._refreshing:
logger.info(f"Ranking cache refresh already in progress for {ruleset}:{type}") logger.info(
f"Ranking cache refresh already in progress for {ruleset}:{type}"
)
return return
# 使用配置文件的设置 # 使用配置文件的设置
@@ -264,7 +253,9 @@ class RankingCacheService:
order_by = col(UserStatistics.ranked_score).desc() order_by = col(UserStatistics.ranked_score).desc()
if country: if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper())) wheres.append(
col(UserStatistics.user).has(country_code=country.upper())
)
# 获取总用户数用于统计 # 获取总用户数用于统计
total_users_query = select(UserStatistics).where(*wheres) total_users_query = select(UserStatistics).where(*wheres)
@@ -308,9 +299,7 @@ class RankingCacheService:
ranking_data.append(user_dict) ranking_data.append(user_dict)
# 缓存这一页的数据 # 缓存这一页的数据
await self.cache_ranking( await self.cache_ranking(ruleset, type, ranking_data, country, page)
ruleset, type, ranking_data, country, page
)
# 添加延迟避免数据库过载 # 添加延迟避免数据库过载
if page < max_pages: if page < max_pages:
@@ -334,7 +323,9 @@ class RankingCacheService:
) -> None: ) -> None:
"""刷新地区排行榜缓存""" """刷新地区排行榜缓存"""
if self._refreshing: if self._refreshing:
logger.info(f"Country ranking cache refresh already in progress for {ruleset}") logger.info(
f"Country ranking cache refresh already in progress for {ruleset}"
)
return return
if max_pages is None: if max_pages is None:
@@ -346,6 +337,7 @@ class RankingCacheService:
# 获取所有国家 # 获取所有国家
from app.database import User from app.database import User
countries = (await session.exec(select(User.country_code).distinct())).all() countries = (await session.exec(select(User.country_code).distinct())).all()
# 计算每个国家的统计数据 # 计算每个国家的统计数据
@@ -430,6 +422,7 @@ class RankingCacheService:
# 获取需要缓存的国家列表活跃用户数量前20的国家 # 获取需要缓存的国家列表活跃用户数量前20的国家
from app.database import User from app.database import User
from sqlmodel import func from sqlmodel import func
countries_query = ( countries_query = (
@@ -456,7 +449,9 @@ class RankingCacheService:
for country in top_countries: for country in top_countries:
for mode in game_modes: for mode in game_modes:
for ranking_type in ranking_types: for ranking_type in ranking_types:
task = self.refresh_ranking_cache(session, mode, ranking_type, country) task = self.refresh_ranking_cache(
session, mode, ranking_type, country
)
refresh_tasks.append(task) refresh_tasks.append(task)
# 地区排行榜 # 地区排行榜
@@ -498,12 +493,14 @@ class RankingCacheService:
if keys: if keys:
await self.redis.delete(*keys) await self.redis.delete(*keys)
deleted_keys += len(keys) deleted_keys += len(keys)
logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}") logger.info(
f"Invalidated {len(keys)} cache keys for {ruleset}:{type}"
)
elif ruleset: elif ruleset:
# 删除特定游戏模式的所有缓存 # 删除特定游戏模式的所有缓存
patterns = [ patterns = [
f"ranking:{ruleset}:*", f"ranking:{ruleset}:*",
f"country_ranking:{ruleset}:*" if include_country_ranking else None f"country_ranking:{ruleset}:*" if include_country_ranking else None,
] ]
for pattern in patterns: for pattern in patterns:
if pattern: if pattern:

View File

@@ -2,24 +2,23 @@
用户缓存服务 用户缓存服务
用于缓存用户信息,提供热缓存和实时刷新功能 用于缓存用户信息,提供热缓存和实时刷新功能
""" """
from __future__ import annotations from __future__ import annotations
import asyncio from datetime import datetime
import json import json
from datetime import UTC, datetime, timedelta from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
from app.config import settings from app.config import settings
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
from app.database import User, UserResp from app.database import User, UserResp
from app.database.lazer_user import SEARCH_INCLUDED from app.database.lazer_user import SEARCH_INCLUDED
from app.database.pp_best_score import PPBestScore from app.database.score import ScoreResp
from app.database.score import Score, ScoreResp
from app.log import logger from app.log import logger
from app.models.score import GameMode from app.models.score import GameMode
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import col, exists, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -48,16 +47,16 @@ class UserCacheService:
self._refreshing = False self._refreshing = False
self._background_tasks: set = set() self._background_tasks: set = set()
def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str: def _get_v1_user_cache_key(
self, user_id: int, ruleset: GameMode | None = None
) -> str:
"""生成 V1 用户缓存键""" """生成 V1 用户缓存键"""
if ruleset: if ruleset:
return f"v1_user:{user_id}:ruleset:{ruleset}" return f"v1_user:{user_id}:ruleset:{ruleset}"
return f"v1_user:{user_id}" return f"v1_user:{user_id}"
async def get_v1_user_from_cache( async def get_v1_user_from_cache(
self, self, user_id: int, ruleset: GameMode | None = None
user_id: int,
ruleset: GameMode | None = None
) -> dict | None: ) -> dict | None:
"""从缓存获取 V1 用户信息""" """从缓存获取 V1 用户信息"""
try: try:
@@ -76,7 +75,7 @@ class UserCacheService:
user_data: dict, user_data: dict,
user_id: int, user_id: int,
ruleset: GameMode | None = None, ruleset: GameMode | None = None,
expire_seconds: int | None = None expire_seconds: int | None = None,
): ):
"""缓存 V1 用户信息""" """缓存 V1 用户信息"""
try: try:
@@ -97,7 +96,9 @@ class UserCacheService:
keys = await self.redis.keys(pattern) keys = await self.redis.keys(pattern)
if keys: if keys:
await self.redis.delete(*keys) await self.redis.delete(*keys)
logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}") logger.info(
f"Invalidated {len(keys)} V1 cache entries for user {user_id}"
)
except Exception as e: except Exception as e:
logger.error(f"Error invalidating V1 user cache: {e}") logger.error(f"Error invalidating V1 user cache: {e}")
@@ -113,26 +114,20 @@ class UserCacheService:
score_type: str, score_type: str,
mode: GameMode | None = None, mode: GameMode | None = None,
limit: int = 100, limit: int = 100,
offset: int = 0 offset: int = 0,
) -> str: ) -> str:
"""生成用户成绩缓存键""" """生成用户成绩缓存键"""
mode_part = f":{mode}" if mode else "" mode_part = f":{mode}" if mode else ""
return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}" return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}"
def _get_user_beatmapsets_cache_key( def _get_user_beatmapsets_cache_key(
self, self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0
user_id: int,
beatmapset_type: str,
limit: int = 100,
offset: int = 0
) -> str: ) -> str:
"""生成用户谱面集缓存键""" """生成用户谱面集缓存键"""
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}" return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
async def get_user_from_cache( async def get_user_from_cache(
self, self, user_id: int, ruleset: GameMode | None = None
user_id: int,
ruleset: GameMode | None = None
) -> UserResp | None: ) -> UserResp | None:
"""从缓存获取用户信息""" """从缓存获取用户信息"""
try: try:
@@ -151,7 +146,7 @@ class UserCacheService:
self, self,
user_resp: UserResp, user_resp: UserResp,
ruleset: GameMode | None = None, ruleset: GameMode | None = None,
expire_seconds: int | None = None expire_seconds: int | None = None,
): ):
"""缓存用户信息""" """缓存用户信息"""
try: try:
@@ -173,14 +168,18 @@ class UserCacheService:
score_type: str, score_type: str,
mode: GameMode | None = None, mode: GameMode | None = None,
limit: int = 100, limit: int = 100,
offset: int = 0 offset: int = 0,
) -> list[ScoreResp] | None: ) -> list[ScoreResp] | None:
"""从缓存获取用户成绩""" """从缓存获取用户成绩"""
try: try:
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset) cache_key = self._get_user_scores_cache_key(
user_id, score_type, mode, limit, offset
)
cached_data = await self.redis.get(cache_key) cached_data = await self.redis.get(cache_key)
if cached_data: if cached_data:
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}") logger.debug(
f"User scores cache hit for user {user_id}, type {score_type}"
)
data = json.loads(cached_data) data = json.loads(cached_data)
return [ScoreResp(**score_data) for score_data in data] return [ScoreResp(**score_data) for score_data in data]
return None return None
@@ -196,34 +195,38 @@ class UserCacheService:
mode: GameMode | None = None, mode: GameMode | None = None,
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
expire_seconds: int | None = None expire_seconds: int | None = None,
): ):
"""缓存用户成绩""" """缓存用户成绩"""
try: try:
if expire_seconds is None: if expire_seconds is None:
expire_seconds = settings.user_scores_cache_expire_seconds expire_seconds = settings.user_scores_cache_expire_seconds
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset) cache_key = self._get_user_scores_cache_key(
user_id, score_type, mode, limit, offset
)
# 使用 model_dump_json() 而不是 model_dump() + json.dumps() # 使用 model_dump_json() 而不是 model_dump() + json.dumps()
scores_json_list = [score.model_dump_json() for score in scores] scores_json_list = [score.model_dump_json() for score in scores]
cached_data = f"[{','.join(scores_json_list)}]" cached_data = f"[{','.join(scores_json_list)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data) await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s") logger.debug(
f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s"
)
except Exception as e: except Exception as e:
logger.error(f"Error caching user scores: {e}") logger.error(f"Error caching user scores: {e}")
async def get_user_beatmapsets_from_cache( async def get_user_beatmapsets_from_cache(
self, self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0
user_id: int,
beatmapset_type: str,
limit: int = 100,
offset: int = 0
) -> list[Any] | None: ) -> list[Any] | None:
"""从缓存获取用户谱面集""" """从缓存获取用户谱面集"""
try: try:
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset) cache_key = self._get_user_beatmapsets_cache_key(
user_id, beatmapset_type, limit, offset
)
cached_data = await self.redis.get(cache_key) cached_data = await self.redis.get(cache_key)
if cached_data: if cached_data:
logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}") logger.debug(
f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}"
)
return json.loads(cached_data) return json.loads(cached_data)
return None return None
except Exception as e: except Exception as e:
@@ -237,23 +240,27 @@ class UserCacheService:
beatmapsets: list[Any], beatmapsets: list[Any],
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
expire_seconds: int | None = None expire_seconds: int | None = None,
): ):
"""缓存用户谱面集""" """缓存用户谱面集"""
try: try:
if expire_seconds is None: if expire_seconds is None:
expire_seconds = settings.user_beatmapsets_cache_expire_seconds expire_seconds = settings.user_beatmapsets_cache_expire_seconds
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset) cache_key = self._get_user_beatmapsets_cache_key(
user_id, beatmapset_type, limit, offset
)
# 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps # 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps
serialized_beatmapsets = [] serialized_beatmapsets = []
for bms in beatmapsets: for bms in beatmapsets:
if hasattr(bms, 'model_dump_json'): if hasattr(bms, "model_dump_json"):
serialized_beatmapsets.append(bms.model_dump_json()) serialized_beatmapsets.append(bms.model_dump_json())
else: else:
serialized_beatmapsets.append(safe_json_dumps(bms)) serialized_beatmapsets.append(safe_json_dumps(bms))
cached_data = f"[{','.join(serialized_beatmapsets)}]" cached_data = f"[{','.join(serialized_beatmapsets)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data) await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s") logger.debug(
f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s"
)
except Exception as e: except Exception as e:
logger.error(f"Error caching user beatmapsets: {e}") logger.error(f"Error caching user beatmapsets: {e}")
@@ -269,7 +276,9 @@ class UserCacheService:
except Exception as e: except Exception as e:
logger.error(f"Error invalidating user cache: {e}") logger.error(f"Error invalidating user cache: {e}")
async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None): async def invalidate_user_scores_cache(
self, user_id: int, mode: GameMode | None = None
):
"""使用户成绩缓存失效""" """使用户成绩缓存失效"""
try: try:
# 删除用户成绩相关缓存 # 删除用户成绩相关缓存
@@ -278,7 +287,9 @@ class UserCacheService:
keys = await self.redis.keys(pattern) keys = await self.redis.keys(pattern)
if keys: if keys:
await self.redis.delete(*keys) await self.redis.delete(*keys)
logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}") logger.info(
f"Invalidated {len(keys)} score cache entries for user {user_id}"
)
except Exception as e: except Exception as e:
logger.error(f"Error invalidating user scores cache: {e}") logger.error(f"Error invalidating user scores cache: {e}")
@@ -293,9 +304,7 @@ class UserCacheService:
# 批量获取用户 # 批量获取用户
users = ( users = (
await session.exec( await session.exec(select(User).where(col(User.id).in_(user_ids)))
select(User).where(col(User.id).in_(user_ids))
)
).all() ).all()
# 串行缓存用户信息,避免并发数据库访问问题 # 串行缓存用户信息,避免并发数据库访问问题
@@ -324,10 +333,7 @@ class UserCacheService:
logger.error(f"Error caching single user {user.id}: {e}") logger.error(f"Error caching single user {user.id}: {e}")
async def refresh_user_cache_on_score_submit( async def refresh_user_cache_on_score_submit(
self, self, session: AsyncSession, user_id: int, mode: GameMode
session: AsyncSession,
user_id: int,
mode: GameMode
): ):
"""成绩提交后刷新用户缓存""" """成绩提交后刷新用户缓存"""
try: try:
@@ -361,10 +367,20 @@ class UserCacheService:
continue continue
return { return {
"cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]), "cached_users": len(
"cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]), [
k
for k in user_keys
if ":scores:" not in k and ":beatmapsets:" not in k
]
),
"cached_v1_users": len(
[k for k in v1_user_keys if ":scores:" not in k]
),
"cached_user_scores": len([k for k in user_keys if ":scores:" in k]), "cached_user_scores": len([k for k in user_keys if ":scores:" in k]),
"cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]), "cached_user_beatmapsets": len(
[k for k in user_keys if ":beatmapsets:" in k]
),
"total_cached_entries": len(all_keys), "total_cached_entries": len(all_keys),
"estimated_total_size_mb": ( "estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0 round(total_size / 1024 / 1024, 2) if total_size > 0 else 0

View File

@@ -1,10 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""测试排行榜缓存序列化修复""" """测试排行榜缓存序列化修复"""
import asyncio from __future__ import annotations
from datetime import UTC, datetime
import warnings import warnings
from datetime import datetime, UTC
from app.service.ranking_cache_service import DateTimeEncoder, safe_json_dumps from app.service.ranking_cache_service import safe_json_dumps
def test_datetime_serialization(): def test_datetime_serialization():
@@ -16,11 +18,7 @@ def test_datetime_serialization():
"username": "test_user", "username": "test_user",
"last_updated": datetime.now(UTC), "last_updated": datetime.now(UTC),
"join_date": datetime(2020, 1, 1, tzinfo=UTC), "join_date": datetime(2020, 1, 1, tzinfo=UTC),
"stats": { "stats": {"pp": 1000.0, "accuracy": 95.5, "last_played": datetime.now(UTC)},
"pp": 1000.0,
"accuracy": 95.5,
"last_played": datetime.now(UTC)
}
} }
try: try:
@@ -31,6 +29,7 @@ def test_datetime_serialization():
# 验证可以重新解析 # 验证可以重新解析
import json import json
parsed = json.loads(json_result) parsed = json.loads(json_result)
assert "last_updated" in parsed assert "last_updated" in parsed
assert isinstance(parsed["last_updated"], str) assert isinstance(parsed["last_updated"], str)
@@ -39,6 +38,7 @@ def test_datetime_serialization():
except Exception as e: except Exception as e:
print(f"❌ datetime 序列化测试失败: {e}") print(f"❌ datetime 序列化测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@@ -48,14 +48,14 @@ def test_boolean_serialization():
test_data = { test_data = {
"user": { "user": {
"is_active": 1, # 数据库中的整数布尔值 "is_active": 1, # 数据库中的整数布尔值
"is_supporter": 0, # 数据库中的整数布尔值 "is_supporter": 0, # 数据库中的整数布尔值
"has_profile": True, # 正常布尔值 "has_profile": True, # 正常布尔值
}, },
"stats": { "stats": {
"is_ranked": 1, # 数据库中的整数布尔值 "is_ranked": 1, # 数据库中的整数布尔值
"verified": False, # 正常布尔值 "verified": False, # 正常布尔值
} },
} }
try: try:
@@ -64,7 +64,11 @@ def test_boolean_serialization():
json_result = safe_json_dumps(test_data) json_result = safe_json_dumps(test_data)
# 检查是否有 Pydantic 序列化警告 # 检查是否有 Pydantic 序列化警告
pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)] pydantic_warnings = [
warning
for warning in w
if "PydanticSerializationUnexpectedValue" in str(warning.message)
]
if pydantic_warnings: if pydantic_warnings:
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个布尔值序列化警告") print(f"⚠️ 仍有 {len(pydantic_warnings)} 个布尔值序列化警告")
for warning in pydantic_warnings: for warning in pydantic_warnings:
@@ -74,12 +78,14 @@ def test_boolean_serialization():
# 验证序列化结果 # 验证序列化结果
import json import json
parsed = json.loads(json_result) parsed = json.loads(json_result)
print(f"✅ 布尔值序列化成功,结果: {parsed}") print(f"✅ 布尔值序列化成功,结果: {parsed}")
except Exception as e: except Exception as e:
print(f"❌ 布尔值序列化测试失败: {e}") print(f"❌ 布尔值序列化测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@@ -95,8 +101,8 @@ def test_complex_ranking_data():
"id": 1, "id": 1,
"username": "player1", "username": "player1",
"country_code": "US", "country_code": "US",
"is_active": 1, # 整数布尔值 "is_active": 1, # 整数布尔值
"is_supporter": 0, # 整数布尔值 "is_supporter": 0, # 整数布尔值
"join_date": datetime(2020, 1, 1, tzinfo=UTC), "join_date": datetime(2020, 1, 1, tzinfo=UTC),
"last_visit": datetime.now(UTC), "last_visit": datetime.now(UTC),
}, },
@@ -104,9 +110,9 @@ def test_complex_ranking_data():
"pp": 8000.0, "pp": 8000.0,
"accuracy": 98.5, "accuracy": 98.5,
"play_count": 5000, "play_count": 5000,
"is_ranked": 1, # 整数布尔值 "is_ranked": 1, # 整数布尔值
"last_updated": datetime.now(UTC), "last_updated": datetime.now(UTC),
} },
}, },
{ {
"id": 2, "id": 2,
@@ -125,8 +131,8 @@ def test_complex_ranking_data():
"play_count": 4500, "play_count": 4500,
"is_ranked": 1, "is_ranked": 1,
"last_updated": datetime.now(UTC), "last_updated": datetime.now(UTC),
} },
} },
] ]
try: try:
@@ -134,7 +140,11 @@ def test_complex_ranking_data():
warnings.simplefilter("always") warnings.simplefilter("always")
json_result = safe_json_dumps(ranking_data) json_result = safe_json_dumps(ranking_data)
pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)] pydantic_warnings = [
warning
for warning in w
if "PydanticSerializationUnexpectedValue" in str(warning.message)
]
if pydantic_warnings: if pydantic_warnings:
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个序列化警告") print(f"⚠️ 仍有 {len(pydantic_warnings)} 个序列化警告")
for warning in pydantic_warnings: for warning in pydantic_warnings:
@@ -144,6 +154,7 @@ def test_complex_ranking_data():
# 验证序列化结果 # 验证序列化结果
import json import json
parsed = json.loads(json_result) parsed = json.loads(json_result)
assert len(parsed) == 2 assert len(parsed) == 2
assert parsed[0]["user"]["username"] == "player1" assert parsed[0]["user"]["username"] == "player1"
@@ -152,6 +163,7 @@ def test_complex_ranking_data():
except Exception as e: except Exception as e:
print(f"❌ 复杂排行榜数据序列化测试失败: {e}") print(f"❌ 复杂排行榜数据序列化测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()