From 80d4237c5d81dde69dfa64ac45ce6f1c3592da02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= <74496778+GooGuJiang@users.noreply.github.com> Date: Fri, 22 Aug 2025 00:07:19 +0800 Subject: [PATCH] ruff fix --- app/calculator.py | 16 +-- app/config.py | 4 +- app/database/beatmapset.py | 15 ++- app/database/field_utils.py | 20 ++-- app/database/lazer_user.py | 2 +- app/database/score.py | 17 ++- app/database/statistics.py | 2 +- app/fetcher/_base.py | 28 +++-- app/fetcher/beatmapset.py | 24 ++-- app/helpers/rate_limiter.py | 1 + app/router/v1/user.py | 27 +++-- app/router/v2/cache.py | 39 +++--- app/router/v2/ranking.py | 61 +++++----- app/router/v2/score.py | 4 +- app/router/v2/user.py | 60 +++++----- app/scheduler/__init__.py | 1 + app/scheduler/cache_scheduler.py | 26 ++-- app/scheduler/user_cache_scheduler.py | 32 ++--- app/service/beatmap_cache_service.py | 5 +- app/service/ranking_cache_service.py | 163 +++++++++++++------------- app/service/user_cache_service.py | 142 ++++++++++++---------- test_ranking_serialization.py | 90 ++++++++------ 22 files changed, 423 insertions(+), 356 deletions(-) diff --git a/app/calculator.py b/app/calculator.py index cc340cf..344bafe 100644 --- a/app/calculator.py +++ b/app/calculator.py @@ -88,6 +88,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f # 使用线程池执行计算密集型操作以避免阻塞事件循环 import asyncio + loop = asyncio.get_event_loop() def _calculate_pp_sync(): @@ -131,11 +132,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f async def pre_fetch_and_calculate_pp( - score: "Score", - beatmap_id: int, - session: AsyncSession, - redis, - fetcher + score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher ) -> float: """ 优化版PP计算:预先获取beatmap文件并使用缓存 @@ -148,9 +145,7 @@ async def pre_fetch_and_calculate_pp( if settings.suspicious_score_check: beatmap_banned = ( await session.exec( - select(exists()).where( - col(BannedBeatmaps.beatmap_id) == beatmap_id - ) + select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id) ) ).first() if beatmap_banned: @@ -184,10 +179,7 @@ async def pre_fetch_and_calculate_pp( async def batch_calculate_pp( - scores_data: list[tuple["Score", int]], - session: AsyncSession, - redis, - fetcher + scores_data: list[tuple["Score", int]], session: AsyncSession, redis, fetcher ) -> list[float]: """ 批量计算PP:适用于重新计算或批量处理场景 diff --git a/app/config.py b/app/config.py index 96ee2df..26d2bec 100644 --- a/app/config.py +++ b/app/config.py @@ -142,14 +142,14 @@ class Settings(BaseSettings): beatmap_cache_expire_hours: int = 24 max_concurrent_pp_calculations: int = 10 enable_pp_calculation_threading: bool = True - + # 排行榜缓存设置 enable_ranking_cache: bool = True ranking_cache_expire_minutes: int = 10 # 排行榜缓存过期时间(分钟) ranking_cache_refresh_interval_minutes: int = 10 # 排行榜缓存刷新间隔(分钟) ranking_cache_max_pages: int = 20 # 最多缓存的页数 ranking_cache_top_countries: int = 20 # 缓存前N个国家的排行榜 - + # 用户缓存设置 enable_user_cache_preload: bool = True # 启用用户缓存预加载 user_cache_expire_seconds: int = 300 # 用户信息缓存过期时间(秒) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 447da5b..48b0408 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -8,7 +8,7 @@ from app.models.score import GameMode from .lazer_user import BASE_INCLUDES, User, UserResp from pydantic import BaseModel, field_validator, model_validator -from sqlalchemy import Boolean, JSON, Column, DateTime, Text +from sqlalchemy import JSON, Boolean, Column, DateTime, Text from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -205,7 +205,18 @@ class BeatmapsetResp(BeatmapsetBase): favourite_count: int = 0 recent_favourites: list[UserResp] = Field(default_factory=list) - @field_validator('nsfw', 'spotlight', 'video', 'can_be_hyped', 'discussion_locked', 'storyboard', 'discussion_enabled', 'is_scoreable', 'has_favourited', mode='before') + @field_validator( + "nsfw", + "spotlight", + "video", + "can_be_hyped", + "discussion_locked", + "storyboard", + "discussion_enabled", + "is_scoreable", + "has_favourited", + mode="before", + ) @classmethod def validate_bool_fields(cls, v): """将整数 0/1 转换为布尔值,处理数据库中的布尔字段""" diff --git a/app/database/field_utils.py b/app/database/field_utils.py index 53c1011..5f18134 100644 --- a/app/database/field_utils.py +++ b/app/database/field_utils.py @@ -2,13 +2,16 @@ 数据库字段类型工具 提供处理数据库和 Pydantic 之间类型转换的工具 """ -from typing import Any, Union + +from typing import Any + from pydantic import field_validator from sqlalchemy import Boolean def bool_field_validator(field_name: str): """为特定布尔字段创建验证器,处理数据库中的 0/1 整数""" + @field_validator(field_name, mode="before") @classmethod def validate_bool_field(cls, v: Any) -> bool: @@ -16,20 +19,21 @@ def bool_field_validator(field_name: str): if isinstance(v, int): return bool(v) return v + return validate_bool_field def create_bool_field(**kwargs): """创建一个带有正确 SQLAlchemy 列定义的布尔字段""" - from sqlmodel import Field, Column - + from sqlmodel import Column, Field + # 如果没有指定 sa_column,则使用 Boolean 类型 - if 'sa_column' not in kwargs: + if "sa_column" not in kwargs: # 处理 index 参数 - index = kwargs.pop('index', False) + index = kwargs.pop("index", False) if index: - kwargs['sa_column'] = Column(Boolean, index=True) + kwargs["sa_column"] = Column(Boolean, index=True) else: - kwargs['sa_column'] = Column(Boolean) - + kwargs["sa_column"] = Column(Boolean) + return Field(**kwargs) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 515f769..232b00d 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -136,7 +136,7 @@ class UserBase(UTCBaseModel, SQLModel): is_qat: bool = False is_bng: bool = False - @field_validator('playmode', mode='before') + @field_validator("playmode", mode="before") @classmethod def validate_playmode(cls, v): """将字符串转换为 GameMode 枚举""" diff --git a/app/database/score.py b/app/database/score.py index ddfef6d..88e2cfa 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -100,7 +100,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): sa_column=Column(JSON), default_factory=dict ) - @field_validator('maximum_statistics', mode='before') + @field_validator("maximum_statistics", mode="before") @classmethod def validate_maximum_statistics(cls, v): """处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举""" @@ -151,7 +151,7 @@ class Score(ScoreBase, table=True): gamemode: GameMode = Field(index=True) pinned_order: int = Field(default=0, exclude=True) - @field_validator('gamemode', mode='before') + @field_validator("gamemode", mode="before") @classmethod def validate_gamemode(cls, v): """将字符串转换为 GameMode 枚举""" @@ -209,7 +209,16 @@ class ScoreResp(ScoreBase): ranked: bool = False current_user_attributes: CurrentUserAttributes | None = None - @field_validator('has_replay', 'passed', 'preserve', 'is_perfect_combo', 'legacy_perfect', 'processed', 'ranked', mode='before') + @field_validator( + "has_replay", + "passed", + "preserve", + "is_perfect_combo", + "legacy_perfect", + "processed", + "ranked", + mode="before", + ) @classmethod def validate_bool_fields(cls, v): """将整数 0/1 转换为布尔值,处理数据库中的布尔字段""" @@ -217,7 +226,7 @@ class ScoreResp(ScoreBase): return bool(v) return v - @field_validator('statistics', 'maximum_statistics', mode='before') + @field_validator("statistics", "maximum_statistics", mode="before") @classmethod def validate_statistics_fields(cls, v): """处理统计字段中的字符串键,转换为 HitResult 枚举""" diff --git a/app/database/statistics.py b/app/database/statistics.py index 04b037d..01226cb 100644 --- a/app/database/statistics.py +++ b/app/database/statistics.py @@ -44,7 +44,7 @@ class UserStatisticsBase(SQLModel): replays_watched_by_others: int = Field(default=0) is_ranked: bool = Field(default=True) - @field_validator('mode', mode='before') + @field_validator("mode", mode="before") @classmethod def validate_mode(cls, v): """将字符串转换为 GameMode 枚举""" diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py index ec0bb4a..97d16db 100644 --- a/app/fetcher/_base.py +++ b/app/fetcher/_base.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -from typing import Optional from app.dependencies.database import get_redis from app.log import logger @@ -11,6 +10,7 @@ from httpx import AsyncClient class TokenAuthError(Exception): """Token 授权失败异常""" + pass @@ -55,7 +55,7 @@ class BaseFetcher: return await self._request_with_retry(url, method, **kwargs) async def _request_with_retry( - self, url: str, method: str = "GET", max_retries: Optional[int] = None, **kwargs + self, url: str, method: str = "GET", max_retries: int | None = None, **kwargs ) -> dict: """ 带重试机制的请求方法 @@ -64,7 +64,7 @@ class BaseFetcher: max_retries = self.max_retries last_error = None - + for attempt in range(max_retries + 1): try: # 检查 token 是否过期 @@ -126,7 +126,9 @@ class BaseFetcher: ) continue else: - logger.error(f"Request failed after {max_retries + 1} attempts: {e}") + logger.error( + f"Request failed after {max_retries + 1} attempts: {e}" + ) break # 如果所有重试都失败了 @@ -194,9 +196,13 @@ class BaseFetcher: f"fetcher:refresh_token:{self.client_id}", self.refresh_token, ) - logger.info(f"Successfully refreshed access token for client {self.client_id}") + logger.info( + f"Successfully refreshed access token for client {self.client_id}" + ) except Exception as e: - logger.error(f"Failed to refresh access token for client {self.client_id}: {e}") + logger.error( + f"Failed to refresh access token for client {self.client_id}: {e}" + ) # 清除无效的 token,要求重新授权 self.access_token = "" self.refresh_token = "" @@ -204,7 +210,9 @@ class BaseFetcher: redis = get_redis() await redis.delete(f"fetcher:access_token:{self.client_id}") await redis.delete(f"fetcher:refresh_token:{self.client_id}") - logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}") + logger.warning( + f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}" + ) raise async def _trigger_reauthorization(self) -> None: @@ -216,18 +224,18 @@ class BaseFetcher: f"Authentication failed after {self._auth_retry_count} attempts. " f"Triggering reauthorization for client {self.client_id}" ) - + # 清除内存中的 token self.access_token = "" self.refresh_token = "" self.token_expiry = 0 self._auth_retry_count = 0 # 重置重试计数器 - + # 清除 Redis 中的 token redis = get_redis() await redis.delete(f"fetcher:access_token:{self.client_id}") await redis.delete(f"fetcher:refresh_token:{self.client_id}") - + logger.warning( f"All tokens cleared for client {self.client_id}. " f"Please re-authorize using: {self.authorize_url}" diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py index f87f2ed..9f3c025 100644 --- a/app/fetcher/beatmapset.py +++ b/app/fetcher/beatmapset.py @@ -101,6 +101,7 @@ class BeatmapsetFetcher(BaseFetcher): return json.loads(cursor_json) except Exception: return {} + async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp: logger.opt(colors=True).debug( f"[BeatmapsetFetcher] get_beatmapset: {beatmap_set_id}" @@ -164,9 +165,7 @@ class BeatmapsetFetcher(BaseFetcher): # 将结果缓存 15 分钟 cache_ttl = 15 * 60 # 15 分钟 await redis_client.set( - cache_key, - json.dumps(api_response, separators=(",", ":")), - ex=cache_ttl + cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl ) logger.opt(colors=True).debug( @@ -178,10 +177,12 @@ class BeatmapsetFetcher(BaseFetcher): # 智能预取:只在用户明确搜索时才预取,避免过多API请求 # 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取 - if (api_response.get("cursor") and - (query.q or query.s != "leaderboard" or cursor)): + if api_response.get("cursor") and ( + query.q or query.s != "leaderboard" or cursor + ): # 在后台预取下1页(减少预取量) import asyncio + # 不立即创建任务,而是延迟一段时间再预取 async def delayed_prefetch(): await asyncio.sleep(3.0) # 延迟3秒 @@ -200,8 +201,11 @@ class BeatmapsetFetcher(BaseFetcher): return resp async def prefetch_next_pages( - self, query: SearchQueryModel, current_cursor: Cursor, - redis_client: redis.Redis, pages: int = 3 + self, + query: SearchQueryModel, + current_cursor: Cursor, + redis_client: redis.Redis, + pages: int = 3, ) -> None: """预取下几页内容""" if not current_cursor: @@ -269,7 +273,7 @@ class BeatmapsetFetcher(BaseFetcher): await redis_client.set( next_cache_key, json.dumps(api_response, separators=(",", ":")), - ex=prefetch_ttl + ex=prefetch_ttl, ) logger.opt(colors=True).debug( @@ -317,7 +321,6 @@ class BeatmapsetFetcher(BaseFetcher): params=params, ) - if api_response.get("cursor"): cursor_dict = api_response["cursor"] api_response["cursor_string"] = self._encode_cursor(cursor_dict) @@ -327,7 +330,7 @@ class BeatmapsetFetcher(BaseFetcher): await redis_client.set( cache_key, json.dumps(api_response, separators=(",", ":")), - ex=cache_ttl + ex=cache_ttl, ) logger.opt(colors=True).info( @@ -335,7 +338,6 @@ class BeatmapsetFetcher(BaseFetcher): f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)" ) - if api_response.get("cursor"): await self.prefetch_next_pages( query, api_response["cursor"], redis_client, pages=2 diff --git a/app/helpers/rate_limiter.py b/app/helpers/rate_limiter.py index ad1f27b..40efc9b 100644 --- a/app/helpers/rate_limiter.py +++ b/app/helpers/rate_limiter.py @@ -5,6 +5,7 @@ Rate limiter for osu! API requests to avoid abuse detection. - 突发:短时间内最多 200 次额外请求 - 建议:每分钟不超过 60 次请求以避免滥用检测 """ + from __future__ import annotations import asyncio diff --git a/app/router/v1/user.py b/app/router/v1/user.py index d8938ed..bb57c3a 100644 --- a/app/router/v1/user.py +++ b/app/router/v1/user.py @@ -4,7 +4,6 @@ import asyncio from datetime import datetime from typing import Literal -from app.config import settings from app.database.lazer_user import User from app.database.statistics import UserStatistics, UserStatisticsResp from app.dependencies.database import Database, get_redis @@ -56,7 +55,7 @@ class V1User(AllStrModel): # 确保 user_id 不为 None if db_user.id is None: raise ValueError("User ID cannot be None") - + ruleset = ruleset or db_user.playmode current_statistics: UserStatistics | None = None for i in await db_user.awaitable_attrs.statistics: @@ -118,26 +117,28 @@ async def get_user( ): redis = get_redis() cache_service = get_user_cache_service(redis) - + # 确定查询方式和用户ID is_id_query = type == "id" or user.isdigit() - + # 解析 ruleset ruleset = GameMode.from_int_extra(ruleset_id) if ruleset_id else None - + # 如果是 ID 查询,先尝试从缓存获取 cached_v1_user = None user_id_for_cache = None - + if is_id_query: try: user_id_for_cache = int(user) - cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset) + cached_v1_user = await cache_service.get_v1_user_from_cache( + user_id_for_cache, ruleset + ) if cached_v1_user: return [V1User(**cached_v1_user)] except (ValueError, TypeError): pass # 不是有效的用户ID,继续数据库查询 - + # 从数据库查询用户 db_user = ( await session.exec( @@ -146,23 +147,23 @@ async def get_user( ) ) ).first() - + if not db_user: return [] - + try: # 生成用户数据 v1_user = await V1User.from_db(session, db_user, ruleset) - + # 异步缓存结果(如果有用户ID) if db_user.id is not None: user_data = v1_user.model_dump() asyncio.create_task( cache_service.cache_v1_user(user_data, db_user.id, ruleset) ) - + return [v1_user] - + except KeyError: raise HTTPException(400, "Invalid request") except ValueError as e: diff --git a/app/router/v2/cache.py b/app/router/v2/cache.py index 7dd970c..08a0b27 100644 --- a/app/router/v2/cache.py +++ b/app/router/v2/cache.py @@ -2,15 +2,15 @@ 缓存管理和监控接口 提供缓存统计、清理和预热功能 """ + from __future__ import annotations from app.dependencies.database import get_redis -from app.dependencies.user import get_current_user from app.service.user_cache_service import get_user_cache_service from .router import router -from fastapi import Depends, HTTPException, Security +from fastapi import Depends, HTTPException from pydantic import BaseModel from redis.asyncio import Redis @@ -34,7 +34,7 @@ async def get_cache_stats( try: cache_service = get_user_cache_service(redis) user_cache_stats = await cache_service.get_cache_stats() - + # 获取 Redis 基本信息 redis_info = await redis.info() redis_stats = { @@ -47,20 +47,17 @@ async def get_cache_stats( "evicted_keys": redis_info.get("evicted_keys", 0), "expired_keys": redis_info.get("expired_keys", 0), } - + # 计算缓存命中率 hits = redis_stats["keyspace_hits"] misses = redis_stats["keyspace_misses"] hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0 redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2) - - return CacheStatsResponse( - user_cache=user_cache_stats, - redis_info=redis_stats - ) - + + return CacheStatsResponse(user_cache=user_cache_stats, redis_info=redis_stats) + except Exception as e: - raise HTTPException(500, f"Failed to get cache stats: {str(e)}") + raise HTTPException(500, f"Failed to get cache stats: {e!s}") @router.post( @@ -80,7 +77,7 @@ async def invalidate_user_cache( await cache_service.invalidate_v1_user_cache(user_id) return {"message": f"Cache invalidated for user {user_id}"} except Exception as e: - raise HTTPException(500, f"Failed to invalidate cache: {str(e)}") + raise HTTPException(500, f"Failed to invalidate cache: {e!s}") @router.post( @@ -98,15 +95,15 @@ async def clear_all_user_cache( user_keys = await redis.keys("user:*") v1_user_keys = await redis.keys("v1_user:*") all_keys = user_keys + v1_user_keys - + if all_keys: await redis.delete(*all_keys) return {"message": f"Cleared {len(all_keys)} cache entries"} else: return {"message": "No cache entries found"} - + except Exception as e: - raise HTTPException(500, f"Failed to clear cache: {str(e)}") + raise HTTPException(500, f"Failed to clear cache: {e!s}") class CacheWarmupRequest(BaseModel): @@ -127,18 +124,22 @@ async def warmup_cache( ): try: cache_service = get_user_cache_service(redis) - + if request.user_ids: # 预热指定用户 from app.dependencies.database import with_db + async with with_db() as session: await cache_service.preload_user_cache(session, request.user_ids) return {"message": f"Warmed up cache for {len(request.user_ids)} users"} else: # 预热活跃用户 - from app.scheduler.user_cache_scheduler import schedule_user_cache_preload_task + from app.scheduler.user_cache_scheduler import ( + schedule_user_cache_preload_task, + ) + await schedule_user_cache_preload_task() return {"message": f"Warmed up cache for top {request.limit} active users"} - + except Exception as e: - raise HTTPException(500, f"Failed to warmup cache: {str(e)}") + raise HTTPException(500, f"Failed to warmup cache: {e!s}") diff --git a/app/router/v2/ranking.py b/app/router/v2/ranking.py index 3bc640f..dcb6d8f 100644 --- a/app/router/v2/ranking.py +++ b/app/router/v2/ranking.py @@ -45,24 +45,24 @@ async def get_country_ranking( # 获取 Redis 连接和缓存服务 redis = get_redis() cache_service = get_ranking_cache_service(redis) - + # 尝试从缓存获取数据 cached_data = await cache_service.get_cached_country_ranking(ruleset, page) - + if cached_data: # 从缓存返回数据 return CountryResponse( ranking=[CountryStatistics.model_validate(item) for item in cached_data] ) - + # 缓存未命中,从数据库查询 response = CountryResponse(ranking=[]) countries = (await session.exec(select(User.country_code).distinct())).all() - + for country in countries: if not country: # 跳过空的国家代码 continue - + statistics = ( await session.exec( select(UserStatistics).where( @@ -73,10 +73,10 @@ async def get_country_ranking( ) ) ).all() - + if not statistics: # 跳过没有数据的国家 continue - + pp = 0 country_stats = CountryStatistics( code=country, @@ -92,27 +92,28 @@ async def get_country_ranking( pp += stat.pp country_stats.performance = round(pp) response.ranking.append(country_stats) - + response.ranking.sort(key=lambda x: x.performance, reverse=True) - + # 分页处理 page_size = 50 start_idx = (page - 1) * page_size end_idx = start_idx + page_size - + # 获取当前页的数据 current_page_data = response.ranking[start_idx:end_idx] - + # 异步缓存数据(不等待完成) cache_data = [item.model_dump() for item in current_page_data] cache_task = cache_service.cache_country_ranking( ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60 ) - + # 创建后台任务来缓存数据 import asyncio + asyncio.create_task(cache_task) - + # 返回当前页的结果 response.ranking = current_page_data return response @@ -142,20 +143,16 @@ async def get_user_ranking( # 获取 Redis 连接和缓存服务 redis = get_redis() cache_service = get_ranking_cache_service(redis) - + # 尝试从缓存获取数据 - cached_data = await cache_service.get_cached_ranking( - ruleset, type, country, page - ) - + cached_data = await cache_service.get_cached_ranking(ruleset, type, country, page) + if cached_data: # 从缓存返回数据 return TopUsersResponse( - ranking=[ - UserStatisticsResp.model_validate(item) for item in cached_data - ] + ranking=[UserStatisticsResp.model_validate(item) for item in cached_data] ) - + # 缓存未命中,从数据库查询 wheres = [ col(UserStatistics.mode) == ruleset, @@ -170,7 +167,7 @@ async def get_user_ranking( order_by = col(UserStatistics.ranked_score).desc() if country: wheres.append(col(UserStatistics.user).has(country_code=country.upper())) - + statistics_list = await session.exec( select(UserStatistics) .where(*wheres) @@ -178,7 +175,7 @@ async def get_user_ranking( .limit(50) .offset(50 * (page - 1)) ) - + # 转换为响应格式 ranking_data = [] for statistics in statistics_list: @@ -186,18 +183,24 @@ async def get_user_ranking( statistics, session, None, include ) ranking_data.append(user_stats_resp) - + # 异步缓存数据(不等待完成) # 使用配置文件中的TTL设置 cache_data = [item.model_dump() for item in ranking_data] cache_task = cache_service.cache_ranking( - ruleset, type, cache_data, country, page, ttl=settings.ranking_cache_expire_minutes * 60 + ruleset, + type, + cache_data, + country, + page, + ttl=settings.ranking_cache_expire_minutes * 60, ) - + # 创建后台任务来缓存数据 import asyncio + asyncio.create_task(cache_task) - + resp = TopUsersResponse(ranking=ranking_data) return resp @@ -328,4 +331,4 @@ async def get_ranking_cache_stats( cache_service = get_ranking_cache_service(redis) stats = await cache_service.get_cache_stats() - return stats """ \ No newline at end of file + return stats """ diff --git a/app/router/v2/score.py b/app/router/v2/score.py index 648505a..d3eb903 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -183,7 +183,7 @@ async def submit_score( } db.add(rank_event) await db.commit() - + # 成绩提交后刷新用户缓存 try: user_cache_service = get_user_cache_service(redis) @@ -193,7 +193,7 @@ async def submit_score( ) except Exception as e: logger.error(f"Failed to refresh user cache after score submit: {e}") - + background_task.add_task(process_user_achievement, resp.id) return resp diff --git a/app/router/v2/user.py b/app/router/v2/user.py index 28ef086..9e67de6 100644 --- a/app/router/v2/user.py +++ b/app/router/v2/user.py @@ -57,25 +57,27 @@ async def get_users( ): redis = get_redis() cache_service = get_user_cache_service(redis) - + if user_ids: # 先尝试从缓存获取 cached_users = [] uncached_user_ids = [] - + for user_id in user_ids[:50]: # 限制50个 cached_user = await cache_service.get_user_from_cache(user_id) if cached_user: cached_users.append(cached_user) else: uncached_user_ids.append(user_id) - + # 查询未缓存的用户 if uncached_user_ids: searched_users = ( - await session.exec(select(User).where(col(User.id).in_(uncached_user_ids))) + await session.exec( + select(User).where(col(User.id).in_(uncached_user_ids)) + ) ).all() - + # 将查询到的用户添加到缓存并返回 for searched_user in searched_users: if searched_user.id != BANCHOBOT_ID: @@ -87,7 +89,7 @@ async def get_users( cached_users.append(user_resp) # 异步缓存,不阻塞响应 asyncio.create_task(cache_service.cache_user(user_resp)) - + return BatchUserResponse(users=cached_users) else: searched_users = (await session.exec(select(User).limit(50))).all() @@ -102,7 +104,7 @@ async def get_users( users.append(user_resp) # 异步缓存 asyncio.create_task(cache_service.cache_user(user_resp)) - + return BatchUserResponse(users=users) @@ -121,14 +123,14 @@ async def get_user_info_ruleset( ): redis = get_redis() cache_service = get_user_cache_service(redis) - + # 如果是数字ID,先尝试从缓存获取 if user_id.isdigit(): user_id_int = int(user_id) cached_user = await cache_service.get_user_from_cache(user_id_int, ruleset) if cached_user: return cached_user - + searched_user = ( await session.exec( select(User).where( @@ -140,17 +142,17 @@ async def get_user_info_ruleset( ).first() if not searched_user or searched_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") - + user_resp = await UserResp.from_db( searched_user, session, include=SEARCH_INCLUDED, ruleset=ruleset, ) - + # 异步缓存结果 asyncio.create_task(cache_service.cache_user(user_resp, ruleset)) - + return user_resp @@ -169,14 +171,14 @@ async def get_user_info( ): redis = get_redis() cache_service = get_user_cache_service(redis) - + # 如果是数字ID,先尝试从缓存获取 if user_id.isdigit(): user_id_int = int(user_id) cached_user = await cache_service.get_user_from_cache(user_id_int) if cached_user: return cached_user - + searched_user = ( await session.exec( select(User).where( @@ -188,16 +190,16 @@ async def get_user_info( ).first() if not searched_user or searched_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") - + user_resp = await UserResp.from_db( searched_user, session, include=SEARCH_INCLUDED, ) - + # 异步缓存结果 asyncio.create_task(cache_service.cache_user(user_resp)) - + return user_resp @@ -218,7 +220,7 @@ async def get_user_beatmapsets( ): redis = get_redis() cache_service = get_user_cache_service(redis) - + # 先尝试从缓存获取 cached_result = await cache_service.get_user_beatmapsets_from_cache( user_id, type.value, limit, offset @@ -229,7 +231,7 @@ async def get_user_beatmapsets( return [BeatmapPlaycountsResp(**item) for item in cached_result] else: return [BeatmapsetResp(**item) for item in cached_result] - + user = await session.get(User, user_id) if not user or user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") @@ -275,10 +277,14 @@ async def get_user_beatmapsets( # 异步缓存结果 async def cache_beatmapsets(): try: - await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset) + await cache_service.cache_user_beatmapsets( + user_id, type.value, resp, limit, offset + ) except Exception as e: - logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}") - + logger.error( + f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}" + ) + asyncio.create_task(cache_beatmapsets()) return resp @@ -311,7 +317,7 @@ async def get_user_scores( ): redis = get_redis() cache_service = get_user_cache_service(redis) - + # 先尝试从缓存获取(对于recent类型使用较短的缓存时间) cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds cached_scores = await cache_service.get_user_scores_from_cache( @@ -319,7 +325,7 @@ async def get_user_scores( ) if cached_scores is not None: return cached_scores - + db_user = await session.get(User, user_id) if not db_user or db_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") @@ -355,7 +361,7 @@ async def get_user_scores( ).all() if not scores: return [] - + score_responses = [ await ScoreResp.from_db( session, @@ -363,14 +369,14 @@ async def get_user_scores( ) for score in scores ] - + # 异步缓存结果 asyncio.create_task( cache_service.cache_user_scores( user_id, type, score_responses, mode, limit, offset, cache_expire ) ) - + return score_responses diff --git a/app/scheduler/__init__.py b/app/scheduler/__init__.py index c9d77d0..d6e4f7c 100644 --- a/app/scheduler/__init__.py +++ b/app/scheduler/__init__.py @@ -1,4 +1,5 @@ """缓存调度器模块""" + from __future__ import annotations from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler diff --git a/app/scheduler/cache_scheduler.py b/app/scheduler/cache_scheduler.py index 72722e3..8edecfb 100644 --- a/app/scheduler/cache_scheduler.py +++ b/app/scheduler/cache_scheduler.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio from app.config import settings -from app.dependencies.database import get_db, get_redis +from app.dependencies.database import get_redis from app.dependencies.fetcher import get_fetcher from app.log import logger from app.scheduler.user_cache_scheduler import ( @@ -44,10 +44,10 @@ class CacheScheduler: """运行调度器主循环""" # 启动时立即执行一次预热 await self._warmup_cache() - + # 启动时执行一次排行榜缓存刷新 await self._refresh_ranking_cache() - + # 启动时执行一次用户缓存预热 await self._warmup_user_cache() @@ -55,14 +55,16 @@ class CacheScheduler: ranking_cache_counter = 0 user_cache_counter = 0 user_cleanup_counter = 0 - + # 从配置文件获取间隔设置 check_interval = 5 * 60 # 5分钟检查间隔 beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔 - ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取 + ranking_cache_interval = ( + settings.ranking_cache_refresh_interval_minutes * 60 + ) # 从配置读取 user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔 user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔 - + beatmap_cache_cycles = beatmap_cache_interval // check_interval ranking_cache_cycles = ranking_cache_interval // check_interval user_cache_cycles = user_cache_interval // check_interval @@ -90,12 +92,12 @@ class CacheScheduler: if ranking_cache_counter >= ranking_cache_cycles: await self._refresh_ranking_cache() ranking_cache_counter = 0 - + # 用户缓存预加载 if user_cache_counter >= user_cache_cycles: await self._preload_user_cache() user_cache_counter = 0 - + # 用户缓存清理 if user_cleanup_counter >= user_cleanup_cycles: await self._cleanup_user_cache() @@ -129,15 +131,14 @@ class CacheScheduler: logger.info("Starting ranking cache refresh...") redis = get_redis() - + # 导入排行榜缓存服务 + # 使用独立的数据库会话 + from app.dependencies.database import with_db from app.service.ranking_cache_service import ( - get_ranking_cache_service, schedule_ranking_refresh_task, ) - # 使用独立的数据库会话 - from app.dependencies.database import with_db async with with_db() as session: await schedule_ranking_refresh_task(session, redis) @@ -171,6 +172,7 @@ class CacheScheduler: # Beatmap缓存调度器(保持向后兼容) class BeatmapsetCacheScheduler(CacheScheduler): """谱面集缓存调度器 - 为了向后兼容""" + pass diff --git a/app/scheduler/user_cache_scheduler.py b/app/scheduler/user_cache_scheduler.py index cf43af2..b98df8a 100644 --- a/app/scheduler/user_cache_scheduler.py +++ b/app/scheduler/user_cache_scheduler.py @@ -1,15 +1,15 @@ """ 用户缓存预热任务调度器 """ + from __future__ import annotations import asyncio from datetime import UTC, datetime, timedelta from app.config import settings -from app.database import User from app.database.score import Score -from app.dependencies.database import get_db, get_redis +from app.dependencies.database import get_redis from app.log import logger from app.service.user_cache_service import get_user_cache_service @@ -25,16 +25,17 @@ async def schedule_user_cache_preload_task(): try: logger.info("Starting user cache preload task...") - + redis = get_redis() cache_service = get_user_cache_service(redis) - + # 使用独立的数据库会话 from app.dependencies.database import with_db + async with with_db() as session: # 获取最近24小时内活跃的用户(提交过成绩的用户) recent_time = datetime.now(UTC) - timedelta(hours=24) - + active_user_ids = ( await session.exec( select(Score.user_id, func.count().label("score_count")) @@ -44,7 +45,7 @@ async def schedule_user_cache_preload_task(): .limit(settings.user_cache_max_preload_users) # 使用配置中的限制 ) ).all() - + if active_user_ids: user_ids = [row[0] for row in active_user_ids] await cache_service.preload_user_cache(session, user_ids) @@ -62,17 +63,18 @@ async def schedule_user_cache_warmup_task(): """定时用户缓存预热任务 - 预加载排行榜前100用户""" try: logger.info("Starting user cache warmup task...") - + redis = get_redis() cache_service = get_user_cache_service(redis) - + # 使用独立的数据库会话 from app.dependencies.database import with_db + async with with_db() as session: # 获取全球排行榜前100的用户 from app.database.statistics import UserStatistics from app.models.score import GameMode - + for mode in GameMode: try: top_users = ( @@ -83,15 +85,15 @@ async def schedule_user_cache_warmup_task(): .limit(100) ) ).all() - + if top_users: user_ids = list(top_users) await cache_service.preload_user_cache(session, user_ids) logger.info(f"Warmed cache for top 100 users in {mode}") - + # 避免过载,稍微延迟 await asyncio.sleep(1) - + except Exception as e: logger.error(f"Failed to warm cache for {mode}: {e}") continue @@ -106,13 +108,13 @@ async def schedule_user_cache_cleanup_task(): """定时用户缓存清理任务""" try: logger.info("Starting user cache cleanup task...") - + redis = get_redis() - + # 清理过期的用户缓存(Redis会自动处理TTL,这里主要记录统计信息) cache_service = get_user_cache_service(redis) stats = await cache_service.get_cache_stats() - + logger.info(f"User cache stats: {stats}") logger.info("User cache cleanup task completed successfully") diff --git a/app/service/beatmap_cache_service.py b/app/service/beatmap_cache_service.py index a4f6c17..2f76426 100644 --- a/app/service/beatmap_cache_service.py +++ b/app/service/beatmap_cache_service.py @@ -2,6 +2,7 @@ Beatmap缓存预取服务 用于提前缓存热门beatmap,减少成绩计算时的获取延迟 """ + from __future__ import annotations import asyncio @@ -155,9 +156,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS async def schedule_preload_task( - session: AsyncSession, - redis: Redis, - fetcher: "Fetcher" + session: AsyncSession, redis: Redis, fetcher: "Fetcher" ): """ 定时预加载任务 diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py index 7eff17d..8006f8d 100644 --- a/app/service/ranking_cache_service.py +++ b/app/service/ranking_cache_service.py @@ -2,11 +2,12 @@ 用户排行榜缓存服务 用于缓存用户排行榜数据,减轻数据库压力 """ + from __future__ import annotations import asyncio +from datetime import UTC, datetime import json -from datetime import UTC, datetime, timedelta from typing import TYPE_CHECKING, Literal from app.config import settings @@ -24,7 +25,7 @@ if TYPE_CHECKING: class DateTimeEncoder(json.JSONEncoder): """自定义 JSON 编码器,支持 datetime 序列化""" - + def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() @@ -33,7 +34,9 @@ class DateTimeEncoder(json.JSONEncoder): def safe_json_dumps(data) -> str: """安全的 JSON 序列化,支持 datetime 对象""" - return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")) + return json.dumps( + data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":") + ) class RankingCacheService: @@ -84,7 +87,7 @@ class RankingCacheService: try: cache_key = self._get_cache_key(ruleset, type, country, page) cached_data = await self.redis.get(cache_key) - + if cached_data: return json.loads(cached_data) return None @@ -107,11 +110,7 @@ class RankingCacheService: # 使用配置文件的TTL设置 if ttl is None: ttl = settings.ranking_cache_expire_minutes * 60 - await self.redis.set( - cache_key, - safe_json_dumps(ranking_data), - ex=ttl - ) + await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl) logger.debug(f"Cached ranking data for {cache_key}") except Exception as e: logger.error(f"Error caching ranking: {e}") @@ -126,7 +125,7 @@ class RankingCacheService: try: cache_key = self._get_stats_cache_key(ruleset, type, country) cached_data = await self.redis.get(cache_key) - + if cached_data: return json.loads(cached_data) return None @@ -148,11 +147,7 @@ class RankingCacheService: # 使用配置文件的TTL设置,统计信息缓存时间更长 if ttl is None: ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间 - await self.redis.set( - cache_key, - safe_json_dumps(stats), - ex=ttl - ) + await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl) logger.debug(f"Cached stats for {cache_key}") except Exception as e: logger.error(f"Error caching stats: {e}") @@ -166,7 +161,7 @@ class RankingCacheService: try: cache_key = self._get_country_cache_key(ruleset, page) cached_data = await self.redis.get(cache_key) - + if cached_data: return json.loads(cached_data) return None @@ -186,11 +181,7 @@ class RankingCacheService: cache_key = self._get_country_cache_key(ruleset, page) if ttl is None: ttl = settings.ranking_cache_expire_minutes * 60 - await self.redis.set( - cache_key, - safe_json_dumps(ranking_data), - ex=ttl - ) + await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl) logger.debug(f"Cached country ranking data for {cache_key}") except Exception as e: logger.error(f"Error caching country ranking: {e}") @@ -200,7 +191,7 @@ class RankingCacheService: try: cache_key = self._get_country_stats_cache_key(ruleset) cached_data = await self.redis.get(cache_key) - + if cached_data: return json.loads(cached_data) return None @@ -219,11 +210,7 @@ class RankingCacheService: cache_key = self._get_country_stats_cache_key(ruleset) if ttl is None: ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间 - await self.redis.set( - cache_key, - safe_json_dumps(stats), - ex=ttl - ) + await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl) logger.debug(f"Cached country stats for {cache_key}") except Exception as e: logger.error(f"Error caching country stats: {e}") @@ -238,7 +225,9 @@ class RankingCacheService: ) -> None: """刷新排行榜缓存""" if self._refreshing: - logger.info(f"Ranking cache refresh already in progress for {ruleset}:{type}") + logger.info( + f"Ranking cache refresh already in progress for {ruleset}:{type}" + ) return # 使用配置文件的设置 @@ -248,7 +237,7 @@ class RankingCacheService: self._refreshing = True try: logger.info(f"Starting ranking cache refresh for {ruleset}:{type}") - + # 构建查询条件 wheres = [ col(UserStatistics.mode) == ruleset, @@ -256,20 +245,22 @@ class RankingCacheService: col(UserStatistics.is_ranked).is_(True), ] include = ["user"] - + if type == "performance": order_by = col(UserStatistics.pp).desc() include.append("rank_change_since_30_days") else: order_by = col(UserStatistics.ranked_score).desc() - + if country: - wheres.append(col(UserStatistics.user).has(country_code=country.upper())) + wheres.append( + col(UserStatistics.user).has(country_code=country.upper()) + ) # 获取总用户数用于统计 total_users_query = select(UserStatistics).where(*wheres) total_users = len((await session.exec(total_users_query)).all()) - + # 计算统计信息 stats = { "total_users": total_users, @@ -278,7 +269,7 @@ class RankingCacheService: "ruleset": ruleset, "country": country, } - + # 缓存统计信息 await self.cache_stats(ruleset, type, stats, country) @@ -292,11 +283,11 @@ class RankingCacheService: .limit(50) .offset(50 * (page - 1)) ) - + statistics_data = statistics_list.all() if not statistics_data: break # 没有更多数据 - + # 转换为响应格式并确保正确序列化 ranking_data = [] for statistics in statistics_data: @@ -306,21 +297,19 @@ class RankingCacheService: # 将 UserStatisticsResp 转换为字典,处理所有序列化问题 user_dict = json.loads(user_stats_resp.model_dump_json()) ranking_data.append(user_dict) - + # 缓存这一页的数据 - await self.cache_ranking( - ruleset, type, ranking_data, country, page - ) - + await self.cache_ranking(ruleset, type, ranking_data, country, page) + # 添加延迟避免数据库过载 if page < max_pages: await asyncio.sleep(0.1) - + except Exception as e: logger.error(f"Error caching page {page} for {ruleset}:{type}: {e}") - + logger.info(f"Completed ranking cache refresh for {ruleset}:{type}") - + except Exception as e: logger.error(f"Ranking cache refresh failed for {ruleset}:{type}: {e}") finally: @@ -334,7 +323,9 @@ class RankingCacheService: ) -> None: """刷新地区排行榜缓存""" if self._refreshing: - logger.info(f"Country ranking cache refresh already in progress for {ruleset}") + logger.info( + f"Country ranking cache refresh already in progress for {ruleset}" + ) return if max_pages is None: @@ -343,17 +334,18 @@ class RankingCacheService: self._refreshing = True try: logger.info(f"Starting country ranking cache refresh for {ruleset}") - + # 获取所有国家 from app.database import User + countries = (await session.exec(select(User.country_code).distinct())).all() - + # 计算每个国家的统计数据 country_stats_list = [] for country in countries: if not country: # 跳过空的国家代码 continue - + statistics = ( await session.exec( select(UserStatistics).where( @@ -364,10 +356,10 @@ class RankingCacheService: ) ) ).all() - + if not statistics: # 跳过没有数据的国家 continue - + pp = 0 country_stats = { "code": country, @@ -376,48 +368,48 @@ class RankingCacheService: "ranked_score": 0, "performance": 0, } - + for stat in statistics: country_stats["active_users"] += 1 country_stats["play_count"] += stat.play_count country_stats["ranked_score"] += stat.ranked_score pp += stat.pp - + country_stats["performance"] = round(pp) country_stats_list.append(country_stats) - + # 按表现分排序 country_stats_list.sort(key=lambda x: x["performance"], reverse=True) - + # 计算统计信息 stats = { "total_countries": len(country_stats_list), "last_updated": datetime.now(UTC).isoformat(), "ruleset": ruleset, } - + # 缓存统计信息 await self.cache_country_stats(ruleset, stats) - + # 分页缓存数据(每页50个国家) page_size = 50 for page in range(1, max_pages + 1): start_idx = (page - 1) * page_size end_idx = start_idx + page_size - + page_data = country_stats_list[start_idx:end_idx] if not page_data: break # 没有更多数据 - + # 缓存这一页的数据 await self.cache_country_ranking(ruleset, page_data, page) - + # 添加延迟避免Redis过载 if page < max_pages and page_data: await asyncio.sleep(0.1) - + logger.info(f"Completed country ranking cache refresh for {ruleset}") - + except Exception as e: logger.error(f"Country ranking cache refresh failed for {ruleset}: {e}") finally: @@ -427,11 +419,12 @@ class RankingCacheService: """刷新所有排行榜缓存""" game_modes = [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA] ranking_types: list[Literal["performance", "score"]] = ["performance", "score"] - + # 获取需要缓存的国家列表(活跃用户数量前20的国家) from app.database import User + from sqlmodel import func - + countries_query = ( await session.exec( select(User.country_code, func.count().label("user_count")) @@ -441,38 +434,40 @@ class RankingCacheService: .limit(settings.ranking_cache_top_countries) ) ).all() - + top_countries = [country for country, _ in countries_query] - + refresh_tasks = [] - + # 全球排行榜 for mode in game_modes: for ranking_type in ranking_types: task = self.refresh_ranking_cache(session, mode, ranking_type) refresh_tasks.append(task) - + # 国家排行榜(仅前20个国家) for country in top_countries: for mode in game_modes: for ranking_type in ranking_types: - task = self.refresh_ranking_cache(session, mode, ranking_type, country) + task = self.refresh_ranking_cache( + session, mode, ranking_type, country + ) refresh_tasks.append(task) - + # 地区排行榜 for mode in game_modes: task = self.refresh_country_ranking_cache(session, mode) refresh_tasks.append(task) - + # 并发执行刷新任务,但限制并发数 semaphore = asyncio.Semaphore(5) # 最多同时5个任务 - + async def bounded_refresh(task): async with semaphore: await task - + bounded_tasks = [bounded_refresh(task) for task in refresh_tasks] - + try: await asyncio.gather(*bounded_tasks, return_exceptions=True) logger.info("All ranking cache refresh completed") @@ -489,7 +484,7 @@ class RankingCacheService: """使缓存失效""" try: deleted_keys = 0 - + if ruleset and type: # 删除特定的用户排行榜缓存 country_part = f":{country.upper()}" if country else "" @@ -498,12 +493,14 @@ class RankingCacheService: if keys: await self.redis.delete(*keys) deleted_keys += len(keys) - logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}") + logger.info( + f"Invalidated {len(keys)} cache keys for {ruleset}:{type}" + ) elif ruleset: # 删除特定游戏模式的所有缓存 patterns = [ f"ranking:{ruleset}:*", - f"country_ranking:{ruleset}:*" if include_country_ranking else None + f"country_ranking:{ruleset}:*" if include_country_ranking else None, ] for pattern in patterns: if pattern: @@ -516,15 +513,15 @@ class RankingCacheService: patterns = ["ranking:*"] if include_country_ranking: patterns.append("country_ranking:*") - + for pattern in patterns: keys = await self.redis.keys(pattern) if keys: await self.redis.delete(*keys) deleted_keys += len(keys) - + logger.info(f"Invalidated all {deleted_keys} ranking cache keys") - + except Exception as e: logger.error(f"Error invalidating cache: {e}") @@ -535,7 +532,7 @@ class RankingCacheService: pattern = f"country_ranking:{ruleset}:*" else: pattern = "country_ranking:*" - + keys = await self.redis.keys(pattern) if keys: await self.redis.delete(*keys) @@ -550,10 +547,10 @@ class RankingCacheService: ranking_keys = await self.redis.keys("ranking:*") # 获取地区排行榜缓存 country_keys = await self.redis.keys("country_ranking:*") - + total_keys = ranking_keys + country_keys total_size = 0 - + for key in total_keys[:100]: # 限制检查数量以避免性能问题 try: size = await self.redis.memory_usage(key) @@ -561,7 +558,7 @@ class RankingCacheService: total_size += size except Exception: continue - + return { "cached_user_rankings": len(ranking_keys), "cached_country_rankings": len(country_keys), diff --git a/app/service/user_cache_service.py b/app/service/user_cache_service.py index ed0926e..f999fe8 100644 --- a/app/service/user_cache_service.py +++ b/app/service/user_cache_service.py @@ -2,24 +2,23 @@ 用户缓存服务 用于缓存用户信息,提供热缓存和实时刷新功能 """ + from __future__ import annotations -import asyncio +from datetime import datetime import json -from datetime import UTC, datetime, timedelta -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any from app.config import settings from app.const import BANCHOBOT_ID from app.database import User, UserResp from app.database.lazer_user import SEARCH_INCLUDED -from app.database.pp_best_score import PPBestScore -from app.database.score import Score, ScoreResp +from app.database.score import ScoreResp from app.log import logger from app.models.score import GameMode from redis.asyncio import Redis -from sqlmodel import col, exists, select +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: @@ -28,7 +27,7 @@ if TYPE_CHECKING: class DateTimeEncoder(json.JSONEncoder): """自定义 JSON 编码器,支持 datetime 序列化""" - + def default(self, obj): if isinstance(obj, datetime): return obj.isoformat() @@ -48,16 +47,16 @@ class UserCacheService: self._refreshing = False self._background_tasks: set = set() - def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str: + def _get_v1_user_cache_key( + self, user_id: int, ruleset: GameMode | None = None + ) -> str: """生成 V1 用户缓存键""" if ruleset: return f"v1_user:{user_id}:ruleset:{ruleset}" return f"v1_user:{user_id}" async def get_v1_user_from_cache( - self, - user_id: int, - ruleset: GameMode | None = None + self, user_id: int, ruleset: GameMode | None = None ) -> dict | None: """从缓存获取 V1 用户信息""" try: @@ -72,11 +71,11 @@ class UserCacheService: return None async def cache_v1_user( - self, - user_data: dict, + self, + user_data: dict, user_id: int, ruleset: GameMode | None = None, - expire_seconds: int | None = None + expire_seconds: int | None = None, ): """缓存 V1 用户信息""" try: @@ -97,7 +96,9 @@ class UserCacheService: keys = await self.redis.keys(pattern) if keys: await self.redis.delete(*keys) - logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}") + logger.info( + f"Invalidated {len(keys)} V1 cache entries for user {user_id}" + ) except Exception as e: logger.error(f"Error invalidating V1 user cache: {e}") @@ -108,31 +109,25 @@ class UserCacheService: return f"user:{user_id}" def _get_user_scores_cache_key( - self, - user_id: int, - score_type: str, + self, + user_id: int, + score_type: str, mode: GameMode | None = None, limit: int = 100, - offset: int = 0 + offset: int = 0, ) -> str: """生成用户成绩缓存键""" mode_part = f":{mode}" if mode else "" return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}" def _get_user_beatmapsets_cache_key( - self, - user_id: int, - beatmapset_type: str, - limit: int = 100, - offset: int = 0 + self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0 ) -> str: """生成用户谱面集缓存键""" return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}" async def get_user_from_cache( - self, - user_id: int, - ruleset: GameMode | None = None + self, user_id: int, ruleset: GameMode | None = None ) -> UserResp | None: """从缓存获取用户信息""" try: @@ -148,10 +143,10 @@ class UserCacheService: return None async def cache_user( - self, - user_resp: UserResp, + self, + user_resp: UserResp, ruleset: GameMode | None = None, - expire_seconds: int | None = None + expire_seconds: int | None = None, ): """缓存用户信息""" try: @@ -173,14 +168,18 @@ class UserCacheService: score_type: str, mode: GameMode | None = None, limit: int = 100, - offset: int = 0 + offset: int = 0, ) -> list[ScoreResp] | None: """从缓存获取用户成绩""" try: - cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset) + cache_key = self._get_user_scores_cache_key( + user_id, score_type, mode, limit, offset + ) cached_data = await self.redis.get(cache_key) if cached_data: - logger.debug(f"User scores cache hit for user {user_id}, type {score_type}") + logger.debug( + f"User scores cache hit for user {user_id}, type {score_type}" + ) data = json.loads(cached_data) return [ScoreResp(**score_data) for score_data in data] return None @@ -196,34 +195,38 @@ class UserCacheService: mode: GameMode | None = None, limit: int = 100, offset: int = 0, - expire_seconds: int | None = None + expire_seconds: int | None = None, ): """缓存用户成绩""" try: if expire_seconds is None: expire_seconds = settings.user_scores_cache_expire_seconds - cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset) + cache_key = self._get_user_scores_cache_key( + user_id, score_type, mode, limit, offset + ) # 使用 model_dump_json() 而不是 model_dump() + json.dumps() scores_json_list = [score.model_dump_json() for score in scores] cached_data = f"[{','.join(scores_json_list)}]" await self.redis.setex(cache_key, expire_seconds, cached_data) - logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s") + logger.debug( + f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s" + ) except Exception as e: logger.error(f"Error caching user scores: {e}") async def get_user_beatmapsets_from_cache( - self, - user_id: int, - beatmapset_type: str, - limit: int = 100, - offset: int = 0 + self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0 ) -> list[Any] | None: """从缓存获取用户谱面集""" try: - cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset) + cache_key = self._get_user_beatmapsets_cache_key( + user_id, beatmapset_type, limit, offset + ) cached_data = await self.redis.get(cache_key) if cached_data: - logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}") + logger.debug( + f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}" + ) return json.loads(cached_data) return None except Exception as e: @@ -237,23 +240,27 @@ class UserCacheService: beatmapsets: list[Any], limit: int = 100, offset: int = 0, - expire_seconds: int | None = None + expire_seconds: int | None = None, ): """缓存用户谱面集""" try: if expire_seconds is None: expire_seconds = settings.user_beatmapsets_cache_expire_seconds - cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset) + cache_key = self._get_user_beatmapsets_cache_key( + user_id, beatmapset_type, limit, offset + ) # 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps serialized_beatmapsets = [] for bms in beatmapsets: - if hasattr(bms, 'model_dump_json'): + if hasattr(bms, "model_dump_json"): serialized_beatmapsets.append(bms.model_dump_json()) else: serialized_beatmapsets.append(safe_json_dumps(bms)) cached_data = f"[{','.join(serialized_beatmapsets)}]" await self.redis.setex(cache_key, expire_seconds, cached_data) - logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s") + logger.debug( + f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s" + ) except Exception as e: logger.error(f"Error caching user beatmapsets: {e}") @@ -269,7 +276,9 @@ class UserCacheService: except Exception as e: logger.error(f"Error invalidating user cache: {e}") - async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None): + async def invalidate_user_scores_cache( + self, user_id: int, mode: GameMode | None = None + ): """使用户成绩缓存失效""" try: # 删除用户成绩相关缓存 @@ -278,7 +287,9 @@ class UserCacheService: keys = await self.redis.keys(pattern) if keys: await self.redis.delete(*keys) - logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}") + logger.info( + f"Invalidated {len(keys)} score cache entries for user {user_id}" + ) except Exception as e: logger.error(f"Error invalidating user scores cache: {e}") @@ -290,12 +301,10 @@ class UserCacheService: self._refreshing = True try: logger.info(f"Preloading cache for {len(user_ids)} users") - + # 批量获取用户 users = ( - await session.exec( - select(User).where(col(User.id).in_(user_ids)) - ) + await session.exec(select(User).where(col(User.id).in_(user_ids))) ).all() # 串行缓存用户信息,避免并发数据库访问问题 @@ -324,10 +333,7 @@ class UserCacheService: logger.error(f"Error caching single user {user.id}: {e}") async def refresh_user_cache_on_score_submit( - self, - session: AsyncSession, - user_id: int, - mode: GameMode + self, session: AsyncSession, user_id: int, mode: GameMode ): """成绩提交后刷新用户缓存""" try: @@ -335,7 +341,7 @@ class UserCacheService: await self.invalidate_user_cache(user_id) await self.invalidate_v1_user_cache(user_id) await self.invalidate_user_scores_cache(user_id, mode) - + # 立即重新加载用户信息 user = await session.get(User, user_id) if user and user.id != BANCHOBOT_ID: @@ -351,7 +357,7 @@ class UserCacheService: v1_user_keys = await self.redis.keys("v1_user:*") all_keys = user_keys + v1_user_keys total_size = 0 - + for key in all_keys[:100]: # 限制检查数量 try: size = await self.redis.memory_usage(key) @@ -359,12 +365,22 @@ class UserCacheService: total_size += size except Exception: continue - + return { - "cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]), - "cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]), + "cached_users": len( + [ + k + for k in user_keys + if ":scores:" not in k and ":beatmapsets:" not in k + ] + ), + "cached_v1_users": len( + [k for k in v1_user_keys if ":scores:" not in k] + ), "cached_user_scores": len([k for k in user_keys if ":scores:" in k]), - "cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]), + "cached_user_beatmapsets": len( + [k for k in user_keys if ":beatmapsets:" in k] + ), "total_cached_entries": len(all_keys), "estimated_total_size_mb": ( round(total_size / 1024 / 1024, 2) if total_size > 0 else 0 diff --git a/test_ranking_serialization.py b/test_ranking_serialization.py index 91e767a..e9b81fd 100644 --- a/test_ranking_serialization.py +++ b/test_ranking_serialization.py @@ -1,92 +1,98 @@ #!/usr/bin/env python3 """测试排行榜缓存序列化修复""" -import asyncio +from __future__ import annotations + +from datetime import UTC, datetime import warnings -from datetime import datetime, UTC -from app.service.ranking_cache_service import DateTimeEncoder, safe_json_dumps + +from app.service.ranking_cache_service import safe_json_dumps def test_datetime_serialization(): """测试 datetime 序列化""" print("🧪 测试 datetime 序列化...") - + test_data = { "id": 1, "username": "test_user", "last_updated": datetime.now(UTC), "join_date": datetime(2020, 1, 1, tzinfo=UTC), - "stats": { - "pp": 1000.0, - "accuracy": 95.5, - "last_played": datetime.now(UTC) - } + "stats": {"pp": 1000.0, "accuracy": 95.5, "last_played": datetime.now(UTC)}, } - + try: # 测试自定义编码器 json_result = safe_json_dumps(test_data) print("✅ datetime 序列化成功") print(f" 序列化结果长度: {len(json_result)}") - + # 验证可以重新解析 import json + parsed = json.loads(json_result) assert "last_updated" in parsed assert isinstance(parsed["last_updated"], str) print("✅ 序列化的 JSON 可以正确解析") - + except Exception as e: print(f"❌ datetime 序列化测试失败: {e}") import traceback + traceback.print_exc() def test_boolean_serialization(): """测试布尔值序列化""" print("\n🧪 测试布尔值序列化...") - + test_data = { "user": { - "is_active": 1, # 数据库中的整数布尔值 - "is_supporter": 0, # 数据库中的整数布尔值 - "has_profile": True, # 正常布尔值 + "is_active": 1, # 数据库中的整数布尔值 + "is_supporter": 0, # 数据库中的整数布尔值 + "has_profile": True, # 正常布尔值 }, "stats": { - "is_ranked": 1, # 数据库中的整数布尔值 - "verified": False, # 正常布尔值 - } + "is_ranked": 1, # 数据库中的整数布尔值 + "verified": False, # 正常布尔值 + }, } - + try: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") json_result = safe_json_dumps(test_data) - + # 检查是否有 Pydantic 序列化警告 - pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)] + pydantic_warnings = [ + warning + for warning in w + if "PydanticSerializationUnexpectedValue" in str(warning.message) + ] if pydantic_warnings: print(f"⚠️ 仍有 {len(pydantic_warnings)} 个布尔值序列化警告") for warning in pydantic_warnings: print(f" {warning.message}") else: print("✅ 布尔值序列化无警告") - + # 验证序列化结果 import json + parsed = json.loads(json_result) print(f"✅ 布尔值序列化成功,结果: {parsed}") - + except Exception as e: print(f"❌ 布尔值序列化测试失败: {e}") import traceback + traceback.print_exc() def test_complex_ranking_data(): """测试复杂的排行榜数据序列化""" print("\n🧪 测试复杂排行榜数据序列化...") - + # 模拟排行榜数据结构 ranking_data = [ { @@ -95,8 +101,8 @@ def test_complex_ranking_data(): "id": 1, "username": "player1", "country_code": "US", - "is_active": 1, # 整数布尔值 - "is_supporter": 0, # 整数布尔值 + "is_active": 1, # 整数布尔值 + "is_supporter": 0, # 整数布尔值 "join_date": datetime(2020, 1, 1, tzinfo=UTC), "last_visit": datetime.now(UTC), }, @@ -104,9 +110,9 @@ def test_complex_ranking_data(): "pp": 8000.0, "accuracy": 98.5, "play_count": 5000, - "is_ranked": 1, # 整数布尔值 + "is_ranked": 1, # 整数布尔值 "last_updated": datetime.now(UTC), - } + }, }, { "id": 2, @@ -125,41 +131,47 @@ def test_complex_ranking_data(): "play_count": 4500, "is_ranked": 1, "last_updated": datetime.now(UTC), - } - } + }, + }, ] - + try: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") json_result = safe_json_dumps(ranking_data) - - pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)] + + pydantic_warnings = [ + warning + for warning in w + if "PydanticSerializationUnexpectedValue" in str(warning.message) + ] if pydantic_warnings: print(f"⚠️ 仍有 {len(pydantic_warnings)} 个序列化警告") for warning in pydantic_warnings: print(f" {warning.message}") else: print("✅ 复杂排行榜数据序列化无警告") - + # 验证序列化结果 import json + parsed = json.loads(json_result) assert len(parsed) == 2 assert parsed[0]["user"]["username"] == "player1" print(f"✅ 复杂排行榜数据序列化成功,包含 {len(parsed)} 个条目") - + except Exception as e: print(f"❌ 复杂排行榜数据序列化测试失败: {e}") import traceback + traceback.print_exc() if __name__ == "__main__": print("🚀 开始排行榜缓存序列化测试\n") - + test_datetime_serialization() - test_boolean_serialization() + test_boolean_serialization() test_complex_ranking_data() - + print("\n🎉 排行榜缓存序列化测试完成!")