This commit is contained in:
咕谷酱
2025-08-22 00:07:19 +08:00
parent bade8658ed
commit 80d4237c5d
22 changed files with 423 additions and 356 deletions

View File

@@ -2,11 +2,12 @@
用户排行榜缓存服务
用于缓存用户排行榜数据,减轻数据库压力
"""
from __future__ import annotations
import asyncio
from datetime import UTC, datetime
import json
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Literal
from app.config import settings
@@ -24,7 +25,7 @@ if TYPE_CHECKING:
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
@@ -33,7 +34,9 @@ class DateTimeEncoder(json.JSONEncoder):
def safe_json_dumps(data) -> str:
"""安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
return json.dumps(
data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")
)
class RankingCacheService:
@@ -84,7 +87,7 @@ class RankingCacheService:
try:
cache_key = self._get_cache_key(ruleset, type, country, page)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
@@ -107,11 +110,7 @@ class RankingCacheService:
# 使用配置文件的TTL设置
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set(
cache_key,
safe_json_dumps(ranking_data),
ex=ttl
)
await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
logger.debug(f"Cached ranking data for {cache_key}")
except Exception as e:
logger.error(f"Error caching ranking: {e}")
@@ -126,7 +125,7 @@ class RankingCacheService:
try:
cache_key = self._get_stats_cache_key(ruleset, type, country)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
@@ -148,11 +147,7 @@ class RankingCacheService:
# 使用配置文件的TTL设置统计信息缓存时间更长
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set(
cache_key,
safe_json_dumps(stats),
ex=ttl
)
await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
logger.debug(f"Cached stats for {cache_key}")
except Exception as e:
logger.error(f"Error caching stats: {e}")
@@ -166,7 +161,7 @@ class RankingCacheService:
try:
cache_key = self._get_country_cache_key(ruleset, page)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
@@ -186,11 +181,7 @@ class RankingCacheService:
cache_key = self._get_country_cache_key(ruleset, page)
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set(
cache_key,
safe_json_dumps(ranking_data),
ex=ttl
)
await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
logger.debug(f"Cached country ranking data for {cache_key}")
except Exception as e:
logger.error(f"Error caching country ranking: {e}")
@@ -200,7 +191,7 @@ class RankingCacheService:
try:
cache_key = self._get_country_stats_cache_key(ruleset)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
@@ -219,11 +210,7 @@ class RankingCacheService:
cache_key = self._get_country_stats_cache_key(ruleset)
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set(
cache_key,
safe_json_dumps(stats),
ex=ttl
)
await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
logger.debug(f"Cached country stats for {cache_key}")
except Exception as e:
logger.error(f"Error caching country stats: {e}")
@@ -238,7 +225,9 @@ class RankingCacheService:
) -> None:
"""刷新排行榜缓存"""
if self._refreshing:
logger.info(f"Ranking cache refresh already in progress for {ruleset}:{type}")
logger.info(
f"Ranking cache refresh already in progress for {ruleset}:{type}"
)
return
# 使用配置文件的设置
@@ -248,7 +237,7 @@ class RankingCacheService:
self._refreshing = True
try:
logger.info(f"Starting ranking cache refresh for {ruleset}:{type}")
# 构建查询条件
wheres = [
col(UserStatistics.mode) == ruleset,
@@ -256,20 +245,22 @@ class RankingCacheService:
col(UserStatistics.is_ranked).is_(True),
]
include = ["user"]
if type == "performance":
order_by = col(UserStatistics.pp).desc()
include.append("rank_change_since_30_days")
else:
order_by = col(UserStatistics.ranked_score).desc()
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
wheres.append(
col(UserStatistics.user).has(country_code=country.upper())
)
# 获取总用户数用于统计
total_users_query = select(UserStatistics).where(*wheres)
total_users = len((await session.exec(total_users_query)).all())
# 计算统计信息
stats = {
"total_users": total_users,
@@ -278,7 +269,7 @@ class RankingCacheService:
"ruleset": ruleset,
"country": country,
}
# 缓存统计信息
await self.cache_stats(ruleset, type, stats, country)
@@ -292,11 +283,11 @@ class RankingCacheService:
.limit(50)
.offset(50 * (page - 1))
)
statistics_data = statistics_list.all()
if not statistics_data:
break # 没有更多数据
# 转换为响应格式并确保正确序列化
ranking_data = []
for statistics in statistics_data:
@@ -306,21 +297,19 @@ class RankingCacheService:
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题
user_dict = json.loads(user_stats_resp.model_dump_json())
ranking_data.append(user_dict)
# 缓存这一页的数据
await self.cache_ranking(
ruleset, type, ranking_data, country, page
)
await self.cache_ranking(ruleset, type, ranking_data, country, page)
# 添加延迟避免数据库过载
if page < max_pages:
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error caching page {page} for {ruleset}:{type}: {e}")
logger.info(f"Completed ranking cache refresh for {ruleset}:{type}")
except Exception as e:
logger.error(f"Ranking cache refresh failed for {ruleset}:{type}: {e}")
finally:
@@ -334,7 +323,9 @@ class RankingCacheService:
) -> None:
"""刷新地区排行榜缓存"""
if self._refreshing:
logger.info(f"Country ranking cache refresh already in progress for {ruleset}")
logger.info(
f"Country ranking cache refresh already in progress for {ruleset}"
)
return
if max_pages is None:
@@ -343,17 +334,18 @@ class RankingCacheService:
self._refreshing = True
try:
logger.info(f"Starting country ranking cache refresh for {ruleset}")
# 获取所有国家
from app.database import User
countries = (await session.exec(select(User.country_code).distinct())).all()
# 计算每个国家的统计数据
country_stats_list = []
for country in countries:
if not country: # 跳过空的国家代码
continue
statistics = (
await session.exec(
select(UserStatistics).where(
@@ -364,10 +356,10 @@ class RankingCacheService:
)
)
).all()
if not statistics: # 跳过没有数据的国家
continue
pp = 0
country_stats = {
"code": country,
@@ -376,48 +368,48 @@ class RankingCacheService:
"ranked_score": 0,
"performance": 0,
}
for stat in statistics:
country_stats["active_users"] += 1
country_stats["play_count"] += stat.play_count
country_stats["ranked_score"] += stat.ranked_score
pp += stat.pp
country_stats["performance"] = round(pp)
country_stats_list.append(country_stats)
# 按表现分排序
country_stats_list.sort(key=lambda x: x["performance"], reverse=True)
# 计算统计信息
stats = {
"total_countries": len(country_stats_list),
"last_updated": datetime.now(UTC).isoformat(),
"ruleset": ruleset,
}
# 缓存统计信息
await self.cache_country_stats(ruleset, stats)
# 分页缓存数据每页50个国家
page_size = 50
for page in range(1, max_pages + 1):
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
page_data = country_stats_list[start_idx:end_idx]
if not page_data:
break # 没有更多数据
# 缓存这一页的数据
await self.cache_country_ranking(ruleset, page_data, page)
# 添加延迟避免Redis过载
if page < max_pages and page_data:
await asyncio.sleep(0.1)
logger.info(f"Completed country ranking cache refresh for {ruleset}")
except Exception as e:
logger.error(f"Country ranking cache refresh failed for {ruleset}: {e}")
finally:
@@ -427,11 +419,12 @@ class RankingCacheService:
"""刷新所有排行榜缓存"""
game_modes = [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
# 获取需要缓存的国家列表活跃用户数量前20的国家
from app.database import User
from sqlmodel import func
countries_query = (
await session.exec(
select(User.country_code, func.count().label("user_count"))
@@ -441,38 +434,40 @@ class RankingCacheService:
.limit(settings.ranking_cache_top_countries)
)
).all()
top_countries = [country for country, _ in countries_query]
refresh_tasks = []
# 全球排行榜
for mode in game_modes:
for ranking_type in ranking_types:
task = self.refresh_ranking_cache(session, mode, ranking_type)
refresh_tasks.append(task)
# 国家排行榜仅前20个国家
for country in top_countries:
for mode in game_modes:
for ranking_type in ranking_types:
task = self.refresh_ranking_cache(session, mode, ranking_type, country)
task = self.refresh_ranking_cache(
session, mode, ranking_type, country
)
refresh_tasks.append(task)
# 地区排行榜
for mode in game_modes:
task = self.refresh_country_ranking_cache(session, mode)
refresh_tasks.append(task)
# 并发执行刷新任务,但限制并发数
semaphore = asyncio.Semaphore(5) # 最多同时5个任务
async def bounded_refresh(task):
async with semaphore:
await task
bounded_tasks = [bounded_refresh(task) for task in refresh_tasks]
try:
await asyncio.gather(*bounded_tasks, return_exceptions=True)
logger.info("All ranking cache refresh completed")
@@ -489,7 +484,7 @@ class RankingCacheService:
"""使缓存失效"""
try:
deleted_keys = 0
if ruleset and type:
# 删除特定的用户排行榜缓存
country_part = f":{country.upper()}" if country else ""
@@ -498,12 +493,14 @@ class RankingCacheService:
if keys:
await self.redis.delete(*keys)
deleted_keys += len(keys)
logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}")
logger.info(
f"Invalidated {len(keys)} cache keys for {ruleset}:{type}"
)
elif ruleset:
# 删除特定游戏模式的所有缓存
patterns = [
f"ranking:{ruleset}:*",
f"country_ranking:{ruleset}:*" if include_country_ranking else None
f"country_ranking:{ruleset}:*" if include_country_ranking else None,
]
for pattern in patterns:
if pattern:
@@ -516,15 +513,15 @@ class RankingCacheService:
patterns = ["ranking:*"]
if include_country_ranking:
patterns.append("country_ranking:*")
for pattern in patterns:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
deleted_keys += len(keys)
logger.info(f"Invalidated all {deleted_keys} ranking cache keys")
except Exception as e:
logger.error(f"Error invalidating cache: {e}")
@@ -535,7 +532,7 @@ class RankingCacheService:
pattern = f"country_ranking:{ruleset}:*"
else:
pattern = "country_ranking:*"
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
@@ -550,10 +547,10 @@ class RankingCacheService:
ranking_keys = await self.redis.keys("ranking:*")
# 获取地区排行榜缓存
country_keys = await self.redis.keys("country_ranking:*")
total_keys = ranking_keys + country_keys
total_size = 0
for key in total_keys[:100]: # 限制检查数量以避免性能问题
try:
size = await self.redis.memory_usage(key)
@@ -561,7 +558,7 @@ class RankingCacheService:
total_size += size
except Exception:
continue
return {
"cached_user_rankings": len(ranking_keys),
"cached_country_rankings": len(country_keys),