add ranking cache

This commit is contained in:
咕谷酱
2025-08-21 21:35:08 +08:00
parent 56e83fa098
commit b316511cf5
4 changed files with 879 additions and 18 deletions

View File

@@ -142,6 +142,13 @@ class Settings(BaseSettings):
beatmap_cache_expire_hours: int = 24
max_concurrent_pp_calculations: int = 10
enable_pp_calculation_threading: bool = True
# 排行榜缓存设置
enable_ranking_cache: bool = True
ranking_cache_expire_minutes: int = 10 # 排行榜缓存过期时间(分钟)
ranking_cache_refresh_interval_minutes: int = 10 # 排行榜缓存刷新间隔(分钟)
ranking_cache_max_pages: int = 20 # 最多缓存的页数
ranking_cache_top_countries: int = 20 # 缓存前N个国家的排行榜
# 反作弊设置
suspicious_score_check: bool = True

View File

@@ -2,11 +2,13 @@ from __future__ import annotations
from typing import Literal
from app.config import settings
from app.database import User
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.dependencies import get_current_user
from app.dependencies.database import Database
from app.dependencies.database import Database, get_redis
from app.models.score import GameMode
from app.service.ranking_cache_service import get_ranking_cache_service
from .router import router
@@ -40,9 +42,27 @@ async def get_country_ranking(
page: int = Query(1, ge=1, description="页码"), # TODO
current_user: User = Security(get_current_user, scopes=["public"]),
):
# 获取 Redis 连接和缓存服务
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
# 尝试从缓存获取数据
cached_data = await cache_service.get_cached_country_ranking(ruleset, page)
if cached_data:
# 从缓存返回数据
return CountryResponse(
ranking=[CountryStatistics.model_validate(item) for item in cached_data]
)
# 缓存未命中,从数据库查询
response = CountryResponse(ranking=[])
countries = (await session.exec(select(User.country_code).distinct())).all()
for country in countries:
if not country: # 跳过空的国家代码
continue
statistics = (
await session.exec(
select(UserStatistics).where(
@@ -53,6 +73,10 @@ async def get_country_ranking(
)
)
).all()
if not statistics: # 跳过没有数据的国家
continue
pp = 0
country_stats = CountryStatistics(
code=country,
@@ -68,7 +92,29 @@ async def get_country_ranking(
pp += stat.pp
country_stats.performance = round(pp)
response.ranking.append(country_stats)
response.ranking.sort(key=lambda x: x.performance, reverse=True)
# 分页处理
page_size = 50
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
# 获取当前页的数据
current_page_data = response.ranking[start_idx:end_idx]
# 异步缓存数据(不等待完成)
cache_data = [item.model_dump() for item in current_page_data]
cache_task = cache_service.cache_country_ranking(
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
)
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
# 返回当前页的结果
response.ranking = current_page_data
return response
@@ -93,6 +139,24 @@ async def get_user_ranking(
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
):
# 获取 Redis 连接和缓存服务
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
# 尝试从缓存获取数据
cached_data = await cache_service.get_cached_ranking(
ruleset, type, country, page
)
if cached_data:
# 从缓存返回数据
return TopUsersResponse(
ranking=[
UserStatisticsResp.model_validate(item) for item in cached_data
]
)
# 缓存未命中,从数据库查询
wheres = [
col(UserStatistics.mode) == ruleset,
col(UserStatistics.pp) > 0,
@@ -106,6 +170,7 @@ async def get_user_ranking(
order_by = col(UserStatistics.ranked_score).desc()
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
statistics_list = await session.exec(
select(UserStatistics)
.where(*wheres)
@@ -113,10 +178,154 @@ async def get_user_ranking(
.limit(50)
.offset(50 * (page - 1))
)
resp = TopUsersResponse(
ranking=[
await UserStatisticsResp.from_db(statistics, session, None, include)
for statistics in statistics_list
]
# 转换为响应格式
ranking_data = []
for statistics in statistics_list:
user_stats_resp = await UserStatisticsResp.from_db(
statistics, session, None, include
)
ranking_data.append(user_stats_resp)
# 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data]
cache_task = cache_service.cache_ranking(
ruleset, type, cache_data, country, page, ttl=settings.ranking_cache_expire_minutes * 60
)
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
resp = TopUsersResponse(ranking=ranking_data)
return resp
""" @router.post(
"/rankings/cache/refresh",
name="刷新排行榜缓存",
description="手动刷新排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def refresh_ranking_cache(
session: Database,
ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
if ruleset and type:
# 刷新特定的用户排行榜
await cache_service.refresh_ranking_cache(session, ruleset, type, country)
message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# 如果请求刷新地区排行榜
if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
await cache_service.refresh_country_ranking_cache(session, ruleset)
message += f" and country ranking for {ruleset}"
return {"message": message}
elif ruleset:
# 刷新特定游戏模式的所有排行榜
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
for ranking_type in ranking_types:
await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
if include_country_ranking:
await cache_service.refresh_country_ranking_cache(session, ruleset)
return {"message": f"Refreshed all ranking caches for {ruleset}"}
else:
# 刷新所有排行榜
await cache_service.refresh_all_rankings(session)
return {"message": "Refreshed all ranking caches"}
@router.post(
"/rankings/{ruleset}/country/cache/refresh",
name="刷新地区排行榜缓存",
description="手动刷新地区排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def refresh_country_ranking_cache(
session: Database,
ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.refresh_country_ranking_cache(session, ruleset)
return {"message": f"Refreshed country ranking cache for {ruleset}"}
@router.delete(
"/rankings/cache",
name="清除排行榜缓存",
description="清除排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def clear_ranking_cache(
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
if ruleset and type:
message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
if include_country_ranking:
message += " and country ranking"
return {"message": message}
else:
message = "Cleared all ranking caches"
if include_country_ranking:
message += " including country rankings"
return {"message": message}
@router.delete(
"/rankings/{ruleset}/country/cache",
name="清除地区排行榜缓存",
description="清除地区排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def clear_country_ranking_cache(
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.invalidate_country_cache(ruleset)
if ruleset:
return {"message": f"Cleared country ranking cache for {ruleset}"}
else:
return {"message": "Cleared all country ranking caches"}
@router.get(
"/rankings/cache/stats",
name="获取排行榜缓存统计",
description="获取排行榜缓存统计信息(管理员功能)",
tags=["排行榜", "管理"],
)
async def get_ranking_cache_stats(
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
stats = await cache_service.get_cache_stats()
return stats """

View File

@@ -2,13 +2,14 @@ from __future__ import annotations
import asyncio
from app.dependencies.database import get_redis
from app.config import settings
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.log import logger
class BeatmapsetCacheScheduler:
"""谱面集缓存调度器"""
class CacheScheduler:
"""缓存调度器 - 统一管理各种缓存任务"""
def __init__(self):
self.running = False
@@ -21,7 +22,7 @@ class BeatmapsetCacheScheduler:
self.running = True
self.task = asyncio.create_task(self._run_scheduler())
logger.info("BeatmapsetCacheScheduler started")
logger.info("CacheScheduler started")
async def stop(self):
"""停止调度器"""
@@ -32,20 +33,47 @@ class BeatmapsetCacheScheduler:
await self.task
except asyncio.CancelledError:
pass
logger.info("BeatmapsetCacheScheduler stopped")
logger.info("CacheScheduler stopped")
async def _run_scheduler(self):
"""运行调度器主循环"""
# 启动时立即执行一次预热
await self._warmup_cache()
# 启动时执行一次排行榜缓存刷新
await self._refresh_ranking_cache()
beatmap_cache_counter = 0
ranking_cache_counter = 0
# 从配置文件获取间隔设置
check_interval = 5 * 60 # 5分钟检查间隔
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取
beatmap_cache_cycles = beatmap_cache_interval // check_interval
ranking_cache_cycles = ranking_cache_interval // check_interval
while self.running:
try:
# 每30分钟执行一次缓存预热
await asyncio.sleep(30 * 60) # 30分钟
# 每5分钟检查一次
await asyncio.sleep(check_interval)
if self.running:
if not self.running:
break
beatmap_cache_counter += 1
ranking_cache_counter += 1
# beatmap缓存预热
if beatmap_cache_counter >= beatmap_cache_cycles:
await self._warmup_cache()
beatmap_cache_counter = 0
# 排行榜缓存刷新
if ranking_cache_counter >= ranking_cache_cycles:
await self._refresh_ranking_cache()
ranking_cache_counter = 0
except asyncio.CancelledError:
break
@@ -56,7 +84,7 @@ class BeatmapsetCacheScheduler:
async def _warmup_cache(self):
"""执行缓存预热"""
try:
logger.info("Starting cache warmup...")
logger.info("Starting beatmap cache warmup...")
fetcher = await get_fetcher()
redis = get_redis()
@@ -64,14 +92,45 @@ class BeatmapsetCacheScheduler:
# 预热主页缓存
await fetcher.warmup_homepage_cache(redis)
logger.info("Cache warmup completed successfully")
logger.info("Beatmap cache warmup completed successfully")
except Exception as e:
logger.error(f"Cache warmup failed: {e}")
logger.error(f"Beatmap cache warmup failed: {e}")
async def _refresh_ranking_cache(self):
"""刷新排行榜缓存"""
try:
logger.info("Starting ranking cache refresh...")
redis = get_redis()
# 导入排行榜缓存服务
from app.service.ranking_cache_service import (
get_ranking_cache_service,
schedule_ranking_refresh_task,
)
# 获取数据库会话
async for session in get_db():
await schedule_ranking_refresh_task(session, redis)
break # 只需要一次会话
logger.info("Ranking cache refresh completed successfully")
except Exception as e:
logger.error(f"Ranking cache refresh failed: {e}")
# Beatmap缓存调度器保持向后兼容
class BeatmapsetCacheScheduler(CacheScheduler):
"""谱面集缓存调度器 - 为了向后兼容"""
pass
# 全局调度器实例
cache_scheduler = BeatmapsetCacheScheduler()
cache_scheduler = CacheScheduler()
# 保持向后兼容的别名
beatmapset_cache_scheduler = BeatmapsetCacheScheduler()
async def start_cache_scheduler():

View File

@@ -0,0 +1,586 @@
"""
用户排行榜缓存服务
用于缓存用户排行榜数据,减轻数据库压力
"""
from __future__ import annotations
import asyncio
import json
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Literal
from app.config import settings
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.log import logger
from app.models.score import GameMode
from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
pass
class RankingCacheService:
"""用户排行榜缓存服务"""
def __init__(self, redis: Redis):
self.redis = redis
self._refreshing = False
self._background_tasks: set = set()
def _get_cache_key(
self,
ruleset: GameMode,
type: Literal["performance", "score"],
country: str | None = None,
page: int = 1,
) -> str:
"""生成缓存键"""
country_part = f":{country.upper()}" if country else ""
return f"ranking:{ruleset}:{type}{country_part}:page:{page}"
def _get_stats_cache_key(
self,
ruleset: GameMode,
type: Literal["performance", "score"],
country: str | None = None,
) -> str:
"""生成统计信息缓存键"""
country_part = f":{country.upper()}" if country else ""
return f"ranking:stats:{ruleset}:{type}{country_part}"
def _get_country_cache_key(self, ruleset: GameMode, page: int = 1) -> str:
"""生成地区排行榜缓存键"""
return f"country_ranking:{ruleset}:page:{page}"
def _get_country_stats_cache_key(self, ruleset: GameMode) -> str:
"""生成地区排行榜统计信息缓存键"""
return f"country_ranking:stats:{ruleset}"
async def get_cached_ranking(
self,
ruleset: GameMode,
type: Literal["performance", "score"],
country: str | None = None,
page: int = 1,
) -> list[dict] | None:
"""获取缓存的排行榜数据"""
try:
cache_key = self._get_cache_key(ruleset, type, country, page)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
except Exception as e:
logger.error(f"Error getting cached ranking: {e}")
return None
async def cache_ranking(
self,
ruleset: GameMode,
type: Literal["performance", "score"],
ranking_data: list[dict],
country: str | None = None,
page: int = 1,
ttl: int | None = None, # 允许为None以使用配置文件的默认值
) -> None:
"""缓存排行榜数据"""
try:
cache_key = self._get_cache_key(ruleset, type, country, page)
# 使用配置文件的TTL设置
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set(
cache_key,
json.dumps(ranking_data, separators=(",", ":")),
ex=ttl
)
logger.debug(f"Cached ranking data for {cache_key}")
except Exception as e:
logger.error(f"Error caching ranking: {e}")
async def get_cached_stats(
self,
ruleset: GameMode,
type: Literal["performance", "score"],
country: str | None = None,
) -> dict | None:
"""获取缓存的统计信息"""
try:
cache_key = self._get_stats_cache_key(ruleset, type, country)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
except Exception as e:
logger.error(f"Error getting cached stats: {e}")
return None
async def cache_stats(
self,
ruleset: GameMode,
type: Literal["performance", "score"],
stats: dict,
country: str | None = None,
ttl: int | None = None, # 允许为None以使用配置文件的默认值
) -> None:
"""缓存统计信息"""
try:
cache_key = self._get_stats_cache_key(ruleset, type, country)
# 使用配置文件的TTL设置统计信息缓存时间更长
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set(
cache_key,
json.dumps(stats, separators=(",", ":")),
ex=ttl
)
logger.debug(f"Cached stats for {cache_key}")
except Exception as e:
logger.error(f"Error caching stats: {e}")
async def get_cached_country_ranking(
self,
ruleset: GameMode,
page: int = 1,
) -> list[dict] | None:
"""获取缓存的地区排行榜数据"""
try:
cache_key = self._get_country_cache_key(ruleset, page)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
except Exception as e:
logger.error(f"Error getting cached country ranking: {e}")
return None
async def cache_country_ranking(
self,
ruleset: GameMode,
ranking_data: list[dict],
page: int = 1,
ttl: int | None = None,
) -> None:
"""缓存地区排行榜数据"""
try:
cache_key = self._get_country_cache_key(ruleset, page)
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set(
cache_key,
json.dumps(ranking_data, separators=(",", ":")),
ex=ttl
)
logger.debug(f"Cached country ranking data for {cache_key}")
except Exception as e:
logger.error(f"Error caching country ranking: {e}")
async def get_cached_country_stats(self, ruleset: GameMode) -> dict | None:
"""获取缓存的地区排行榜统计信息"""
try:
cache_key = self._get_country_stats_cache_key(ruleset)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
except Exception as e:
logger.error(f"Error getting cached country stats: {e}")
return None
async def cache_country_stats(
self,
ruleset: GameMode,
stats: dict,
ttl: int | None = None,
) -> None:
"""缓存地区排行榜统计信息"""
try:
cache_key = self._get_country_stats_cache_key(ruleset)
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set(
cache_key,
json.dumps(stats, separators=(",", ":")),
ex=ttl
)
logger.debug(f"Cached country stats for {cache_key}")
except Exception as e:
logger.error(f"Error caching country stats: {e}")
async def refresh_ranking_cache(
self,
session: AsyncSession,
ruleset: GameMode,
type: Literal["performance", "score"],
country: str | None = None,
max_pages: int | None = None, # 允许为None以使用配置文件的默认值
) -> None:
"""刷新排行榜缓存"""
if self._refreshing:
logger.info(f"Ranking cache refresh already in progress for {ruleset}:{type}")
return
# 使用配置文件的设置
if max_pages is None:
max_pages = settings.ranking_cache_max_pages
self._refreshing = True
try:
logger.info(f"Starting ranking cache refresh for {ruleset}:{type}")
# 构建查询条件
wheres = [
col(UserStatistics.mode) == ruleset,
col(UserStatistics.pp) > 0,
col(UserStatistics.is_ranked).is_(True),
]
include = ["user"]
if type == "performance":
order_by = col(UserStatistics.pp).desc()
include.append("rank_change_since_30_days")
else:
order_by = col(UserStatistics.ranked_score).desc()
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
# 获取总用户数用于统计
total_users_query = select(UserStatistics).where(*wheres)
total_users = len((await session.exec(total_users_query)).all())
# 计算统计信息
stats = {
"total_users": total_users,
"last_updated": datetime.now(UTC).isoformat(),
"type": type,
"ruleset": ruleset,
"country": country,
}
# 缓存统计信息
await self.cache_stats(ruleset, type, stats, country)
# 分页缓存数据
for page in range(1, max_pages + 1):
try:
statistics_list = await session.exec(
select(UserStatistics)
.where(*wheres)
.order_by(order_by)
.limit(50)
.offset(50 * (page - 1))
)
statistics_data = statistics_list.all()
if not statistics_data:
break # 没有更多数据
# 转换为响应格式
ranking_data = []
for statistics in statistics_data:
user_stats_resp = await UserStatisticsResp.from_db(
statistics, session, None, include
)
ranking_data.append(user_stats_resp.model_dump())
# 缓存这一页的数据
await self.cache_ranking(
ruleset, type, ranking_data, country, page
)
# 添加延迟避免数据库过载
if page < max_pages:
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error caching page {page} for {ruleset}:{type}: {e}")
logger.info(f"Completed ranking cache refresh for {ruleset}:{type}")
except Exception as e:
logger.error(f"Ranking cache refresh failed for {ruleset}:{type}: {e}")
finally:
self._refreshing = False
async def refresh_country_ranking_cache(
self,
session: AsyncSession,
ruleset: GameMode,
max_pages: int | None = None,
) -> None:
"""刷新地区排行榜缓存"""
if self._refreshing:
logger.info(f"Country ranking cache refresh already in progress for {ruleset}")
return
if max_pages is None:
max_pages = settings.ranking_cache_max_pages
self._refreshing = True
try:
logger.info(f"Starting country ranking cache refresh for {ruleset}")
# 获取所有国家
from app.database import User
countries = (await session.exec(select(User.country_code).distinct())).all()
# 计算每个国家的统计数据
country_stats_list = []
for country in countries:
if not country: # 跳过空的国家代码
continue
statistics = (
await session.exec(
select(UserStatistics).where(
UserStatistics.mode == ruleset,
UserStatistics.pp > 0,
col(UserStatistics.user).has(country_code=country),
col(UserStatistics.user).has(is_active=True),
)
)
).all()
if not statistics: # 跳过没有数据的国家
continue
pp = 0
country_stats = {
"code": country,
"active_users": 0,
"play_count": 0,
"ranked_score": 0,
"performance": 0,
}
for stat in statistics:
country_stats["active_users"] += 1
country_stats["play_count"] += stat.play_count
country_stats["ranked_score"] += stat.ranked_score
pp += stat.pp
country_stats["performance"] = round(pp)
country_stats_list.append(country_stats)
# 按表现分排序
country_stats_list.sort(key=lambda x: x["performance"], reverse=True)
# 计算统计信息
stats = {
"total_countries": len(country_stats_list),
"last_updated": datetime.now(UTC).isoformat(),
"ruleset": ruleset,
}
# 缓存统计信息
await self.cache_country_stats(ruleset, stats)
# 分页缓存数据每页50个国家
page_size = 50
for page in range(1, max_pages + 1):
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
page_data = country_stats_list[start_idx:end_idx]
if not page_data:
break # 没有更多数据
# 缓存这一页的数据
await self.cache_country_ranking(ruleset, page_data, page)
# 添加延迟避免Redis过载
if page < max_pages and page_data:
await asyncio.sleep(0.1)
logger.info(f"Completed country ranking cache refresh for {ruleset}")
except Exception as e:
logger.error(f"Country ranking cache refresh failed for {ruleset}: {e}")
finally:
self._refreshing = False
async def refresh_all_rankings(self, session: AsyncSession) -> None:
"""刷新所有排行榜缓存"""
game_modes = [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
# 获取需要缓存的国家列表活跃用户数量前20的国家
from app.database import User
from sqlmodel import func
countries_query = (
await session.exec(
select(User.country_code, func.count().label("user_count"))
.where(col(User.is_active).is_(True))
.group_by(User.country_code)
.order_by(col("user_count").desc())
.limit(settings.ranking_cache_top_countries)
)
).all()
top_countries = [country for country, _ in countries_query]
refresh_tasks = []
# 全球排行榜
for mode in game_modes:
for ranking_type in ranking_types:
task = self.refresh_ranking_cache(session, mode, ranking_type)
refresh_tasks.append(task)
# 国家排行榜仅前20个国家
for country in top_countries:
for mode in game_modes:
for ranking_type in ranking_types:
task = self.refresh_ranking_cache(session, mode, ranking_type, country)
refresh_tasks.append(task)
# 地区排行榜
for mode in game_modes:
task = self.refresh_country_ranking_cache(session, mode)
refresh_tasks.append(task)
# 并发执行刷新任务,但限制并发数
semaphore = asyncio.Semaphore(5) # 最多同时5个任务
async def bounded_refresh(task):
async with semaphore:
await task
bounded_tasks = [bounded_refresh(task) for task in refresh_tasks]
try:
await asyncio.gather(*bounded_tasks, return_exceptions=True)
logger.info("All ranking cache refresh completed")
except Exception as e:
logger.error(f"Error in batch ranking cache refresh: {e}")
async def invalidate_cache(
self,
ruleset: GameMode | None = None,
type: Literal["performance", "score"] | None = None,
country: str | None = None,
include_country_ranking: bool = True,
) -> None:
"""使缓存失效"""
try:
deleted_keys = 0
if ruleset and type:
# 删除特定的用户排行榜缓存
country_part = f":{country.upper()}" if country else ""
pattern = f"ranking:{ruleset}:{type}{country_part}:page:*"
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
deleted_keys += len(keys)
logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}")
elif ruleset:
# 删除特定游戏模式的所有缓存
patterns = [
f"ranking:{ruleset}:*",
f"country_ranking:{ruleset}:*" if include_country_ranking else None
]
for pattern in patterns:
if pattern:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
deleted_keys += len(keys)
else:
# 删除所有排行榜缓存
patterns = ["ranking:*"]
if include_country_ranking:
patterns.append("country_ranking:*")
for pattern in patterns:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
deleted_keys += len(keys)
logger.info(f"Invalidated all {deleted_keys} ranking cache keys")
except Exception as e:
logger.error(f"Error invalidating cache: {e}")
async def invalidate_country_cache(self, ruleset: GameMode | None = None) -> None:
"""使地区排行榜缓存失效"""
try:
if ruleset:
pattern = f"country_ranking:{ruleset}:*"
else:
pattern = "country_ranking:*"
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
logger.info(f"Invalidated {len(keys)} country ranking cache keys")
except Exception as e:
logger.error(f"Error invalidating country cache: {e}")
async def get_cache_stats(self) -> dict:
"""获取缓存统计信息"""
try:
# 获取用户排行榜缓存
ranking_keys = await self.redis.keys("ranking:*")
# 获取地区排行榜缓存
country_keys = await self.redis.keys("country_ranking:*")
total_keys = ranking_keys + country_keys
total_size = 0
for key in total_keys[:100]: # 限制检查数量以避免性能问题
try:
size = await self.redis.memory_usage(key)
if size:
total_size += size
except Exception:
continue
return {
"cached_user_rankings": len(ranking_keys),
"cached_country_rankings": len(country_keys),
"total_cached_rankings": len(total_keys),
"estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
),
"refreshing": self._refreshing,
}
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return {"error": str(e)}
# 全局缓存服务实例
_ranking_cache_service: RankingCacheService | None = None
def get_ranking_cache_service(redis: Redis) -> RankingCacheService:
"""获取排行榜缓存服务实例"""
global _ranking_cache_service
if _ranking_cache_service is None:
_ranking_cache_service = RankingCacheService(redis)
return _ranking_cache_service
async def schedule_ranking_refresh_task(session: AsyncSession, redis: Redis):
"""定时排行榜刷新任务"""
# 默认启用排行榜缓存,除非明确禁用
enable_ranking_cache = getattr(settings, "enable_ranking_cache", True)
if not enable_ranking_cache:
return
cache_service = get_ranking_cache_service(redis)
try:
await cache_service.refresh_all_rankings(session)
except Exception as e:
logger.error(f"Scheduled ranking refresh task failed: {e}")