This commit is contained in:
咕谷酱
2025-08-22 00:07:19 +08:00
parent bade8658ed
commit 80d4237c5d
22 changed files with 423 additions and 356 deletions

View File

@@ -88,6 +88,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
# 使用线程池执行计算密集型操作以避免阻塞事件循环 # 使用线程池执行计算密集型操作以避免阻塞事件循环
import asyncio import asyncio
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
def _calculate_pp_sync(): def _calculate_pp_sync():
@@ -131,11 +132,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
async def pre_fetch_and_calculate_pp( async def pre_fetch_and_calculate_pp(
score: "Score", score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher
beatmap_id: int,
session: AsyncSession,
redis,
fetcher
) -> float: ) -> float:
""" """
优化版PP计算预先获取beatmap文件并使用缓存 优化版PP计算预先获取beatmap文件并使用缓存
@@ -148,9 +145,7 @@ async def pre_fetch_and_calculate_pp(
if settings.suspicious_score_check: if settings.suspicious_score_check:
beatmap_banned = ( beatmap_banned = (
await session.exec( await session.exec(
select(exists()).where( select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id)
col(BannedBeatmaps.beatmap_id) == beatmap_id
)
) )
).first() ).first()
if beatmap_banned: if beatmap_banned:
@@ -184,10 +179,7 @@ async def pre_fetch_and_calculate_pp(
async def batch_calculate_pp( async def batch_calculate_pp(
scores_data: list[tuple["Score", int]], scores_data: list[tuple["Score", int]], session: AsyncSession, redis, fetcher
session: AsyncSession,
redis,
fetcher
) -> list[float]: ) -> list[float]:
""" """
批量计算PP适用于重新计算或批量处理场景 批量计算PP适用于重新计算或批量处理场景

View File

@@ -142,14 +142,14 @@ class Settings(BaseSettings):
beatmap_cache_expire_hours: int = 24 beatmap_cache_expire_hours: int = 24
max_concurrent_pp_calculations: int = 10 max_concurrent_pp_calculations: int = 10
enable_pp_calculation_threading: bool = True enable_pp_calculation_threading: bool = True
# 排行榜缓存设置 # 排行榜缓存设置
enable_ranking_cache: bool = True enable_ranking_cache: bool = True
ranking_cache_expire_minutes: int = 10 # 排行榜缓存过期时间(分钟) ranking_cache_expire_minutes: int = 10 # 排行榜缓存过期时间(分钟)
ranking_cache_refresh_interval_minutes: int = 10 # 排行榜缓存刷新间隔(分钟) ranking_cache_refresh_interval_minutes: int = 10 # 排行榜缓存刷新间隔(分钟)
ranking_cache_max_pages: int = 20 # 最多缓存的页数 ranking_cache_max_pages: int = 20 # 最多缓存的页数
ranking_cache_top_countries: int = 20 # 缓存前N个国家的排行榜 ranking_cache_top_countries: int = 20 # 缓存前N个国家的排行榜
# 用户缓存设置 # 用户缓存设置
enable_user_cache_preload: bool = True # 启用用户缓存预加载 enable_user_cache_preload: bool = True # 启用用户缓存预加载
user_cache_expire_seconds: int = 300 # 用户信息缓存过期时间(秒) user_cache_expire_seconds: int = 300 # 用户信息缓存过期时间(秒)

View File

@@ -8,7 +8,7 @@ from app.models.score import GameMode
from .lazer_user import BASE_INCLUDES, User, UserResp from .lazer_user import BASE_INCLUDES, User, UserResp
from pydantic import BaseModel, field_validator, model_validator from pydantic import BaseModel, field_validator, model_validator
from sqlalchemy import Boolean, JSON, Column, DateTime, Text from sqlalchemy import JSON, Boolean, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -205,7 +205,18 @@ class BeatmapsetResp(BeatmapsetBase):
favourite_count: int = 0 favourite_count: int = 0
recent_favourites: list[UserResp] = Field(default_factory=list) recent_favourites: list[UserResp] = Field(default_factory=list)
@field_validator('nsfw', 'spotlight', 'video', 'can_be_hyped', 'discussion_locked', 'storyboard', 'discussion_enabled', 'is_scoreable', 'has_favourited', mode='before') @field_validator(
"nsfw",
"spotlight",
"video",
"can_be_hyped",
"discussion_locked",
"storyboard",
"discussion_enabled",
"is_scoreable",
"has_favourited",
mode="before",
)
@classmethod @classmethod
def validate_bool_fields(cls, v): def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段""" """将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""

View File

@@ -2,13 +2,16 @@
数据库字段类型工具 数据库字段类型工具
提供处理数据库和 Pydantic 之间类型转换的工具 提供处理数据库和 Pydantic 之间类型转换的工具
""" """
from typing import Any, Union
from typing import Any
from pydantic import field_validator from pydantic import field_validator
from sqlalchemy import Boolean from sqlalchemy import Boolean
def bool_field_validator(field_name: str): def bool_field_validator(field_name: str):
"""为特定布尔字段创建验证器,处理数据库中的 0/1 整数""" """为特定布尔字段创建验证器,处理数据库中的 0/1 整数"""
@field_validator(field_name, mode="before") @field_validator(field_name, mode="before")
@classmethod @classmethod
def validate_bool_field(cls, v: Any) -> bool: def validate_bool_field(cls, v: Any) -> bool:
@@ -16,20 +19,21 @@ def bool_field_validator(field_name: str):
if isinstance(v, int): if isinstance(v, int):
return bool(v) return bool(v)
return v return v
return validate_bool_field return validate_bool_field
def create_bool_field(**kwargs): def create_bool_field(**kwargs):
"""创建一个带有正确 SQLAlchemy 列定义的布尔字段""" """创建一个带有正确 SQLAlchemy 列定义的布尔字段"""
from sqlmodel import Field, Column from sqlmodel import Column, Field
# 如果没有指定 sa_column则使用 Boolean 类型 # 如果没有指定 sa_column则使用 Boolean 类型
if 'sa_column' not in kwargs: if "sa_column" not in kwargs:
# 处理 index 参数 # 处理 index 参数
index = kwargs.pop('index', False) index = kwargs.pop("index", False)
if index: if index:
kwargs['sa_column'] = Column(Boolean, index=True) kwargs["sa_column"] = Column(Boolean, index=True)
else: else:
kwargs['sa_column'] = Column(Boolean) kwargs["sa_column"] = Column(Boolean)
return Field(**kwargs) return Field(**kwargs)

View File

@@ -136,7 +136,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_qat: bool = False is_qat: bool = False
is_bng: bool = False is_bng: bool = False
@field_validator('playmode', mode='before') @field_validator("playmode", mode="before")
@classmethod @classmethod
def validate_playmode(cls, v): def validate_playmode(cls, v):
"""将字符串转换为 GameMode 枚举""" """将字符串转换为 GameMode 枚举"""

View File

@@ -100,7 +100,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
sa_column=Column(JSON), default_factory=dict sa_column=Column(JSON), default_factory=dict
) )
@field_validator('maximum_statistics', mode='before') @field_validator("maximum_statistics", mode="before")
@classmethod @classmethod
def validate_maximum_statistics(cls, v): def validate_maximum_statistics(cls, v):
"""处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举""" """处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举"""
@@ -151,7 +151,7 @@ class Score(ScoreBase, table=True):
gamemode: GameMode = Field(index=True) gamemode: GameMode = Field(index=True)
pinned_order: int = Field(default=0, exclude=True) pinned_order: int = Field(default=0, exclude=True)
@field_validator('gamemode', mode='before') @field_validator("gamemode", mode="before")
@classmethod @classmethod
def validate_gamemode(cls, v): def validate_gamemode(cls, v):
"""将字符串转换为 GameMode 枚举""" """将字符串转换为 GameMode 枚举"""
@@ -209,7 +209,16 @@ class ScoreResp(ScoreBase):
ranked: bool = False ranked: bool = False
current_user_attributes: CurrentUserAttributes | None = None current_user_attributes: CurrentUserAttributes | None = None
@field_validator('has_replay', 'passed', 'preserve', 'is_perfect_combo', 'legacy_perfect', 'processed', 'ranked', mode='before') @field_validator(
"has_replay",
"passed",
"preserve",
"is_perfect_combo",
"legacy_perfect",
"processed",
"ranked",
mode="before",
)
@classmethod @classmethod
def validate_bool_fields(cls, v): def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段""" """将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
@@ -217,7 +226,7 @@ class ScoreResp(ScoreBase):
return bool(v) return bool(v)
return v return v
@field_validator('statistics', 'maximum_statistics', mode='before') @field_validator("statistics", "maximum_statistics", mode="before")
@classmethod @classmethod
def validate_statistics_fields(cls, v): def validate_statistics_fields(cls, v):
"""处理统计字段中的字符串键,转换为 HitResult 枚举""" """处理统计字段中的字符串键,转换为 HitResult 枚举"""

View File

@@ -44,7 +44,7 @@ class UserStatisticsBase(SQLModel):
replays_watched_by_others: int = Field(default=0) replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=True) is_ranked: bool = Field(default=True)
@field_validator('mode', mode='before') @field_validator("mode", mode="before")
@classmethod @classmethod
def validate_mode(cls, v): def validate_mode(cls, v):
"""将字符串转换为 GameMode 枚举""" """将字符串转换为 GameMode 枚举"""

View File

@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import time import time
from typing import Optional
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
from app.log import logger from app.log import logger
@@ -11,6 +10,7 @@ from httpx import AsyncClient
class TokenAuthError(Exception): class TokenAuthError(Exception):
"""Token 授权失败异常""" """Token 授权失败异常"""
pass pass
@@ -55,7 +55,7 @@ class BaseFetcher:
return await self._request_with_retry(url, method, **kwargs) return await self._request_with_retry(url, method, **kwargs)
async def _request_with_retry( async def _request_with_retry(
self, url: str, method: str = "GET", max_retries: Optional[int] = None, **kwargs self, url: str, method: str = "GET", max_retries: int | None = None, **kwargs
) -> dict: ) -> dict:
""" """
带重试机制的请求方法 带重试机制的请求方法
@@ -64,7 +64,7 @@ class BaseFetcher:
max_retries = self.max_retries max_retries = self.max_retries
last_error = None last_error = None
for attempt in range(max_retries + 1): for attempt in range(max_retries + 1):
try: try:
# 检查 token 是否过期 # 检查 token 是否过期
@@ -126,7 +126,9 @@ class BaseFetcher:
) )
continue continue
else: else:
logger.error(f"Request failed after {max_retries + 1} attempts: {e}") logger.error(
f"Request failed after {max_retries + 1} attempts: {e}"
)
break break
# 如果所有重试都失败了 # 如果所有重试都失败了
@@ -194,9 +196,13 @@ class BaseFetcher:
f"fetcher:refresh_token:{self.client_id}", f"fetcher:refresh_token:{self.client_id}",
self.refresh_token, self.refresh_token,
) )
logger.info(f"Successfully refreshed access token for client {self.client_id}") logger.info(
f"Successfully refreshed access token for client {self.client_id}"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to refresh access token for client {self.client_id}: {e}") logger.error(
f"Failed to refresh access token for client {self.client_id}: {e}"
)
# 清除无效的 token要求重新授权 # 清除无效的 token要求重新授权
self.access_token = "" self.access_token = ""
self.refresh_token = "" self.refresh_token = ""
@@ -204,7 +210,9 @@ class BaseFetcher:
redis = get_redis() redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}") await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:refresh_token:{self.client_id}") await redis.delete(f"fetcher:refresh_token:{self.client_id}")
logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}") logger.warning(
f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}"
)
raise raise
async def _trigger_reauthorization(self) -> None: async def _trigger_reauthorization(self) -> None:
@@ -216,18 +224,18 @@ class BaseFetcher:
f"Authentication failed after {self._auth_retry_count} attempts. " f"Authentication failed after {self._auth_retry_count} attempts. "
f"Triggering reauthorization for client {self.client_id}" f"Triggering reauthorization for client {self.client_id}"
) )
# 清除内存中的 token # 清除内存中的 token
self.access_token = "" self.access_token = ""
self.refresh_token = "" self.refresh_token = ""
self.token_expiry = 0 self.token_expiry = 0
self._auth_retry_count = 0 # 重置重试计数器 self._auth_retry_count = 0 # 重置重试计数器
# 清除 Redis 中的 token # 清除 Redis 中的 token
redis = get_redis() redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}") await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:refresh_token:{self.client_id}") await redis.delete(f"fetcher:refresh_token:{self.client_id}")
logger.warning( logger.warning(
f"All tokens cleared for client {self.client_id}. " f"All tokens cleared for client {self.client_id}. "
f"Please re-authorize using: {self.authorize_url}" f"Please re-authorize using: {self.authorize_url}"

View File

@@ -101,6 +101,7 @@ class BeatmapsetFetcher(BaseFetcher):
return json.loads(cursor_json) return json.loads(cursor_json)
except Exception: except Exception:
return {} return {}
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp: async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>" f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
@@ -164,9 +165,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 将结果缓存 15 分钟 # 将结果缓存 15 分钟
cache_ttl = 15 * 60 # 15 分钟 cache_ttl = 15 * 60 # 15 分钟
await redis_client.set( await redis_client.set(
cache_key, cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl
json.dumps(api_response, separators=(",", ":")),
ex=cache_ttl
) )
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
@@ -178,10 +177,12 @@ class BeatmapsetFetcher(BaseFetcher):
# 智能预取只在用户明确搜索时才预取避免过多API请求 # 智能预取只在用户明确搜索时才预取避免过多API请求
# 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取 # 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
if (api_response.get("cursor") and if api_response.get("cursor") and (
(query.q or query.s != "leaderboard" or cursor)): query.q or query.s != "leaderboard" or cursor
):
# 在后台预取下1页减少预取量 # 在后台预取下1页减少预取量
import asyncio import asyncio
# 不立即创建任务,而是延迟一段时间再预取 # 不立即创建任务,而是延迟一段时间再预取
async def delayed_prefetch(): async def delayed_prefetch():
await asyncio.sleep(3.0) # 延迟3秒 await asyncio.sleep(3.0) # 延迟3秒
@@ -200,8 +201,11 @@ class BeatmapsetFetcher(BaseFetcher):
return resp return resp
async def prefetch_next_pages( async def prefetch_next_pages(
self, query: SearchQueryModel, current_cursor: Cursor, self,
redis_client: redis.Redis, pages: int = 3 query: SearchQueryModel,
current_cursor: Cursor,
redis_client: redis.Redis,
pages: int = 3,
) -> None: ) -> None:
"""预取下几页内容""" """预取下几页内容"""
if not current_cursor: if not current_cursor:
@@ -269,7 +273,7 @@ class BeatmapsetFetcher(BaseFetcher):
await redis_client.set( await redis_client.set(
next_cache_key, next_cache_key,
json.dumps(api_response, separators=(",", ":")), json.dumps(api_response, separators=(",", ":")),
ex=prefetch_ttl ex=prefetch_ttl,
) )
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
@@ -317,7 +321,6 @@ class BeatmapsetFetcher(BaseFetcher):
params=params, params=params,
) )
if api_response.get("cursor"): if api_response.get("cursor"):
cursor_dict = api_response["cursor"] cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict) api_response["cursor_string"] = self._encode_cursor(cursor_dict)
@@ -327,7 +330,7 @@ class BeatmapsetFetcher(BaseFetcher):
await redis_client.set( await redis_client.set(
cache_key, cache_key,
json.dumps(api_response, separators=(",", ":")), json.dumps(api_response, separators=(",", ":")),
ex=cache_ttl ex=cache_ttl,
) )
logger.opt(colors=True).info( logger.opt(colors=True).info(
@@ -335,7 +338,6 @@ class BeatmapsetFetcher(BaseFetcher):
f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)" f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
) )
if api_response.get("cursor"): if api_response.get("cursor"):
await self.prefetch_next_pages( await self.prefetch_next_pages(
query, api_response["cursor"], redis_client, pages=2 query, api_response["cursor"], redis_client, pages=2

View File

@@ -5,6 +5,7 @@ Rate limiter for osu! API requests to avoid abuse detection.
- 突发:短时间内最多 200 次额外请求 - 突发:短时间内最多 200 次额外请求
- 建议:每分钟不超过 60 次请求以避免滥用检测 - 建议:每分钟不超过 60 次请求以避免滥用检测
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio

View File

@@ -4,7 +4,6 @@ import asyncio
from datetime import datetime from datetime import datetime
from typing import Literal from typing import Literal
from app.config import settings
from app.database.lazer_user import User from app.database.lazer_user import User
from app.database.statistics import UserStatistics, UserStatisticsResp from app.database.statistics import UserStatistics, UserStatisticsResp
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, get_redis
@@ -56,7 +55,7 @@ class V1User(AllStrModel):
# 确保 user_id 不为 None # 确保 user_id 不为 None
if db_user.id is None: if db_user.id is None:
raise ValueError("User ID cannot be None") raise ValueError("User ID cannot be None")
ruleset = ruleset or db_user.playmode ruleset = ruleset or db_user.playmode
current_statistics: UserStatistics | None = None current_statistics: UserStatistics | None = None
for i in await db_user.awaitable_attrs.statistics: for i in await db_user.awaitable_attrs.statistics:
@@ -118,26 +117,28 @@ async def get_user(
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
# 确定查询方式和用户ID # 确定查询方式和用户ID
is_id_query = type == "id" or user.isdigit() is_id_query = type == "id" or user.isdigit()
# 解析 ruleset # 解析 ruleset
ruleset = GameMode.from_int_extra(ruleset_id) if ruleset_id else None ruleset = GameMode.from_int_extra(ruleset_id) if ruleset_id else None
# 如果是 ID 查询,先尝试从缓存获取 # 如果是 ID 查询,先尝试从缓存获取
cached_v1_user = None cached_v1_user = None
user_id_for_cache = None user_id_for_cache = None
if is_id_query: if is_id_query:
try: try:
user_id_for_cache = int(user) user_id_for_cache = int(user)
cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset) cached_v1_user = await cache_service.get_v1_user_from_cache(
user_id_for_cache, ruleset
)
if cached_v1_user: if cached_v1_user:
return [V1User(**cached_v1_user)] return [V1User(**cached_v1_user)]
except (ValueError, TypeError): except (ValueError, TypeError):
pass # 不是有效的用户ID继续数据库查询 pass # 不是有效的用户ID继续数据库查询
# 从数据库查询用户 # 从数据库查询用户
db_user = ( db_user = (
await session.exec( await session.exec(
@@ -146,23 +147,23 @@ async def get_user(
) )
) )
).first() ).first()
if not db_user: if not db_user:
return [] return []
try: try:
# 生成用户数据 # 生成用户数据
v1_user = await V1User.from_db(session, db_user, ruleset) v1_user = await V1User.from_db(session, db_user, ruleset)
# 异步缓存结果如果有用户ID # 异步缓存结果如果有用户ID
if db_user.id is not None: if db_user.id is not None:
user_data = v1_user.model_dump() user_data = v1_user.model_dump()
asyncio.create_task( asyncio.create_task(
cache_service.cache_v1_user(user_data, db_user.id, ruleset) cache_service.cache_v1_user(user_data, db_user.id, ruleset)
) )
return [v1_user] return [v1_user]
except KeyError: except KeyError:
raise HTTPException(400, "Invalid request") raise HTTPException(400, "Invalid request")
except ValueError as e: except ValueError as e:

View File

@@ -2,15 +2,15 @@
缓存管理和监控接口 缓存管理和监控接口
提供缓存统计、清理和预热功能 提供缓存统计、清理和预热功能
""" """
from __future__ import annotations from __future__ import annotations
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
from app.dependencies.user import get_current_user
from app.service.user_cache_service import get_user_cache_service from app.service.user_cache_service import get_user_cache_service
from .router import router from .router import router
from fastapi import Depends, HTTPException, Security from fastapi import Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis from redis.asyncio import Redis
@@ -34,7 +34,7 @@ async def get_cache_stats(
try: try:
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
user_cache_stats = await cache_service.get_cache_stats() user_cache_stats = await cache_service.get_cache_stats()
# 获取 Redis 基本信息 # 获取 Redis 基本信息
redis_info = await redis.info() redis_info = await redis.info()
redis_stats = { redis_stats = {
@@ -47,20 +47,17 @@ async def get_cache_stats(
"evicted_keys": redis_info.get("evicted_keys", 0), "evicted_keys": redis_info.get("evicted_keys", 0),
"expired_keys": redis_info.get("expired_keys", 0), "expired_keys": redis_info.get("expired_keys", 0),
} }
# 计算缓存命中率 # 计算缓存命中率
hits = redis_stats["keyspace_hits"] hits = redis_stats["keyspace_hits"]
misses = redis_stats["keyspace_misses"] misses = redis_stats["keyspace_misses"]
hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0 hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0
redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2) redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2)
return CacheStatsResponse( return CacheStatsResponse(user_cache=user_cache_stats, redis_info=redis_stats)
user_cache=user_cache_stats,
redis_info=redis_stats
)
except Exception as e: except Exception as e:
raise HTTPException(500, f"Failed to get cache stats: {str(e)}") raise HTTPException(500, f"Failed to get cache stats: {e!s}")
@router.post( @router.post(
@@ -80,7 +77,7 @@ async def invalidate_user_cache(
await cache_service.invalidate_v1_user_cache(user_id) await cache_service.invalidate_v1_user_cache(user_id)
return {"message": f"Cache invalidated for user {user_id}"} return {"message": f"Cache invalidated for user {user_id}"}
except Exception as e: except Exception as e:
raise HTTPException(500, f"Failed to invalidate cache: {str(e)}") raise HTTPException(500, f"Failed to invalidate cache: {e!s}")
@router.post( @router.post(
@@ -98,15 +95,15 @@ async def clear_all_user_cache(
user_keys = await redis.keys("user:*") user_keys = await redis.keys("user:*")
v1_user_keys = await redis.keys("v1_user:*") v1_user_keys = await redis.keys("v1_user:*")
all_keys = user_keys + v1_user_keys all_keys = user_keys + v1_user_keys
if all_keys: if all_keys:
await redis.delete(*all_keys) await redis.delete(*all_keys)
return {"message": f"Cleared {len(all_keys)} cache entries"} return {"message": f"Cleared {len(all_keys)} cache entries"}
else: else:
return {"message": "No cache entries found"} return {"message": "No cache entries found"}
except Exception as e: except Exception as e:
raise HTTPException(500, f"Failed to clear cache: {str(e)}") raise HTTPException(500, f"Failed to clear cache: {e!s}")
class CacheWarmupRequest(BaseModel): class CacheWarmupRequest(BaseModel):
@@ -127,18 +124,22 @@ async def warmup_cache(
): ):
try: try:
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
if request.user_ids: if request.user_ids:
# 预热指定用户 # 预热指定用户
from app.dependencies.database import with_db from app.dependencies.database import with_db
async with with_db() as session: async with with_db() as session:
await cache_service.preload_user_cache(session, request.user_ids) await cache_service.preload_user_cache(session, request.user_ids)
return {"message": f"Warmed up cache for {len(request.user_ids)} users"} return {"message": f"Warmed up cache for {len(request.user_ids)} users"}
else: else:
# 预热活跃用户 # 预热活跃用户
from app.scheduler.user_cache_scheduler import schedule_user_cache_preload_task from app.scheduler.user_cache_scheduler import (
schedule_user_cache_preload_task,
)
await schedule_user_cache_preload_task() await schedule_user_cache_preload_task()
return {"message": f"Warmed up cache for top {request.limit} active users"} return {"message": f"Warmed up cache for top {request.limit} active users"}
except Exception as e: except Exception as e:
raise HTTPException(500, f"Failed to warmup cache: {str(e)}") raise HTTPException(500, f"Failed to warmup cache: {e!s}")

View File

@@ -45,24 +45,24 @@ async def get_country_ranking(
# 获取 Redis 连接和缓存服务 # 获取 Redis 连接和缓存服务
redis = get_redis() redis = get_redis()
cache_service = get_ranking_cache_service(redis) cache_service = get_ranking_cache_service(redis)
# 尝试从缓存获取数据 # 尝试从缓存获取数据
cached_data = await cache_service.get_cached_country_ranking(ruleset, page) cached_data = await cache_service.get_cached_country_ranking(ruleset, page)
if cached_data: if cached_data:
# 从缓存返回数据 # 从缓存返回数据
return CountryResponse( return CountryResponse(
ranking=[CountryStatistics.model_validate(item) for item in cached_data] ranking=[CountryStatistics.model_validate(item) for item in cached_data]
) )
# 缓存未命中,从数据库查询 # 缓存未命中,从数据库查询
response = CountryResponse(ranking=[]) response = CountryResponse(ranking=[])
countries = (await session.exec(select(User.country_code).distinct())).all() countries = (await session.exec(select(User.country_code).distinct())).all()
for country in countries: for country in countries:
if not country: # 跳过空的国家代码 if not country: # 跳过空的国家代码
continue continue
statistics = ( statistics = (
await session.exec( await session.exec(
select(UserStatistics).where( select(UserStatistics).where(
@@ -73,10 +73,10 @@ async def get_country_ranking(
) )
) )
).all() ).all()
if not statistics: # 跳过没有数据的国家 if not statistics: # 跳过没有数据的国家
continue continue
pp = 0 pp = 0
country_stats = CountryStatistics( country_stats = CountryStatistics(
code=country, code=country,
@@ -92,27 +92,28 @@ async def get_country_ranking(
pp += stat.pp pp += stat.pp
country_stats.performance = round(pp) country_stats.performance = round(pp)
response.ranking.append(country_stats) response.ranking.append(country_stats)
response.ranking.sort(key=lambda x: x.performance, reverse=True) response.ranking.sort(key=lambda x: x.performance, reverse=True)
# 分页处理 # 分页处理
page_size = 50 page_size = 50
start_idx = (page - 1) * page_size start_idx = (page - 1) * page_size
end_idx = start_idx + page_size end_idx = start_idx + page_size
# 获取当前页的数据 # 获取当前页的数据
current_page_data = response.ranking[start_idx:end_idx] current_page_data = response.ranking[start_idx:end_idx]
# 异步缓存数据(不等待完成) # 异步缓存数据(不等待完成)
cache_data = [item.model_dump() for item in current_page_data] cache_data = [item.model_dump() for item in current_page_data]
cache_task = cache_service.cache_country_ranking( cache_task = cache_service.cache_country_ranking(
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60 ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
) )
# 创建后台任务来缓存数据 # 创建后台任务来缓存数据
import asyncio import asyncio
asyncio.create_task(cache_task) asyncio.create_task(cache_task)
# 返回当前页的结果 # 返回当前页的结果
response.ranking = current_page_data response.ranking = current_page_data
return response return response
@@ -142,20 +143,16 @@ async def get_user_ranking(
# 获取 Redis 连接和缓存服务 # 获取 Redis 连接和缓存服务
redis = get_redis() redis = get_redis()
cache_service = get_ranking_cache_service(redis) cache_service = get_ranking_cache_service(redis)
# 尝试从缓存获取数据 # 尝试从缓存获取数据
cached_data = await cache_service.get_cached_ranking( cached_data = await cache_service.get_cached_ranking(ruleset, type, country, page)
ruleset, type, country, page
)
if cached_data: if cached_data:
# 从缓存返回数据 # 从缓存返回数据
return TopUsersResponse( return TopUsersResponse(
ranking=[ ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
UserStatisticsResp.model_validate(item) for item in cached_data
]
) )
# 缓存未命中,从数据库查询 # 缓存未命中,从数据库查询
wheres = [ wheres = [
col(UserStatistics.mode) == ruleset, col(UserStatistics.mode) == ruleset,
@@ -170,7 +167,7 @@ async def get_user_ranking(
order_by = col(UserStatistics.ranked_score).desc() order_by = col(UserStatistics.ranked_score).desc()
if country: if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper())) wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
statistics_list = await session.exec( statistics_list = await session.exec(
select(UserStatistics) select(UserStatistics)
.where(*wheres) .where(*wheres)
@@ -178,7 +175,7 @@ async def get_user_ranking(
.limit(50) .limit(50)
.offset(50 * (page - 1)) .offset(50 * (page - 1))
) )
# 转换为响应格式 # 转换为响应格式
ranking_data = [] ranking_data = []
for statistics in statistics_list: for statistics in statistics_list:
@@ -186,18 +183,24 @@ async def get_user_ranking(
statistics, session, None, include statistics, session, None, include
) )
ranking_data.append(user_stats_resp) ranking_data.append(user_stats_resp)
# 异步缓存数据(不等待完成) # 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置 # 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data] cache_data = [item.model_dump() for item in ranking_data]
cache_task = cache_service.cache_ranking( cache_task = cache_service.cache_ranking(
ruleset, type, cache_data, country, page, ttl=settings.ranking_cache_expire_minutes * 60 ruleset,
type,
cache_data,
country,
page,
ttl=settings.ranking_cache_expire_minutes * 60,
) )
# 创建后台任务来缓存数据 # 创建后台任务来缓存数据
import asyncio import asyncio
asyncio.create_task(cache_task) asyncio.create_task(cache_task)
resp = TopUsersResponse(ranking=ranking_data) resp = TopUsersResponse(ranking=ranking_data)
return resp return resp
@@ -328,4 +331,4 @@ async def get_ranking_cache_stats(
cache_service = get_ranking_cache_service(redis) cache_service = get_ranking_cache_service(redis)
stats = await cache_service.get_cache_stats() stats = await cache_service.get_cache_stats()
return stats """ return stats """

View File

@@ -183,7 +183,7 @@ async def submit_score(
} }
db.add(rank_event) db.add(rank_event)
await db.commit() await db.commit()
# 成绩提交后刷新用户缓存 # 成绩提交后刷新用户缓存
try: try:
user_cache_service = get_user_cache_service(redis) user_cache_service = get_user_cache_service(redis)
@@ -193,7 +193,7 @@ async def submit_score(
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to refresh user cache after score submit: {e}") logger.error(f"Failed to refresh user cache after score submit: {e}")
background_task.add_task(process_user_achievement, resp.id) background_task.add_task(process_user_achievement, resp.id)
return resp return resp

View File

@@ -57,25 +57,27 @@ async def get_users(
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
if user_ids: if user_ids:
# 先尝试从缓存获取 # 先尝试从缓存获取
cached_users = [] cached_users = []
uncached_user_ids = [] uncached_user_ids = []
for user_id in user_ids[:50]: # 限制50个 for user_id in user_ids[:50]: # 限制50个
cached_user = await cache_service.get_user_from_cache(user_id) cached_user = await cache_service.get_user_from_cache(user_id)
if cached_user: if cached_user:
cached_users.append(cached_user) cached_users.append(cached_user)
else: else:
uncached_user_ids.append(user_id) uncached_user_ids.append(user_id)
# 查询未缓存的用户 # 查询未缓存的用户
if uncached_user_ids: if uncached_user_ids:
searched_users = ( searched_users = (
await session.exec(select(User).where(col(User.id).in_(uncached_user_ids))) await session.exec(
select(User).where(col(User.id).in_(uncached_user_ids))
)
).all() ).all()
# 将查询到的用户添加到缓存并返回 # 将查询到的用户添加到缓存并返回
for searched_user in searched_users: for searched_user in searched_users:
if searched_user.id != BANCHOBOT_ID: if searched_user.id != BANCHOBOT_ID:
@@ -87,7 +89,7 @@ async def get_users(
cached_users.append(user_resp) cached_users.append(user_resp)
# 异步缓存,不阻塞响应 # 异步缓存,不阻塞响应
asyncio.create_task(cache_service.cache_user(user_resp)) asyncio.create_task(cache_service.cache_user(user_resp))
return BatchUserResponse(users=cached_users) return BatchUserResponse(users=cached_users)
else: else:
searched_users = (await session.exec(select(User).limit(50))).all() searched_users = (await session.exec(select(User).limit(50))).all()
@@ -102,7 +104,7 @@ async def get_users(
users.append(user_resp) users.append(user_resp)
# 异步缓存 # 异步缓存
asyncio.create_task(cache_service.cache_user(user_resp)) asyncio.create_task(cache_service.cache_user(user_resp))
return BatchUserResponse(users=users) return BatchUserResponse(users=users)
@@ -121,14 +123,14 @@ async def get_user_info_ruleset(
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
# 如果是数字ID先尝试从缓存获取 # 如果是数字ID先尝试从缓存获取
if user_id.isdigit(): if user_id.isdigit():
user_id_int = int(user_id) user_id_int = int(user_id)
cached_user = await cache_service.get_user_from_cache(user_id_int, ruleset) cached_user = await cache_service.get_user_from_cache(user_id_int, ruleset)
if cached_user: if cached_user:
return cached_user return cached_user
searched_user = ( searched_user = (
await session.exec( await session.exec(
select(User).where( select(User).where(
@@ -140,17 +142,17 @@ async def get_user_info_ruleset(
).first() ).first()
if not searched_user or searched_user.id == BANCHOBOT_ID: if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
user_resp = await UserResp.from_db( user_resp = await UserResp.from_db(
searched_user, searched_user,
session, session,
include=SEARCH_INCLUDED, include=SEARCH_INCLUDED,
ruleset=ruleset, ruleset=ruleset,
) )
# 异步缓存结果 # 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp, ruleset)) asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
return user_resp return user_resp
@@ -169,14 +171,14 @@ async def get_user_info(
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
# 如果是数字ID先尝试从缓存获取 # 如果是数字ID先尝试从缓存获取
if user_id.isdigit(): if user_id.isdigit():
user_id_int = int(user_id) user_id_int = int(user_id)
cached_user = await cache_service.get_user_from_cache(user_id_int) cached_user = await cache_service.get_user_from_cache(user_id_int)
if cached_user: if cached_user:
return cached_user return cached_user
searched_user = ( searched_user = (
await session.exec( await session.exec(
select(User).where( select(User).where(
@@ -188,16 +190,16 @@ async def get_user_info(
).first() ).first()
if not searched_user or searched_user.id == BANCHOBOT_ID: if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
user_resp = await UserResp.from_db( user_resp = await UserResp.from_db(
searched_user, searched_user,
session, session,
include=SEARCH_INCLUDED, include=SEARCH_INCLUDED,
) )
# 异步缓存结果 # 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp)) asyncio.create_task(cache_service.cache_user(user_resp))
return user_resp return user_resp
@@ -218,7 +220,7 @@ async def get_user_beatmapsets(
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
# 先尝试从缓存获取 # 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache( cached_result = await cache_service.get_user_beatmapsets_from_cache(
user_id, type.value, limit, offset user_id, type.value, limit, offset
@@ -229,7 +231,7 @@ async def get_user_beatmapsets(
return [BeatmapPlaycountsResp(**item) for item in cached_result] return [BeatmapPlaycountsResp(**item) for item in cached_result]
else: else:
return [BeatmapsetResp(**item) for item in cached_result] return [BeatmapsetResp(**item) for item in cached_result]
user = await session.get(User, user_id) user = await session.get(User, user_id)
if not user or user.id == BANCHOBOT_ID: if not user or user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
@@ -275,10 +277,14 @@ async def get_user_beatmapsets(
# 异步缓存结果 # 异步缓存结果
async def cache_beatmapsets(): async def cache_beatmapsets():
try: try:
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset) await cache_service.cache_user_beatmapsets(
user_id, type.value, resp, limit, offset
)
except Exception as e: except Exception as e:
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}") logger.error(
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
)
asyncio.create_task(cache_beatmapsets()) asyncio.create_task(cache_beatmapsets())
return resp return resp
@@ -311,7 +317,7 @@ async def get_user_scores(
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
# 先尝试从缓存获取对于recent类型使用较短的缓存时间 # 先尝试从缓存获取对于recent类型使用较短的缓存时间
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
cached_scores = await cache_service.get_user_scores_from_cache( cached_scores = await cache_service.get_user_scores_from_cache(
@@ -319,7 +325,7 @@ async def get_user_scores(
) )
if cached_scores is not None: if cached_scores is not None:
return cached_scores return cached_scores
db_user = await session.get(User, user_id) db_user = await session.get(User, user_id)
if not db_user or db_user.id == BANCHOBOT_ID: if not db_user or db_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
@@ -355,7 +361,7 @@ async def get_user_scores(
).all() ).all()
if not scores: if not scores:
return [] return []
score_responses = [ score_responses = [
await ScoreResp.from_db( await ScoreResp.from_db(
session, session,
@@ -363,14 +369,14 @@ async def get_user_scores(
) )
for score in scores for score in scores
] ]
# 异步缓存结果 # 异步缓存结果
asyncio.create_task( asyncio.create_task(
cache_service.cache_user_scores( cache_service.cache_user_scores(
user_id, type, score_responses, mode, limit, offset, cache_expire user_id, type, score_responses, mode, limit, offset, cache_expire
) )
) )
return score_responses return score_responses

View File

@@ -1,4 +1,5 @@
"""缓存调度器模块""" """缓存调度器模块"""
from __future__ import annotations from __future__ import annotations
from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from app.config import settings from app.config import settings
from app.dependencies.database import get_db, get_redis from app.dependencies.database import get_redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.log import logger from app.log import logger
from app.scheduler.user_cache_scheduler import ( from app.scheduler.user_cache_scheduler import (
@@ -44,10 +44,10 @@ class CacheScheduler:
"""运行调度器主循环""" """运行调度器主循环"""
# 启动时立即执行一次预热 # 启动时立即执行一次预热
await self._warmup_cache() await self._warmup_cache()
# 启动时执行一次排行榜缓存刷新 # 启动时执行一次排行榜缓存刷新
await self._refresh_ranking_cache() await self._refresh_ranking_cache()
# 启动时执行一次用户缓存预热 # 启动时执行一次用户缓存预热
await self._warmup_user_cache() await self._warmup_user_cache()
@@ -55,14 +55,16 @@ class CacheScheduler:
ranking_cache_counter = 0 ranking_cache_counter = 0
user_cache_counter = 0 user_cache_counter = 0
user_cleanup_counter = 0 user_cleanup_counter = 0
# 从配置文件获取间隔设置 # 从配置文件获取间隔设置
check_interval = 5 * 60 # 5分钟检查间隔 check_interval = 5 * 60 # 5分钟检查间隔
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔 beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取 ranking_cache_interval = (
settings.ranking_cache_refresh_interval_minutes * 60
) # 从配置读取
user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔 user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔
user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔 user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔
beatmap_cache_cycles = beatmap_cache_interval // check_interval beatmap_cache_cycles = beatmap_cache_interval // check_interval
ranking_cache_cycles = ranking_cache_interval // check_interval ranking_cache_cycles = ranking_cache_interval // check_interval
user_cache_cycles = user_cache_interval // check_interval user_cache_cycles = user_cache_interval // check_interval
@@ -90,12 +92,12 @@ class CacheScheduler:
if ranking_cache_counter >= ranking_cache_cycles: if ranking_cache_counter >= ranking_cache_cycles:
await self._refresh_ranking_cache() await self._refresh_ranking_cache()
ranking_cache_counter = 0 ranking_cache_counter = 0
# 用户缓存预加载 # 用户缓存预加载
if user_cache_counter >= user_cache_cycles: if user_cache_counter >= user_cache_cycles:
await self._preload_user_cache() await self._preload_user_cache()
user_cache_counter = 0 user_cache_counter = 0
# 用户缓存清理 # 用户缓存清理
if user_cleanup_counter >= user_cleanup_cycles: if user_cleanup_counter >= user_cleanup_cycles:
await self._cleanup_user_cache() await self._cleanup_user_cache()
@@ -129,15 +131,14 @@ class CacheScheduler:
logger.info("Starting ranking cache refresh...") logger.info("Starting ranking cache refresh...")
redis = get_redis() redis = get_redis()
# 导入排行榜缓存服务 # 导入排行榜缓存服务
# 使用独立的数据库会话
from app.dependencies.database import with_db
from app.service.ranking_cache_service import ( from app.service.ranking_cache_service import (
get_ranking_cache_service,
schedule_ranking_refresh_task, schedule_ranking_refresh_task,
) )
# 使用独立的数据库会话
from app.dependencies.database import with_db
async with with_db() as session: async with with_db() as session:
await schedule_ranking_refresh_task(session, redis) await schedule_ranking_refresh_task(session, redis)
@@ -171,6 +172,7 @@ class CacheScheduler:
# Beatmap缓存调度器保持向后兼容 # Beatmap缓存调度器保持向后兼容
class BeatmapsetCacheScheduler(CacheScheduler): class BeatmapsetCacheScheduler(CacheScheduler):
"""谱面集缓存调度器 - 为了向后兼容""" """谱面集缓存调度器 - 为了向后兼容"""
pass pass

View File

@@ -1,15 +1,15 @@
""" """
用户缓存预热任务调度器 用户缓存预热任务调度器
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from app.config import settings from app.config import settings
from app.database import User
from app.database.score import Score from app.database.score import Score
from app.dependencies.database import get_db, get_redis from app.dependencies.database import get_redis
from app.log import logger from app.log import logger
from app.service.user_cache_service import get_user_cache_service from app.service.user_cache_service import get_user_cache_service
@@ -25,16 +25,17 @@ async def schedule_user_cache_preload_task():
try: try:
logger.info("Starting user cache preload task...") logger.info("Starting user cache preload task...")
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
# 使用独立的数据库会话 # 使用独立的数据库会话
from app.dependencies.database import with_db from app.dependencies.database import with_db
async with with_db() as session: async with with_db() as session:
# 获取最近24小时内活跃的用户提交过成绩的用户 # 获取最近24小时内活跃的用户提交过成绩的用户
recent_time = datetime.now(UTC) - timedelta(hours=24) recent_time = datetime.now(UTC) - timedelta(hours=24)
active_user_ids = ( active_user_ids = (
await session.exec( await session.exec(
select(Score.user_id, func.count().label("score_count")) select(Score.user_id, func.count().label("score_count"))
@@ -44,7 +45,7 @@ async def schedule_user_cache_preload_task():
.limit(settings.user_cache_max_preload_users) # 使用配置中的限制 .limit(settings.user_cache_max_preload_users) # 使用配置中的限制
) )
).all() ).all()
if active_user_ids: if active_user_ids:
user_ids = [row[0] for row in active_user_ids] user_ids = [row[0] for row in active_user_ids]
await cache_service.preload_user_cache(session, user_ids) await cache_service.preload_user_cache(session, user_ids)
@@ -62,17 +63,18 @@ async def schedule_user_cache_warmup_task():
"""定时用户缓存预热任务 - 预加载排行榜前100用户""" """定时用户缓存预热任务 - 预加载排行榜前100用户"""
try: try:
logger.info("Starting user cache warmup task...") logger.info("Starting user cache warmup task...")
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
# 使用独立的数据库会话 # 使用独立的数据库会话
from app.dependencies.database import with_db from app.dependencies.database import with_db
async with with_db() as session: async with with_db() as session:
# 获取全球排行榜前100的用户 # 获取全球排行榜前100的用户
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics
from app.models.score import GameMode from app.models.score import GameMode
for mode in GameMode: for mode in GameMode:
try: try:
top_users = ( top_users = (
@@ -83,15 +85,15 @@ async def schedule_user_cache_warmup_task():
.limit(100) .limit(100)
) )
).all() ).all()
if top_users: if top_users:
user_ids = list(top_users) user_ids = list(top_users)
await cache_service.preload_user_cache(session, user_ids) await cache_service.preload_user_cache(session, user_ids)
logger.info(f"Warmed cache for top 100 users in {mode}") logger.info(f"Warmed cache for top 100 users in {mode}")
# 避免过载,稍微延迟 # 避免过载,稍微延迟
await asyncio.sleep(1) await asyncio.sleep(1)
except Exception as e: except Exception as e:
logger.error(f"Failed to warm cache for {mode}: {e}") logger.error(f"Failed to warm cache for {mode}: {e}")
continue continue
@@ -106,13 +108,13 @@ async def schedule_user_cache_cleanup_task():
"""定时用户缓存清理任务""" """定时用户缓存清理任务"""
try: try:
logger.info("Starting user cache cleanup task...") logger.info("Starting user cache cleanup task...")
redis = get_redis() redis = get_redis()
# 清理过期的用户缓存Redis会自动处理TTL这里主要记录统计信息 # 清理过期的用户缓存Redis会自动处理TTL这里主要记录统计信息
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
stats = await cache_service.get_cache_stats() stats = await cache_service.get_cache_stats()
logger.info(f"User cache stats: {stats}") logger.info(f"User cache stats: {stats}")
logger.info("User cache cleanup task completed successfully") logger.info("User cache cleanup task completed successfully")

View File

@@ -2,6 +2,7 @@
Beatmap缓存预取服务 Beatmap缓存预取服务
用于提前缓存热门beatmap减少成绩计算时的获取延迟 用于提前缓存热门beatmap减少成绩计算时的获取延迟
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@@ -155,9 +156,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS
async def schedule_preload_task( async def schedule_preload_task(
session: AsyncSession, session: AsyncSession, redis: Redis, fetcher: "Fetcher"
redis: Redis,
fetcher: "Fetcher"
): ):
""" """
定时预加载任务 定时预加载任务

View File

@@ -2,11 +2,12 @@
用户排行榜缓存服务 用户排行榜缓存服务
用于缓存用户排行榜数据,减轻数据库压力 用于缓存用户排行榜数据,减轻数据库压力
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from datetime import UTC, datetime
import json import json
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
from app.config import settings from app.config import settings
@@ -24,7 +25,7 @@ if TYPE_CHECKING:
class DateTimeEncoder(json.JSONEncoder): class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化""" """自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj): def default(self, obj):
if isinstance(obj, datetime): if isinstance(obj, datetime):
return obj.isoformat() return obj.isoformat()
@@ -33,7 +34,9 @@ class DateTimeEncoder(json.JSONEncoder):
def safe_json_dumps(data) -> str: def safe_json_dumps(data) -> str:
"""安全的 JSON 序列化,支持 datetime 对象""" """安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")) return json.dumps(
data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")
)
class RankingCacheService: class RankingCacheService:
@@ -84,7 +87,7 @@ class RankingCacheService:
try: try:
cache_key = self._get_cache_key(ruleset, type, country, page) cache_key = self._get_cache_key(ruleset, type, country, page)
cached_data = await self.redis.get(cache_key) cached_data = await self.redis.get(cache_key)
if cached_data: if cached_data:
return json.loads(cached_data) return json.loads(cached_data)
return None return None
@@ -107,11 +110,7 @@ class RankingCacheService:
# 使用配置文件的TTL设置 # 使用配置文件的TTL设置
if ttl is None: if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set( await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
cache_key,
safe_json_dumps(ranking_data),
ex=ttl
)
logger.debug(f"Cached ranking data for {cache_key}") logger.debug(f"Cached ranking data for {cache_key}")
except Exception as e: except Exception as e:
logger.error(f"Error caching ranking: {e}") logger.error(f"Error caching ranking: {e}")
@@ -126,7 +125,7 @@ class RankingCacheService:
try: try:
cache_key = self._get_stats_cache_key(ruleset, type, country) cache_key = self._get_stats_cache_key(ruleset, type, country)
cached_data = await self.redis.get(cache_key) cached_data = await self.redis.get(cache_key)
if cached_data: if cached_data:
return json.loads(cached_data) return json.loads(cached_data)
return None return None
@@ -148,11 +147,7 @@ class RankingCacheService:
# 使用配置文件的TTL设置统计信息缓存时间更长 # 使用配置文件的TTL设置统计信息缓存时间更长
if ttl is None: if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间 ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set( await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
cache_key,
safe_json_dumps(stats),
ex=ttl
)
logger.debug(f"Cached stats for {cache_key}") logger.debug(f"Cached stats for {cache_key}")
except Exception as e: except Exception as e:
logger.error(f"Error caching stats: {e}") logger.error(f"Error caching stats: {e}")
@@ -166,7 +161,7 @@ class RankingCacheService:
try: try:
cache_key = self._get_country_cache_key(ruleset, page) cache_key = self._get_country_cache_key(ruleset, page)
cached_data = await self.redis.get(cache_key) cached_data = await self.redis.get(cache_key)
if cached_data: if cached_data:
return json.loads(cached_data) return json.loads(cached_data)
return None return None
@@ -186,11 +181,7 @@ class RankingCacheService:
cache_key = self._get_country_cache_key(ruleset, page) cache_key = self._get_country_cache_key(ruleset, page)
if ttl is None: if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set( await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
cache_key,
safe_json_dumps(ranking_data),
ex=ttl
)
logger.debug(f"Cached country ranking data for {cache_key}") logger.debug(f"Cached country ranking data for {cache_key}")
except Exception as e: except Exception as e:
logger.error(f"Error caching country ranking: {e}") logger.error(f"Error caching country ranking: {e}")
@@ -200,7 +191,7 @@ class RankingCacheService:
try: try:
cache_key = self._get_country_stats_cache_key(ruleset) cache_key = self._get_country_stats_cache_key(ruleset)
cached_data = await self.redis.get(cache_key) cached_data = await self.redis.get(cache_key)
if cached_data: if cached_data:
return json.loads(cached_data) return json.loads(cached_data)
return None return None
@@ -219,11 +210,7 @@ class RankingCacheService:
cache_key = self._get_country_stats_cache_key(ruleset) cache_key = self._get_country_stats_cache_key(ruleset)
if ttl is None: if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间 ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set( await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
cache_key,
safe_json_dumps(stats),
ex=ttl
)
logger.debug(f"Cached country stats for {cache_key}") logger.debug(f"Cached country stats for {cache_key}")
except Exception as e: except Exception as e:
logger.error(f"Error caching country stats: {e}") logger.error(f"Error caching country stats: {e}")
@@ -238,7 +225,9 @@ class RankingCacheService:
) -> None: ) -> None:
"""刷新排行榜缓存""" """刷新排行榜缓存"""
if self._refreshing: if self._refreshing:
logger.info(f"Ranking cache refresh already in progress for {ruleset}:{type}") logger.info(
f"Ranking cache refresh already in progress for {ruleset}:{type}"
)
return return
# 使用配置文件的设置 # 使用配置文件的设置
@@ -248,7 +237,7 @@ class RankingCacheService:
self._refreshing = True self._refreshing = True
try: try:
logger.info(f"Starting ranking cache refresh for {ruleset}:{type}") logger.info(f"Starting ranking cache refresh for {ruleset}:{type}")
# 构建查询条件 # 构建查询条件
wheres = [ wheres = [
col(UserStatistics.mode) == ruleset, col(UserStatistics.mode) == ruleset,
@@ -256,20 +245,22 @@ class RankingCacheService:
col(UserStatistics.is_ranked).is_(True), col(UserStatistics.is_ranked).is_(True),
] ]
include = ["user"] include = ["user"]
if type == "performance": if type == "performance":
order_by = col(UserStatistics.pp).desc() order_by = col(UserStatistics.pp).desc()
include.append("rank_change_since_30_days") include.append("rank_change_since_30_days")
else: else:
order_by = col(UserStatistics.ranked_score).desc() order_by = col(UserStatistics.ranked_score).desc()
if country: if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper())) wheres.append(
col(UserStatistics.user).has(country_code=country.upper())
)
# 获取总用户数用于统计 # 获取总用户数用于统计
total_users_query = select(UserStatistics).where(*wheres) total_users_query = select(UserStatistics).where(*wheres)
total_users = len((await session.exec(total_users_query)).all()) total_users = len((await session.exec(total_users_query)).all())
# 计算统计信息 # 计算统计信息
stats = { stats = {
"total_users": total_users, "total_users": total_users,
@@ -278,7 +269,7 @@ class RankingCacheService:
"ruleset": ruleset, "ruleset": ruleset,
"country": country, "country": country,
} }
# 缓存统计信息 # 缓存统计信息
await self.cache_stats(ruleset, type, stats, country) await self.cache_stats(ruleset, type, stats, country)
@@ -292,11 +283,11 @@ class RankingCacheService:
.limit(50) .limit(50)
.offset(50 * (page - 1)) .offset(50 * (page - 1))
) )
statistics_data = statistics_list.all() statistics_data = statistics_list.all()
if not statistics_data: if not statistics_data:
break # 没有更多数据 break # 没有更多数据
# 转换为响应格式并确保正确序列化 # 转换为响应格式并确保正确序列化
ranking_data = [] ranking_data = []
for statistics in statistics_data: for statistics in statistics_data:
@@ -306,21 +297,19 @@ class RankingCacheService:
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题 # 将 UserStatisticsResp 转换为字典,处理所有序列化问题
user_dict = json.loads(user_stats_resp.model_dump_json()) user_dict = json.loads(user_stats_resp.model_dump_json())
ranking_data.append(user_dict) ranking_data.append(user_dict)
# 缓存这一页的数据 # 缓存这一页的数据
await self.cache_ranking( await self.cache_ranking(ruleset, type, ranking_data, country, page)
ruleset, type, ranking_data, country, page
)
# 添加延迟避免数据库过载 # 添加延迟避免数据库过载
if page < max_pages: if page < max_pages:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
except Exception as e: except Exception as e:
logger.error(f"Error caching page {page} for {ruleset}:{type}: {e}") logger.error(f"Error caching page {page} for {ruleset}:{type}: {e}")
logger.info(f"Completed ranking cache refresh for {ruleset}:{type}") logger.info(f"Completed ranking cache refresh for {ruleset}:{type}")
except Exception as e: except Exception as e:
logger.error(f"Ranking cache refresh failed for {ruleset}:{type}: {e}") logger.error(f"Ranking cache refresh failed for {ruleset}:{type}: {e}")
finally: finally:
@@ -334,7 +323,9 @@ class RankingCacheService:
) -> None: ) -> None:
"""刷新地区排行榜缓存""" """刷新地区排行榜缓存"""
if self._refreshing: if self._refreshing:
logger.info(f"Country ranking cache refresh already in progress for {ruleset}") logger.info(
f"Country ranking cache refresh already in progress for {ruleset}"
)
return return
if max_pages is None: if max_pages is None:
@@ -343,17 +334,18 @@ class RankingCacheService:
self._refreshing = True self._refreshing = True
try: try:
logger.info(f"Starting country ranking cache refresh for {ruleset}") logger.info(f"Starting country ranking cache refresh for {ruleset}")
# 获取所有国家 # 获取所有国家
from app.database import User from app.database import User
countries = (await session.exec(select(User.country_code).distinct())).all() countries = (await session.exec(select(User.country_code).distinct())).all()
# 计算每个国家的统计数据 # 计算每个国家的统计数据
country_stats_list = [] country_stats_list = []
for country in countries: for country in countries:
if not country: # 跳过空的国家代码 if not country: # 跳过空的国家代码
continue continue
statistics = ( statistics = (
await session.exec( await session.exec(
select(UserStatistics).where( select(UserStatistics).where(
@@ -364,10 +356,10 @@ class RankingCacheService:
) )
) )
).all() ).all()
if not statistics: # 跳过没有数据的国家 if not statistics: # 跳过没有数据的国家
continue continue
pp = 0 pp = 0
country_stats = { country_stats = {
"code": country, "code": country,
@@ -376,48 +368,48 @@ class RankingCacheService:
"ranked_score": 0, "ranked_score": 0,
"performance": 0, "performance": 0,
} }
for stat in statistics: for stat in statistics:
country_stats["active_users"] += 1 country_stats["active_users"] += 1
country_stats["play_count"] += stat.play_count country_stats["play_count"] += stat.play_count
country_stats["ranked_score"] += stat.ranked_score country_stats["ranked_score"] += stat.ranked_score
pp += stat.pp pp += stat.pp
country_stats["performance"] = round(pp) country_stats["performance"] = round(pp)
country_stats_list.append(country_stats) country_stats_list.append(country_stats)
# 按表现分排序 # 按表现分排序
country_stats_list.sort(key=lambda x: x["performance"], reverse=True) country_stats_list.sort(key=lambda x: x["performance"], reverse=True)
# 计算统计信息 # 计算统计信息
stats = { stats = {
"total_countries": len(country_stats_list), "total_countries": len(country_stats_list),
"last_updated": datetime.now(UTC).isoformat(), "last_updated": datetime.now(UTC).isoformat(),
"ruleset": ruleset, "ruleset": ruleset,
} }
# 缓存统计信息 # 缓存统计信息
await self.cache_country_stats(ruleset, stats) await self.cache_country_stats(ruleset, stats)
# 分页缓存数据每页50个国家 # 分页缓存数据每页50个国家
page_size = 50 page_size = 50
for page in range(1, max_pages + 1): for page in range(1, max_pages + 1):
start_idx = (page - 1) * page_size start_idx = (page - 1) * page_size
end_idx = start_idx + page_size end_idx = start_idx + page_size
page_data = country_stats_list[start_idx:end_idx] page_data = country_stats_list[start_idx:end_idx]
if not page_data: if not page_data:
break # 没有更多数据 break # 没有更多数据
# 缓存这一页的数据 # 缓存这一页的数据
await self.cache_country_ranking(ruleset, page_data, page) await self.cache_country_ranking(ruleset, page_data, page)
# 添加延迟避免Redis过载 # 添加延迟避免Redis过载
if page < max_pages and page_data: if page < max_pages and page_data:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
logger.info(f"Completed country ranking cache refresh for {ruleset}") logger.info(f"Completed country ranking cache refresh for {ruleset}")
except Exception as e: except Exception as e:
logger.error(f"Country ranking cache refresh failed for {ruleset}: {e}") logger.error(f"Country ranking cache refresh failed for {ruleset}: {e}")
finally: finally:
@@ -427,11 +419,12 @@ class RankingCacheService:
"""刷新所有排行榜缓存""" """刷新所有排行榜缓存"""
game_modes = [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA] game_modes = [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"] ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
# 获取需要缓存的国家列表活跃用户数量前20的国家 # 获取需要缓存的国家列表活跃用户数量前20的国家
from app.database import User from app.database import User
from sqlmodel import func from sqlmodel import func
countries_query = ( countries_query = (
await session.exec( await session.exec(
select(User.country_code, func.count().label("user_count")) select(User.country_code, func.count().label("user_count"))
@@ -441,38 +434,40 @@ class RankingCacheService:
.limit(settings.ranking_cache_top_countries) .limit(settings.ranking_cache_top_countries)
) )
).all() ).all()
top_countries = [country for country, _ in countries_query] top_countries = [country for country, _ in countries_query]
refresh_tasks = [] refresh_tasks = []
# 全球排行榜 # 全球排行榜
for mode in game_modes: for mode in game_modes:
for ranking_type in ranking_types: for ranking_type in ranking_types:
task = self.refresh_ranking_cache(session, mode, ranking_type) task = self.refresh_ranking_cache(session, mode, ranking_type)
refresh_tasks.append(task) refresh_tasks.append(task)
# 国家排行榜仅前20个国家 # 国家排行榜仅前20个国家
for country in top_countries: for country in top_countries:
for mode in game_modes: for mode in game_modes:
for ranking_type in ranking_types: for ranking_type in ranking_types:
task = self.refresh_ranking_cache(session, mode, ranking_type, country) task = self.refresh_ranking_cache(
session, mode, ranking_type, country
)
refresh_tasks.append(task) refresh_tasks.append(task)
# 地区排行榜 # 地区排行榜
for mode in game_modes: for mode in game_modes:
task = self.refresh_country_ranking_cache(session, mode) task = self.refresh_country_ranking_cache(session, mode)
refresh_tasks.append(task) refresh_tasks.append(task)
# 并发执行刷新任务,但限制并发数 # 并发执行刷新任务,但限制并发数
semaphore = asyncio.Semaphore(5) # 最多同时5个任务 semaphore = asyncio.Semaphore(5) # 最多同时5个任务
async def bounded_refresh(task): async def bounded_refresh(task):
async with semaphore: async with semaphore:
await task await task
bounded_tasks = [bounded_refresh(task) for task in refresh_tasks] bounded_tasks = [bounded_refresh(task) for task in refresh_tasks]
try: try:
await asyncio.gather(*bounded_tasks, return_exceptions=True) await asyncio.gather(*bounded_tasks, return_exceptions=True)
logger.info("All ranking cache refresh completed") logger.info("All ranking cache refresh completed")
@@ -489,7 +484,7 @@ class RankingCacheService:
"""使缓存失效""" """使缓存失效"""
try: try:
deleted_keys = 0 deleted_keys = 0
if ruleset and type: if ruleset and type:
# 删除特定的用户排行榜缓存 # 删除特定的用户排行榜缓存
country_part = f":{country.upper()}" if country else "" country_part = f":{country.upper()}" if country else ""
@@ -498,12 +493,14 @@ class RankingCacheService:
if keys: if keys:
await self.redis.delete(*keys) await self.redis.delete(*keys)
deleted_keys += len(keys) deleted_keys += len(keys)
logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}") logger.info(
f"Invalidated {len(keys)} cache keys for {ruleset}:{type}"
)
elif ruleset: elif ruleset:
# 删除特定游戏模式的所有缓存 # 删除特定游戏模式的所有缓存
patterns = [ patterns = [
f"ranking:{ruleset}:*", f"ranking:{ruleset}:*",
f"country_ranking:{ruleset}:*" if include_country_ranking else None f"country_ranking:{ruleset}:*" if include_country_ranking else None,
] ]
for pattern in patterns: for pattern in patterns:
if pattern: if pattern:
@@ -516,15 +513,15 @@ class RankingCacheService:
patterns = ["ranking:*"] patterns = ["ranking:*"]
if include_country_ranking: if include_country_ranking:
patterns.append("country_ranking:*") patterns.append("country_ranking:*")
for pattern in patterns: for pattern in patterns:
keys = await self.redis.keys(pattern) keys = await self.redis.keys(pattern)
if keys: if keys:
await self.redis.delete(*keys) await self.redis.delete(*keys)
deleted_keys += len(keys) deleted_keys += len(keys)
logger.info(f"Invalidated all {deleted_keys} ranking cache keys") logger.info(f"Invalidated all {deleted_keys} ranking cache keys")
except Exception as e: except Exception as e:
logger.error(f"Error invalidating cache: {e}") logger.error(f"Error invalidating cache: {e}")
@@ -535,7 +532,7 @@ class RankingCacheService:
pattern = f"country_ranking:{ruleset}:*" pattern = f"country_ranking:{ruleset}:*"
else: else:
pattern = "country_ranking:*" pattern = "country_ranking:*"
keys = await self.redis.keys(pattern) keys = await self.redis.keys(pattern)
if keys: if keys:
await self.redis.delete(*keys) await self.redis.delete(*keys)
@@ -550,10 +547,10 @@ class RankingCacheService:
ranking_keys = await self.redis.keys("ranking:*") ranking_keys = await self.redis.keys("ranking:*")
# 获取地区排行榜缓存 # 获取地区排行榜缓存
country_keys = await self.redis.keys("country_ranking:*") country_keys = await self.redis.keys("country_ranking:*")
total_keys = ranking_keys + country_keys total_keys = ranking_keys + country_keys
total_size = 0 total_size = 0
for key in total_keys[:100]: # 限制检查数量以避免性能问题 for key in total_keys[:100]: # 限制检查数量以避免性能问题
try: try:
size = await self.redis.memory_usage(key) size = await self.redis.memory_usage(key)
@@ -561,7 +558,7 @@ class RankingCacheService:
total_size += size total_size += size
except Exception: except Exception:
continue continue
return { return {
"cached_user_rankings": len(ranking_keys), "cached_user_rankings": len(ranking_keys),
"cached_country_rankings": len(country_keys), "cached_country_rankings": len(country_keys),

View File

@@ -2,24 +2,23 @@
用户缓存服务 用户缓存服务
用于缓存用户信息,提供热缓存和实时刷新功能 用于缓存用户信息,提供热缓存和实时刷新功能
""" """
from __future__ import annotations from __future__ import annotations
import asyncio from datetime import datetime
import json import json
from datetime import UTC, datetime, timedelta from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
from app.config import settings from app.config import settings
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
from app.database import User, UserResp from app.database import User, UserResp
from app.database.lazer_user import SEARCH_INCLUDED from app.database.lazer_user import SEARCH_INCLUDED
from app.database.pp_best_score import PPBestScore from app.database.score import ScoreResp
from app.database.score import Score, ScoreResp
from app.log import logger from app.log import logger
from app.models.score import GameMode from app.models.score import GameMode
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import col, exists, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -28,7 +27,7 @@ if TYPE_CHECKING:
class DateTimeEncoder(json.JSONEncoder): class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化""" """自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj): def default(self, obj):
if isinstance(obj, datetime): if isinstance(obj, datetime):
return obj.isoformat() return obj.isoformat()
@@ -48,16 +47,16 @@ class UserCacheService:
self._refreshing = False self._refreshing = False
self._background_tasks: set = set() self._background_tasks: set = set()
def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str: def _get_v1_user_cache_key(
self, user_id: int, ruleset: GameMode | None = None
) -> str:
"""生成 V1 用户缓存键""" """生成 V1 用户缓存键"""
if ruleset: if ruleset:
return f"v1_user:{user_id}:ruleset:{ruleset}" return f"v1_user:{user_id}:ruleset:{ruleset}"
return f"v1_user:{user_id}" return f"v1_user:{user_id}"
async def get_v1_user_from_cache( async def get_v1_user_from_cache(
self, self, user_id: int, ruleset: GameMode | None = None
user_id: int,
ruleset: GameMode | None = None
) -> dict | None: ) -> dict | None:
"""从缓存获取 V1 用户信息""" """从缓存获取 V1 用户信息"""
try: try:
@@ -72,11 +71,11 @@ class UserCacheService:
return None return None
async def cache_v1_user( async def cache_v1_user(
self, self,
user_data: dict, user_data: dict,
user_id: int, user_id: int,
ruleset: GameMode | None = None, ruleset: GameMode | None = None,
expire_seconds: int | None = None expire_seconds: int | None = None,
): ):
"""缓存 V1 用户信息""" """缓存 V1 用户信息"""
try: try:
@@ -97,7 +96,9 @@ class UserCacheService:
keys = await self.redis.keys(pattern) keys = await self.redis.keys(pattern)
if keys: if keys:
await self.redis.delete(*keys) await self.redis.delete(*keys)
logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}") logger.info(
f"Invalidated {len(keys)} V1 cache entries for user {user_id}"
)
except Exception as e: except Exception as e:
logger.error(f"Error invalidating V1 user cache: {e}") logger.error(f"Error invalidating V1 user cache: {e}")
@@ -108,31 +109,25 @@ class UserCacheService:
return f"user:{user_id}" return f"user:{user_id}"
def _get_user_scores_cache_key( def _get_user_scores_cache_key(
self, self,
user_id: int, user_id: int,
score_type: str, score_type: str,
mode: GameMode | None = None, mode: GameMode | None = None,
limit: int = 100, limit: int = 100,
offset: int = 0 offset: int = 0,
) -> str: ) -> str:
"""生成用户成绩缓存键""" """生成用户成绩缓存键"""
mode_part = f":{mode}" if mode else "" mode_part = f":{mode}" if mode else ""
return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}" return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}"
def _get_user_beatmapsets_cache_key( def _get_user_beatmapsets_cache_key(
self, self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0
user_id: int,
beatmapset_type: str,
limit: int = 100,
offset: int = 0
) -> str: ) -> str:
"""生成用户谱面集缓存键""" """生成用户谱面集缓存键"""
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}" return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
async def get_user_from_cache( async def get_user_from_cache(
self, self, user_id: int, ruleset: GameMode | None = None
user_id: int,
ruleset: GameMode | None = None
) -> UserResp | None: ) -> UserResp | None:
"""从缓存获取用户信息""" """从缓存获取用户信息"""
try: try:
@@ -148,10 +143,10 @@ class UserCacheService:
return None return None
async def cache_user( async def cache_user(
self, self,
user_resp: UserResp, user_resp: UserResp,
ruleset: GameMode | None = None, ruleset: GameMode | None = None,
expire_seconds: int | None = None expire_seconds: int | None = None,
): ):
"""缓存用户信息""" """缓存用户信息"""
try: try:
@@ -173,14 +168,18 @@ class UserCacheService:
score_type: str, score_type: str,
mode: GameMode | None = None, mode: GameMode | None = None,
limit: int = 100, limit: int = 100,
offset: int = 0 offset: int = 0,
) -> list[ScoreResp] | None: ) -> list[ScoreResp] | None:
"""从缓存获取用户成绩""" """从缓存获取用户成绩"""
try: try:
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset) cache_key = self._get_user_scores_cache_key(
user_id, score_type, mode, limit, offset
)
cached_data = await self.redis.get(cache_key) cached_data = await self.redis.get(cache_key)
if cached_data: if cached_data:
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}") logger.debug(
f"User scores cache hit for user {user_id}, type {score_type}"
)
data = json.loads(cached_data) data = json.loads(cached_data)
return [ScoreResp(**score_data) for score_data in data] return [ScoreResp(**score_data) for score_data in data]
return None return None
@@ -196,34 +195,38 @@ class UserCacheService:
mode: GameMode | None = None, mode: GameMode | None = None,
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
expire_seconds: int | None = None expire_seconds: int | None = None,
): ):
"""缓存用户成绩""" """缓存用户成绩"""
try: try:
if expire_seconds is None: if expire_seconds is None:
expire_seconds = settings.user_scores_cache_expire_seconds expire_seconds = settings.user_scores_cache_expire_seconds
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset) cache_key = self._get_user_scores_cache_key(
user_id, score_type, mode, limit, offset
)
# 使用 model_dump_json() 而不是 model_dump() + json.dumps() # 使用 model_dump_json() 而不是 model_dump() + json.dumps()
scores_json_list = [score.model_dump_json() for score in scores] scores_json_list = [score.model_dump_json() for score in scores]
cached_data = f"[{','.join(scores_json_list)}]" cached_data = f"[{','.join(scores_json_list)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data) await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s") logger.debug(
f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s"
)
except Exception as e: except Exception as e:
logger.error(f"Error caching user scores: {e}") logger.error(f"Error caching user scores: {e}")
async def get_user_beatmapsets_from_cache( async def get_user_beatmapsets_from_cache(
self, self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0
user_id: int,
beatmapset_type: str,
limit: int = 100,
offset: int = 0
) -> list[Any] | None: ) -> list[Any] | None:
"""从缓存获取用户谱面集""" """从缓存获取用户谱面集"""
try: try:
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset) cache_key = self._get_user_beatmapsets_cache_key(
user_id, beatmapset_type, limit, offset
)
cached_data = await self.redis.get(cache_key) cached_data = await self.redis.get(cache_key)
if cached_data: if cached_data:
logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}") logger.debug(
f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}"
)
return json.loads(cached_data) return json.loads(cached_data)
return None return None
except Exception as e: except Exception as e:
@@ -237,23 +240,27 @@ class UserCacheService:
beatmapsets: list[Any], beatmapsets: list[Any],
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
expire_seconds: int | None = None expire_seconds: int | None = None,
): ):
"""缓存用户谱面集""" """缓存用户谱面集"""
try: try:
if expire_seconds is None: if expire_seconds is None:
expire_seconds = settings.user_beatmapsets_cache_expire_seconds expire_seconds = settings.user_beatmapsets_cache_expire_seconds
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset) cache_key = self._get_user_beatmapsets_cache_key(
user_id, beatmapset_type, limit, offset
)
# 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps # 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps
serialized_beatmapsets = [] serialized_beatmapsets = []
for bms in beatmapsets: for bms in beatmapsets:
if hasattr(bms, 'model_dump_json'): if hasattr(bms, "model_dump_json"):
serialized_beatmapsets.append(bms.model_dump_json()) serialized_beatmapsets.append(bms.model_dump_json())
else: else:
serialized_beatmapsets.append(safe_json_dumps(bms)) serialized_beatmapsets.append(safe_json_dumps(bms))
cached_data = f"[{','.join(serialized_beatmapsets)}]" cached_data = f"[{','.join(serialized_beatmapsets)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data) await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s") logger.debug(
f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s"
)
except Exception as e: except Exception as e:
logger.error(f"Error caching user beatmapsets: {e}") logger.error(f"Error caching user beatmapsets: {e}")
@@ -269,7 +276,9 @@ class UserCacheService:
except Exception as e: except Exception as e:
logger.error(f"Error invalidating user cache: {e}") logger.error(f"Error invalidating user cache: {e}")
async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None): async def invalidate_user_scores_cache(
self, user_id: int, mode: GameMode | None = None
):
"""使用户成绩缓存失效""" """使用户成绩缓存失效"""
try: try:
# 删除用户成绩相关缓存 # 删除用户成绩相关缓存
@@ -278,7 +287,9 @@ class UserCacheService:
keys = await self.redis.keys(pattern) keys = await self.redis.keys(pattern)
if keys: if keys:
await self.redis.delete(*keys) await self.redis.delete(*keys)
logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}") logger.info(
f"Invalidated {len(keys)} score cache entries for user {user_id}"
)
except Exception as e: except Exception as e:
logger.error(f"Error invalidating user scores cache: {e}") logger.error(f"Error invalidating user scores cache: {e}")
@@ -290,12 +301,10 @@ class UserCacheService:
self._refreshing = True self._refreshing = True
try: try:
logger.info(f"Preloading cache for {len(user_ids)} users") logger.info(f"Preloading cache for {len(user_ids)} users")
# 批量获取用户 # 批量获取用户
users = ( users = (
await session.exec( await session.exec(select(User).where(col(User.id).in_(user_ids)))
select(User).where(col(User.id).in_(user_ids))
)
).all() ).all()
# 串行缓存用户信息,避免并发数据库访问问题 # 串行缓存用户信息,避免并发数据库访问问题
@@ -324,10 +333,7 @@ class UserCacheService:
logger.error(f"Error caching single user {user.id}: {e}") logger.error(f"Error caching single user {user.id}: {e}")
async def refresh_user_cache_on_score_submit( async def refresh_user_cache_on_score_submit(
self, self, session: AsyncSession, user_id: int, mode: GameMode
session: AsyncSession,
user_id: int,
mode: GameMode
): ):
"""成绩提交后刷新用户缓存""" """成绩提交后刷新用户缓存"""
try: try:
@@ -335,7 +341,7 @@ class UserCacheService:
await self.invalidate_user_cache(user_id) await self.invalidate_user_cache(user_id)
await self.invalidate_v1_user_cache(user_id) await self.invalidate_v1_user_cache(user_id)
await self.invalidate_user_scores_cache(user_id, mode) await self.invalidate_user_scores_cache(user_id, mode)
# 立即重新加载用户信息 # 立即重新加载用户信息
user = await session.get(User, user_id) user = await session.get(User, user_id)
if user and user.id != BANCHOBOT_ID: if user and user.id != BANCHOBOT_ID:
@@ -351,7 +357,7 @@ class UserCacheService:
v1_user_keys = await self.redis.keys("v1_user:*") v1_user_keys = await self.redis.keys("v1_user:*")
all_keys = user_keys + v1_user_keys all_keys = user_keys + v1_user_keys
total_size = 0 total_size = 0
for key in all_keys[:100]: # 限制检查数量 for key in all_keys[:100]: # 限制检查数量
try: try:
size = await self.redis.memory_usage(key) size = await self.redis.memory_usage(key)
@@ -359,12 +365,22 @@ class UserCacheService:
total_size += size total_size += size
except Exception: except Exception:
continue continue
return { return {
"cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]), "cached_users": len(
"cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]), [
k
for k in user_keys
if ":scores:" not in k and ":beatmapsets:" not in k
]
),
"cached_v1_users": len(
[k for k in v1_user_keys if ":scores:" not in k]
),
"cached_user_scores": len([k for k in user_keys if ":scores:" in k]), "cached_user_scores": len([k for k in user_keys if ":scores:" in k]),
"cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]), "cached_user_beatmapsets": len(
[k for k in user_keys if ":beatmapsets:" in k]
),
"total_cached_entries": len(all_keys), "total_cached_entries": len(all_keys),
"estimated_total_size_mb": ( "estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0 round(total_size / 1024 / 1024, 2) if total_size > 0 else 0

View File

@@ -1,92 +1,98 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""测试排行榜缓存序列化修复""" """测试排行榜缓存序列化修复"""
import asyncio from __future__ import annotations
from datetime import UTC, datetime
import warnings import warnings
from datetime import datetime, UTC
from app.service.ranking_cache_service import DateTimeEncoder, safe_json_dumps from app.service.ranking_cache_service import safe_json_dumps
def test_datetime_serialization(): def test_datetime_serialization():
"""测试 datetime 序列化""" """测试 datetime 序列化"""
print("🧪 测试 datetime 序列化...") print("🧪 测试 datetime 序列化...")
test_data = { test_data = {
"id": 1, "id": 1,
"username": "test_user", "username": "test_user",
"last_updated": datetime.now(UTC), "last_updated": datetime.now(UTC),
"join_date": datetime(2020, 1, 1, tzinfo=UTC), "join_date": datetime(2020, 1, 1, tzinfo=UTC),
"stats": { "stats": {"pp": 1000.0, "accuracy": 95.5, "last_played": datetime.now(UTC)},
"pp": 1000.0,
"accuracy": 95.5,
"last_played": datetime.now(UTC)
}
} }
try: try:
# 测试自定义编码器 # 测试自定义编码器
json_result = safe_json_dumps(test_data) json_result = safe_json_dumps(test_data)
print("✅ datetime 序列化成功") print("✅ datetime 序列化成功")
print(f" 序列化结果长度: {len(json_result)}") print(f" 序列化结果长度: {len(json_result)}")
# 验证可以重新解析 # 验证可以重新解析
import json import json
parsed = json.loads(json_result) parsed = json.loads(json_result)
assert "last_updated" in parsed assert "last_updated" in parsed
assert isinstance(parsed["last_updated"], str) assert isinstance(parsed["last_updated"], str)
print("✅ 序列化的 JSON 可以正确解析") print("✅ 序列化的 JSON 可以正确解析")
except Exception as e: except Exception as e:
print(f"❌ datetime 序列化测试失败: {e}") print(f"❌ datetime 序列化测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
def test_boolean_serialization(): def test_boolean_serialization():
"""测试布尔值序列化""" """测试布尔值序列化"""
print("\n🧪 测试布尔值序列化...") print("\n🧪 测试布尔值序列化...")
test_data = { test_data = {
"user": { "user": {
"is_active": 1, # 数据库中的整数布尔值 "is_active": 1, # 数据库中的整数布尔值
"is_supporter": 0, # 数据库中的整数布尔值 "is_supporter": 0, # 数据库中的整数布尔值
"has_profile": True, # 正常布尔值 "has_profile": True, # 正常布尔值
}, },
"stats": { "stats": {
"is_ranked": 1, # 数据库中的整数布尔值 "is_ranked": 1, # 数据库中的整数布尔值
"verified": False, # 正常布尔值 "verified": False, # 正常布尔值
} },
} }
try: try:
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
json_result = safe_json_dumps(test_data) json_result = safe_json_dumps(test_data)
# 检查是否有 Pydantic 序列化警告 # 检查是否有 Pydantic 序列化警告
pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)] pydantic_warnings = [
warning
for warning in w
if "PydanticSerializationUnexpectedValue" in str(warning.message)
]
if pydantic_warnings: if pydantic_warnings:
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个布尔值序列化警告") print(f"⚠️ 仍有 {len(pydantic_warnings)} 个布尔值序列化警告")
for warning in pydantic_warnings: for warning in pydantic_warnings:
print(f" {warning.message}") print(f" {warning.message}")
else: else:
print("✅ 布尔值序列化无警告") print("✅ 布尔值序列化无警告")
# 验证序列化结果 # 验证序列化结果
import json import json
parsed = json.loads(json_result) parsed = json.loads(json_result)
print(f"✅ 布尔值序列化成功,结果: {parsed}") print(f"✅ 布尔值序列化成功,结果: {parsed}")
except Exception as e: except Exception as e:
print(f"❌ 布尔值序列化测试失败: {e}") print(f"❌ 布尔值序列化测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
def test_complex_ranking_data(): def test_complex_ranking_data():
"""测试复杂的排行榜数据序列化""" """测试复杂的排行榜数据序列化"""
print("\n🧪 测试复杂排行榜数据序列化...") print("\n🧪 测试复杂排行榜数据序列化...")
# 模拟排行榜数据结构 # 模拟排行榜数据结构
ranking_data = [ ranking_data = [
{ {
@@ -95,8 +101,8 @@ def test_complex_ranking_data():
"id": 1, "id": 1,
"username": "player1", "username": "player1",
"country_code": "US", "country_code": "US",
"is_active": 1, # 整数布尔值 "is_active": 1, # 整数布尔值
"is_supporter": 0, # 整数布尔值 "is_supporter": 0, # 整数布尔值
"join_date": datetime(2020, 1, 1, tzinfo=UTC), "join_date": datetime(2020, 1, 1, tzinfo=UTC),
"last_visit": datetime.now(UTC), "last_visit": datetime.now(UTC),
}, },
@@ -104,9 +110,9 @@ def test_complex_ranking_data():
"pp": 8000.0, "pp": 8000.0,
"accuracy": 98.5, "accuracy": 98.5,
"play_count": 5000, "play_count": 5000,
"is_ranked": 1, # 整数布尔值 "is_ranked": 1, # 整数布尔值
"last_updated": datetime.now(UTC), "last_updated": datetime.now(UTC),
} },
}, },
{ {
"id": 2, "id": 2,
@@ -125,41 +131,47 @@ def test_complex_ranking_data():
"play_count": 4500, "play_count": 4500,
"is_ranked": 1, "is_ranked": 1,
"last_updated": datetime.now(UTC), "last_updated": datetime.now(UTC),
} },
} },
] ]
try: try:
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
json_result = safe_json_dumps(ranking_data) json_result = safe_json_dumps(ranking_data)
pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)] pydantic_warnings = [
warning
for warning in w
if "PydanticSerializationUnexpectedValue" in str(warning.message)
]
if pydantic_warnings: if pydantic_warnings:
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个序列化警告") print(f"⚠️ 仍有 {len(pydantic_warnings)} 个序列化警告")
for warning in pydantic_warnings: for warning in pydantic_warnings:
print(f" {warning.message}") print(f" {warning.message}")
else: else:
print("✅ 复杂排行榜数据序列化无警告") print("✅ 复杂排行榜数据序列化无警告")
# 验证序列化结果 # 验证序列化结果
import json import json
parsed = json.loads(json_result) parsed = json.loads(json_result)
assert len(parsed) == 2 assert len(parsed) == 2
assert parsed[0]["user"]["username"] == "player1" assert parsed[0]["user"]["username"] == "player1"
print(f"✅ 复杂排行榜数据序列化成功,包含 {len(parsed)} 个条目") print(f"✅ 复杂排行榜数据序列化成功,包含 {len(parsed)} 个条目")
except Exception as e: except Exception as e:
print(f"❌ 复杂排行榜数据序列化测试失败: {e}") print(f"❌ 复杂排行榜数据序列化测试失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
if __name__ == "__main__": if __name__ == "__main__":
print("🚀 开始排行榜缓存序列化测试\n") print("🚀 开始排行榜缓存序列化测试\n")
test_datetime_serialization() test_datetime_serialization()
test_boolean_serialization() test_boolean_serialization()
test_complex_ranking_data() test_complex_ranking_data()
print("\n🎉 排行榜缓存序列化测试完成!") print("\n🎉 排行榜缓存序列化测试完成!")