From b316511cf542e3b3ea5f52d3681725a7084e148f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= <74496778+GooGuJiang@users.noreply.github.com> Date: Thu, 21 Aug 2025 21:35:08 +0800 Subject: [PATCH] add ranking cache --- app/config.py | 7 + app/router/v2/ranking.py | 221 +++++++++- app/scheduler/cache_scheduler.py | 83 +++- app/service/ranking_cache_service.py | 586 +++++++++++++++++++++++++++ 4 files changed, 879 insertions(+), 18 deletions(-) create mode 100644 app/service/ranking_cache_service.py diff --git a/app/config.py b/app/config.py index c12f405..1c679f5 100644 --- a/app/config.py +++ b/app/config.py @@ -142,6 +142,13 @@ class Settings(BaseSettings): beatmap_cache_expire_hours: int = 24 max_concurrent_pp_calculations: int = 10 enable_pp_calculation_threading: bool = True + + # 排行榜缓存设置 + enable_ranking_cache: bool = True + ranking_cache_expire_minutes: int = 10 # 排行榜缓存过期时间(分钟) + ranking_cache_refresh_interval_minutes: int = 10 # 排行榜缓存刷新间隔(分钟) + ranking_cache_max_pages: int = 20 # 最多缓存的页数 + ranking_cache_top_countries: int = 20 # 缓存前N个国家的排行榜 # 反作弊设置 suspicious_score_check: bool = True diff --git a/app/router/v2/ranking.py b/app/router/v2/ranking.py index 711a793..3bc640f 100644 --- a/app/router/v2/ranking.py +++ b/app/router/v2/ranking.py @@ -2,11 +2,13 @@ from __future__ import annotations from typing import Literal +from app.config import settings from app.database import User from app.database.statistics import UserStatistics, UserStatisticsResp from app.dependencies import get_current_user -from app.dependencies.database import Database +from app.dependencies.database import Database, get_redis from app.models.score import GameMode +from app.service.ranking_cache_service import get_ranking_cache_service from .router import router @@ -40,9 +42,27 @@ async def get_country_ranking( page: int = Query(1, ge=1, description="页码"), # TODO current_user: User = Security(get_current_user, scopes=["public"]), ): + # 获取 Redis 连接和缓存服务 + redis = get_redis() + cache_service = get_ranking_cache_service(redis) + + # 尝试从缓存获取数据 + cached_data = await cache_service.get_cached_country_ranking(ruleset, page) + + if cached_data: + # 从缓存返回数据 + return CountryResponse( + ranking=[CountryStatistics.model_validate(item) for item in cached_data] + ) + + # 缓存未命中,从数据库查询 response = CountryResponse(ranking=[]) countries = (await session.exec(select(User.country_code).distinct())).all() + for country in countries: + if not country: # 跳过空的国家代码 + continue + statistics = ( await session.exec( select(UserStatistics).where( @@ -53,6 +73,10 @@ async def get_country_ranking( ) ) ).all() + + if not statistics: # 跳过没有数据的国家 + continue + pp = 0 country_stats = CountryStatistics( code=country, @@ -68,7 +92,29 @@ async def get_country_ranking( pp += stat.pp country_stats.performance = round(pp) response.ranking.append(country_stats) + response.ranking.sort(key=lambda x: x.performance, reverse=True) + + # 分页处理 + page_size = 50 + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + + # 获取当前页的数据 + current_page_data = response.ranking[start_idx:end_idx] + + # 异步缓存数据(不等待完成) + cache_data = [item.model_dump() for item in current_page_data] + cache_task = cache_service.cache_country_ranking( + ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60 + ) + + # 创建后台任务来缓存数据 + import asyncio + asyncio.create_task(cache_task) + + # 返回当前页的结果 + response.ranking = current_page_data return response @@ -93,6 +139,24 @@ async def get_user_ranking( page: int = Query(1, ge=1, description="页码"), current_user: User = Security(get_current_user, scopes=["public"]), ): + # 获取 Redis 连接和缓存服务 + redis = get_redis() + cache_service = get_ranking_cache_service(redis) + + # 尝试从缓存获取数据 + cached_data = await cache_service.get_cached_ranking( + ruleset, type, country, page + ) + + if cached_data: + # 从缓存返回数据 + return TopUsersResponse( + ranking=[ + UserStatisticsResp.model_validate(item) for item in cached_data + ] + ) + + # 缓存未命中,从数据库查询 wheres = [ col(UserStatistics.mode) == ruleset, col(UserStatistics.pp) > 0, @@ -106,6 +170,7 @@ async def get_user_ranking( order_by = col(UserStatistics.ranked_score).desc() if country: wheres.append(col(UserStatistics.user).has(country_code=country.upper())) + statistics_list = await session.exec( select(UserStatistics) .where(*wheres) @@ -113,10 +178,154 @@ async def get_user_ranking( .limit(50) .offset(50 * (page - 1)) ) - resp = TopUsersResponse( - ranking=[ - await UserStatisticsResp.from_db(statistics, session, None, include) - for statistics in statistics_list - ] + + # 转换为响应格式 + ranking_data = [] + for statistics in statistics_list: + user_stats_resp = await UserStatisticsResp.from_db( + statistics, session, None, include + ) + ranking_data.append(user_stats_resp) + + # 异步缓存数据(不等待完成) + # 使用配置文件中的TTL设置 + cache_data = [item.model_dump() for item in ranking_data] + cache_task = cache_service.cache_ranking( + ruleset, type, cache_data, country, page, ttl=settings.ranking_cache_expire_minutes * 60 ) + + # 创建后台任务来缓存数据 + import asyncio + asyncio.create_task(cache_task) + + resp = TopUsersResponse(ranking=ranking_data) return resp + + +""" @router.post( + "/rankings/cache/refresh", + name="刷新排行榜缓存", + description="手动刷新排行榜缓存(管理员功能)", + tags=["排行榜", "管理"], +) +async def refresh_ranking_cache( + session: Database, + ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"), + type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"), + country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"), + include_country_ranking: bool = Query(True, description="是否包含地区排行榜"), + current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 +): + redis = get_redis() + cache_service = get_ranking_cache_service(redis) + + if ruleset and type: + # 刷新特定的用户排行榜 + await cache_service.refresh_ranking_cache(session, ruleset, type, country) + message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "") + + # 如果请求刷新地区排行榜 + if include_country_ranking and not country: # 地区排行榜不依赖于国家参数 + await cache_service.refresh_country_ranking_cache(session, ruleset) + message += f" and country ranking for {ruleset}" + + return {"message": message} + elif ruleset: + # 刷新特定游戏模式的所有排行榜 + ranking_types: list[Literal["performance", "score"]] = ["performance", "score"] + for ranking_type in ranking_types: + await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country) + + if include_country_ranking: + await cache_service.refresh_country_ranking_cache(session, ruleset) + + return {"message": f"Refreshed all ranking caches for {ruleset}"} + else: + # 刷新所有排行榜 + await cache_service.refresh_all_rankings(session) + return {"message": "Refreshed all ranking caches"} + + +@router.post( + "/rankings/{ruleset}/country/cache/refresh", + name="刷新地区排行榜缓存", + description="手动刷新地区排行榜缓存(管理员功能)", + tags=["排行榜", "管理"], +) +async def refresh_country_ranking_cache( + session: Database, + ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"), + current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 +): + redis = get_redis() + cache_service = get_ranking_cache_service(redis) + + await cache_service.refresh_country_ranking_cache(session, ruleset) + return {"message": f"Refreshed country ranking cache for {ruleset}"} + + +@router.delete( + "/rankings/cache", + name="清除排行榜缓存", + description="清除排行榜缓存(管理员功能)", + tags=["排行榜", "管理"], +) +async def clear_ranking_cache( + ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"), + type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"), + country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"), + include_country_ranking: bool = Query(True, description="是否包含地区排行榜"), + current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 +): + redis = get_redis() + cache_service = get_ranking_cache_service(redis) + + await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking) + + if ruleset and type: + message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "") + if include_country_ranking: + message += " and country ranking" + return {"message": message} + else: + message = "Cleared all ranking caches" + if include_country_ranking: + message += " including country rankings" + return {"message": message} + + +@router.delete( + "/rankings/{ruleset}/country/cache", + name="清除地区排行榜缓存", + description="清除地区排行榜缓存(管理员功能)", + tags=["排行榜", "管理"], +) +async def clear_country_ranking_cache( + ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"), + current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 +): + redis = get_redis() + cache_service = get_ranking_cache_service(redis) + + await cache_service.invalidate_country_cache(ruleset) + + if ruleset: + return {"message": f"Cleared country ranking cache for {ruleset}"} + else: + return {"message": "Cleared all country ranking caches"} + + +@router.get( + "/rankings/cache/stats", + name="获取排行榜缓存统计", + description="获取排行榜缓存统计信息(管理员功能)", + tags=["排行榜", "管理"], +) +async def get_ranking_cache_stats( + current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 +): + redis = get_redis() + cache_service = get_ranking_cache_service(redis) + + stats = await cache_service.get_cache_stats() + return stats """ \ No newline at end of file diff --git a/app/scheduler/cache_scheduler.py b/app/scheduler/cache_scheduler.py index 254d315..4c4e8ae 100644 --- a/app/scheduler/cache_scheduler.py +++ b/app/scheduler/cache_scheduler.py @@ -2,13 +2,14 @@ from __future__ import annotations import asyncio -from app.dependencies.database import get_redis +from app.config import settings +from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.log import logger -class BeatmapsetCacheScheduler: - """谱面集缓存调度器""" +class CacheScheduler: + """缓存调度器 - 统一管理各种缓存任务""" def __init__(self): self.running = False @@ -21,7 +22,7 @@ class BeatmapsetCacheScheduler: self.running = True self.task = asyncio.create_task(self._run_scheduler()) - logger.info("BeatmapsetCacheScheduler started") + logger.info("CacheScheduler started") async def stop(self): """停止调度器""" @@ -32,20 +33,47 @@ class BeatmapsetCacheScheduler: await self.task except asyncio.CancelledError: pass - logger.info("BeatmapsetCacheScheduler stopped") + logger.info("CacheScheduler stopped") async def _run_scheduler(self): """运行调度器主循环""" # 启动时立即执行一次预热 await self._warmup_cache() + + # 启动时执行一次排行榜缓存刷新 + await self._refresh_ranking_cache() + + beatmap_cache_counter = 0 + ranking_cache_counter = 0 + + # 从配置文件获取间隔设置 + check_interval = 5 * 60 # 5分钟检查间隔 + beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔 + ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取 + + beatmap_cache_cycles = beatmap_cache_interval // check_interval + ranking_cache_cycles = ranking_cache_interval // check_interval while self.running: try: - # 每30分钟执行一次缓存预热 - await asyncio.sleep(30 * 60) # 30分钟 + # 每5分钟检查一次 + await asyncio.sleep(check_interval) - if self.running: + if not self.running: + break + + beatmap_cache_counter += 1 + ranking_cache_counter += 1 + + # beatmap缓存预热 + if beatmap_cache_counter >= beatmap_cache_cycles: await self._warmup_cache() + beatmap_cache_counter = 0 + + # 排行榜缓存刷新 + if ranking_cache_counter >= ranking_cache_cycles: + await self._refresh_ranking_cache() + ranking_cache_counter = 0 except asyncio.CancelledError: break @@ -56,7 +84,7 @@ class BeatmapsetCacheScheduler: async def _warmup_cache(self): """执行缓存预热""" try: - logger.info("Starting cache warmup...") + logger.info("Starting beatmap cache warmup...") fetcher = await get_fetcher() redis = get_redis() @@ -64,14 +92,45 @@ class BeatmapsetCacheScheduler: # 预热主页缓存 await fetcher.warmup_homepage_cache(redis) - logger.info("Cache warmup completed successfully") + logger.info("Beatmap cache warmup completed successfully") except Exception as e: - logger.error(f"Cache warmup failed: {e}") + logger.error(f"Beatmap cache warmup failed: {e}") + + async def _refresh_ranking_cache(self): + """刷新排行榜缓存""" + try: + logger.info("Starting ranking cache refresh...") + + redis = get_redis() + + # 导入排行榜缓存服务 + from app.service.ranking_cache_service import ( + get_ranking_cache_service, + schedule_ranking_refresh_task, + ) + + # 获取数据库会话 + async for session in get_db(): + await schedule_ranking_refresh_task(session, redis) + break # 只需要一次会话 + + logger.info("Ranking cache refresh completed successfully") + + except Exception as e: + logger.error(f"Ranking cache refresh failed: {e}") + + +# Beatmap缓存调度器(保持向后兼容) +class BeatmapsetCacheScheduler(CacheScheduler): + """谱面集缓存调度器 - 为了向后兼容""" + pass # 全局调度器实例 -cache_scheduler = BeatmapsetCacheScheduler() +cache_scheduler = CacheScheduler() +# 保持向后兼容的别名 +beatmapset_cache_scheduler = BeatmapsetCacheScheduler() async def start_cache_scheduler(): diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py new file mode 100644 index 0000000..a2b9894 --- /dev/null +++ b/app/service/ranking_cache_service.py @@ -0,0 +1,586 @@ +""" +用户排行榜缓存服务 +用于缓存用户排行榜数据,减轻数据库压力 +""" +from __future__ import annotations + +import asyncio +import json +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING, Literal + +from app.config import settings +from app.database.statistics import UserStatistics, UserStatisticsResp +from app.log import logger +from app.models.score import GameMode + +from redis.asyncio import Redis +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + pass + + +class RankingCacheService: + """用户排行榜缓存服务""" + + def __init__(self, redis: Redis): + self.redis = redis + self._refreshing = False + self._background_tasks: set = set() + + def _get_cache_key( + self, + ruleset: GameMode, + type: Literal["performance", "score"], + country: str | None = None, + page: int = 1, + ) -> str: + """生成缓存键""" + country_part = f":{country.upper()}" if country else "" + return f"ranking:{ruleset}:{type}{country_part}:page:{page}" + + def _get_stats_cache_key( + self, + ruleset: GameMode, + type: Literal["performance", "score"], + country: str | None = None, + ) -> str: + """生成统计信息缓存键""" + country_part = f":{country.upper()}" if country else "" + return f"ranking:stats:{ruleset}:{type}{country_part}" + + def _get_country_cache_key(self, ruleset: GameMode, page: int = 1) -> str: + """生成地区排行榜缓存键""" + return f"country_ranking:{ruleset}:page:{page}" + + def _get_country_stats_cache_key(self, ruleset: GameMode) -> str: + """生成地区排行榜统计信息缓存键""" + return f"country_ranking:stats:{ruleset}" + + async def get_cached_ranking( + self, + ruleset: GameMode, + type: Literal["performance", "score"], + country: str | None = None, + page: int = 1, + ) -> list[dict] | None: + """获取缓存的排行榜数据""" + try: + cache_key = self._get_cache_key(ruleset, type, country, page) + cached_data = await self.redis.get(cache_key) + + if cached_data: + return json.loads(cached_data) + return None + except Exception as e: + logger.error(f"Error getting cached ranking: {e}") + return None + + async def cache_ranking( + self, + ruleset: GameMode, + type: Literal["performance", "score"], + ranking_data: list[dict], + country: str | None = None, + page: int = 1, + ttl: int | None = None, # 允许为None以使用配置文件的默认值 + ) -> None: + """缓存排行榜数据""" + try: + cache_key = self._get_cache_key(ruleset, type, country, page) + # 使用配置文件的TTL设置 + if ttl is None: + ttl = settings.ranking_cache_expire_minutes * 60 + await self.redis.set( + cache_key, + json.dumps(ranking_data, separators=(",", ":")), + ex=ttl + ) + logger.debug(f"Cached ranking data for {cache_key}") + except Exception as e: + logger.error(f"Error caching ranking: {e}") + + async def get_cached_stats( + self, + ruleset: GameMode, + type: Literal["performance", "score"], + country: str | None = None, + ) -> dict | None: + """获取缓存的统计信息""" + try: + cache_key = self._get_stats_cache_key(ruleset, type, country) + cached_data = await self.redis.get(cache_key) + + if cached_data: + return json.loads(cached_data) + return None + except Exception as e: + logger.error(f"Error getting cached stats: {e}") + return None + + async def cache_stats( + self, + ruleset: GameMode, + type: Literal["performance", "score"], + stats: dict, + country: str | None = None, + ttl: int | None = None, # 允许为None以使用配置文件的默认值 + ) -> None: + """缓存统计信息""" + try: + cache_key = self._get_stats_cache_key(ruleset, type, country) + # 使用配置文件的TTL设置,统计信息缓存时间更长 + if ttl is None: + ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间 + await self.redis.set( + cache_key, + json.dumps(stats, separators=(",", ":")), + ex=ttl + ) + logger.debug(f"Cached stats for {cache_key}") + except Exception as e: + logger.error(f"Error caching stats: {e}") + + async def get_cached_country_ranking( + self, + ruleset: GameMode, + page: int = 1, + ) -> list[dict] | None: + """获取缓存的地区排行榜数据""" + try: + cache_key = self._get_country_cache_key(ruleset, page) + cached_data = await self.redis.get(cache_key) + + if cached_data: + return json.loads(cached_data) + return None + except Exception as e: + logger.error(f"Error getting cached country ranking: {e}") + return None + + async def cache_country_ranking( + self, + ruleset: GameMode, + ranking_data: list[dict], + page: int = 1, + ttl: int | None = None, + ) -> None: + """缓存地区排行榜数据""" + try: + cache_key = self._get_country_cache_key(ruleset, page) + if ttl is None: + ttl = settings.ranking_cache_expire_minutes * 60 + await self.redis.set( + cache_key, + json.dumps(ranking_data, separators=(",", ":")), + ex=ttl + ) + logger.debug(f"Cached country ranking data for {cache_key}") + except Exception as e: + logger.error(f"Error caching country ranking: {e}") + + async def get_cached_country_stats(self, ruleset: GameMode) -> dict | None: + """获取缓存的地区排行榜统计信息""" + try: + cache_key = self._get_country_stats_cache_key(ruleset) + cached_data = await self.redis.get(cache_key) + + if cached_data: + return json.loads(cached_data) + return None + except Exception as e: + logger.error(f"Error getting cached country stats: {e}") + return None + + async def cache_country_stats( + self, + ruleset: GameMode, + stats: dict, + ttl: int | None = None, + ) -> None: + """缓存地区排行榜统计信息""" + try: + cache_key = self._get_country_stats_cache_key(ruleset) + if ttl is None: + ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间 + await self.redis.set( + cache_key, + json.dumps(stats, separators=(",", ":")), + ex=ttl + ) + logger.debug(f"Cached country stats for {cache_key}") + except Exception as e: + logger.error(f"Error caching country stats: {e}") + + async def refresh_ranking_cache( + self, + session: AsyncSession, + ruleset: GameMode, + type: Literal["performance", "score"], + country: str | None = None, + max_pages: int | None = None, # 允许为None以使用配置文件的默认值 + ) -> None: + """刷新排行榜缓存""" + if self._refreshing: + logger.info(f"Ranking cache refresh already in progress for {ruleset}:{type}") + return + + # 使用配置文件的设置 + if max_pages is None: + max_pages = settings.ranking_cache_max_pages + + self._refreshing = True + try: + logger.info(f"Starting ranking cache refresh for {ruleset}:{type}") + + # 构建查询条件 + wheres = [ + col(UserStatistics.mode) == ruleset, + col(UserStatistics.pp) > 0, + col(UserStatistics.is_ranked).is_(True), + ] + include = ["user"] + + if type == "performance": + order_by = col(UserStatistics.pp).desc() + include.append("rank_change_since_30_days") + else: + order_by = col(UserStatistics.ranked_score).desc() + + if country: + wheres.append(col(UserStatistics.user).has(country_code=country.upper())) + + # 获取总用户数用于统计 + total_users_query = select(UserStatistics).where(*wheres) + total_users = len((await session.exec(total_users_query)).all()) + + # 计算统计信息 + stats = { + "total_users": total_users, + "last_updated": datetime.now(UTC).isoformat(), + "type": type, + "ruleset": ruleset, + "country": country, + } + + # 缓存统计信息 + await self.cache_stats(ruleset, type, stats, country) + + # 分页缓存数据 + for page in range(1, max_pages + 1): + try: + statistics_list = await session.exec( + select(UserStatistics) + .where(*wheres) + .order_by(order_by) + .limit(50) + .offset(50 * (page - 1)) + ) + + statistics_data = statistics_list.all() + if not statistics_data: + break # 没有更多数据 + + # 转换为响应格式 + ranking_data = [] + for statistics in statistics_data: + user_stats_resp = await UserStatisticsResp.from_db( + statistics, session, None, include + ) + ranking_data.append(user_stats_resp.model_dump()) + + # 缓存这一页的数据 + await self.cache_ranking( + ruleset, type, ranking_data, country, page + ) + + # 添加延迟避免数据库过载 + if page < max_pages: + await asyncio.sleep(0.1) + + except Exception as e: + logger.error(f"Error caching page {page} for {ruleset}:{type}: {e}") + + logger.info(f"Completed ranking cache refresh for {ruleset}:{type}") + + except Exception as e: + logger.error(f"Ranking cache refresh failed for {ruleset}:{type}: {e}") + finally: + self._refreshing = False + + async def refresh_country_ranking_cache( + self, + session: AsyncSession, + ruleset: GameMode, + max_pages: int | None = None, + ) -> None: + """刷新地区排行榜缓存""" + if self._refreshing: + logger.info(f"Country ranking cache refresh already in progress for {ruleset}") + return + + if max_pages is None: + max_pages = settings.ranking_cache_max_pages + + self._refreshing = True + try: + logger.info(f"Starting country ranking cache refresh for {ruleset}") + + # 获取所有国家 + from app.database import User + countries = (await session.exec(select(User.country_code).distinct())).all() + + # 计算每个国家的统计数据 + country_stats_list = [] + for country in countries: + if not country: # 跳过空的国家代码 + continue + + statistics = ( + await session.exec( + select(UserStatistics).where( + UserStatistics.mode == ruleset, + UserStatistics.pp > 0, + col(UserStatistics.user).has(country_code=country), + col(UserStatistics.user).has(is_active=True), + ) + ) + ).all() + + if not statistics: # 跳过没有数据的国家 + continue + + pp = 0 + country_stats = { + "code": country, + "active_users": 0, + "play_count": 0, + "ranked_score": 0, + "performance": 0, + } + + for stat in statistics: + country_stats["active_users"] += 1 + country_stats["play_count"] += stat.play_count + country_stats["ranked_score"] += stat.ranked_score + pp += stat.pp + + country_stats["performance"] = round(pp) + country_stats_list.append(country_stats) + + # 按表现分排序 + country_stats_list.sort(key=lambda x: x["performance"], reverse=True) + + # 计算统计信息 + stats = { + "total_countries": len(country_stats_list), + "last_updated": datetime.now(UTC).isoformat(), + "ruleset": ruleset, + } + + # 缓存统计信息 + await self.cache_country_stats(ruleset, stats) + + # 分页缓存数据(每页50个国家) + page_size = 50 + for page in range(1, max_pages + 1): + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + + page_data = country_stats_list[start_idx:end_idx] + if not page_data: + break # 没有更多数据 + + # 缓存这一页的数据 + await self.cache_country_ranking(ruleset, page_data, page) + + # 添加延迟避免Redis过载 + if page < max_pages and page_data: + await asyncio.sleep(0.1) + + logger.info(f"Completed country ranking cache refresh for {ruleset}") + + except Exception as e: + logger.error(f"Country ranking cache refresh failed for {ruleset}: {e}") + finally: + self._refreshing = False + + async def refresh_all_rankings(self, session: AsyncSession) -> None: + """刷新所有排行榜缓存""" + game_modes = [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA] + ranking_types: list[Literal["performance", "score"]] = ["performance", "score"] + + # 获取需要缓存的国家列表(活跃用户数量前20的国家) + from app.database import User + from sqlmodel import func + + countries_query = ( + await session.exec( + select(User.country_code, func.count().label("user_count")) + .where(col(User.is_active).is_(True)) + .group_by(User.country_code) + .order_by(col("user_count").desc()) + .limit(settings.ranking_cache_top_countries) + ) + ).all() + + top_countries = [country for country, _ in countries_query] + + refresh_tasks = [] + + # 全球排行榜 + for mode in game_modes: + for ranking_type in ranking_types: + task = self.refresh_ranking_cache(session, mode, ranking_type) + refresh_tasks.append(task) + + # 国家排行榜(仅前20个国家) + for country in top_countries: + for mode in game_modes: + for ranking_type in ranking_types: + task = self.refresh_ranking_cache(session, mode, ranking_type, country) + refresh_tasks.append(task) + + # 地区排行榜 + for mode in game_modes: + task = self.refresh_country_ranking_cache(session, mode) + refresh_tasks.append(task) + + # 并发执行刷新任务,但限制并发数 + semaphore = asyncio.Semaphore(5) # 最多同时5个任务 + + async def bounded_refresh(task): + async with semaphore: + await task + + bounded_tasks = [bounded_refresh(task) for task in refresh_tasks] + + try: + await asyncio.gather(*bounded_tasks, return_exceptions=True) + logger.info("All ranking cache refresh completed") + except Exception as e: + logger.error(f"Error in batch ranking cache refresh: {e}") + + async def invalidate_cache( + self, + ruleset: GameMode | None = None, + type: Literal["performance", "score"] | None = None, + country: str | None = None, + include_country_ranking: bool = True, + ) -> None: + """使缓存失效""" + try: + deleted_keys = 0 + + if ruleset and type: + # 删除特定的用户排行榜缓存 + country_part = f":{country.upper()}" if country else "" + pattern = f"ranking:{ruleset}:{type}{country_part}:page:*" + keys = await self.redis.keys(pattern) + if keys: + await self.redis.delete(*keys) + deleted_keys += len(keys) + logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}") + elif ruleset: + # 删除特定游戏模式的所有缓存 + patterns = [ + f"ranking:{ruleset}:*", + f"country_ranking:{ruleset}:*" if include_country_ranking else None + ] + for pattern in patterns: + if pattern: + keys = await self.redis.keys(pattern) + if keys: + await self.redis.delete(*keys) + deleted_keys += len(keys) + else: + # 删除所有排行榜缓存 + patterns = ["ranking:*"] + if include_country_ranking: + patterns.append("country_ranking:*") + + for pattern in patterns: + keys = await self.redis.keys(pattern) + if keys: + await self.redis.delete(*keys) + deleted_keys += len(keys) + + logger.info(f"Invalidated all {deleted_keys} ranking cache keys") + + except Exception as e: + logger.error(f"Error invalidating cache: {e}") + + async def invalidate_country_cache(self, ruleset: GameMode | None = None) -> None: + """使地区排行榜缓存失效""" + try: + if ruleset: + pattern = f"country_ranking:{ruleset}:*" + else: + pattern = "country_ranking:*" + + keys = await self.redis.keys(pattern) + if keys: + await self.redis.delete(*keys) + logger.info(f"Invalidated {len(keys)} country ranking cache keys") + except Exception as e: + logger.error(f"Error invalidating country cache: {e}") + + async def get_cache_stats(self) -> dict: + """获取缓存统计信息""" + try: + # 获取用户排行榜缓存 + ranking_keys = await self.redis.keys("ranking:*") + # 获取地区排行榜缓存 + country_keys = await self.redis.keys("country_ranking:*") + + total_keys = ranking_keys + country_keys + total_size = 0 + + for key in total_keys[:100]: # 限制检查数量以避免性能问题 + try: + size = await self.redis.memory_usage(key) + if size: + total_size += size + except Exception: + continue + + return { + "cached_user_rankings": len(ranking_keys), + "cached_country_rankings": len(country_keys), + "total_cached_rankings": len(total_keys), + "estimated_total_size_mb": ( + round(total_size / 1024 / 1024, 2) if total_size > 0 else 0 + ), + "refreshing": self._refreshing, + } + except Exception as e: + logger.error(f"Error getting cache stats: {e}") + return {"error": str(e)} + + +# 全局缓存服务实例 +_ranking_cache_service: RankingCacheService | None = None + + +def get_ranking_cache_service(redis: Redis) -> RankingCacheService: + """获取排行榜缓存服务实例""" + global _ranking_cache_service + if _ranking_cache_service is None: + _ranking_cache_service = RankingCacheService(redis) + return _ranking_cache_service + + +async def schedule_ranking_refresh_task(session: AsyncSession, redis: Redis): + """定时排行榜刷新任务""" + # 默认启用排行榜缓存,除非明确禁用 + enable_ranking_cache = getattr(settings, "enable_ranking_cache", True) + if not enable_ranking_cache: + return + + cache_service = get_ranking_cache_service(redis) + try: + await cache_service.refresh_all_rankings(session) + except Exception as e: + logger.error(f"Scheduled ranking refresh task failed: {e}")