This commit is contained in:
咕谷酱
2025-08-22 00:07:19 +08:00
parent bade8658ed
commit 80d4237c5d
22 changed files with 423 additions and 356 deletions

View File

@@ -2,6 +2,7 @@
Beatmap缓存预取服务
用于提前缓存热门beatmap减少成绩计算时的获取延迟
"""
from __future__ import annotations
import asyncio
@@ -155,9 +156,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS
async def schedule_preload_task(
session: AsyncSession,
redis: Redis,
fetcher: "Fetcher"
session: AsyncSession, redis: Redis, fetcher: "Fetcher"
):
"""
定时预加载任务

View File

@@ -2,11 +2,12 @@
用户排行榜缓存服务
用于缓存用户排行榜数据,减轻数据库压力
"""
from __future__ import annotations
import asyncio
from datetime import UTC, datetime
import json
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Literal
from app.config import settings
@@ -24,7 +25,7 @@ if TYPE_CHECKING:
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
@@ -33,7 +34,9 @@ class DateTimeEncoder(json.JSONEncoder):
def safe_json_dumps(data) -> str:
"""安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
return json.dumps(
data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")
)
class RankingCacheService:
@@ -84,7 +87,7 @@ class RankingCacheService:
try:
cache_key = self._get_cache_key(ruleset, type, country, page)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
@@ -107,11 +110,7 @@ class RankingCacheService:
# 使用配置文件的TTL设置
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set(
cache_key,
safe_json_dumps(ranking_data),
ex=ttl
)
await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
logger.debug(f"Cached ranking data for {cache_key}")
except Exception as e:
logger.error(f"Error caching ranking: {e}")
@@ -126,7 +125,7 @@ class RankingCacheService:
try:
cache_key = self._get_stats_cache_key(ruleset, type, country)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
@@ -148,11 +147,7 @@ class RankingCacheService:
# 使用配置文件的TTL设置统计信息缓存时间更长
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set(
cache_key,
safe_json_dumps(stats),
ex=ttl
)
await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
logger.debug(f"Cached stats for {cache_key}")
except Exception as e:
logger.error(f"Error caching stats: {e}")
@@ -166,7 +161,7 @@ class RankingCacheService:
try:
cache_key = self._get_country_cache_key(ruleset, page)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
@@ -186,11 +181,7 @@ class RankingCacheService:
cache_key = self._get_country_cache_key(ruleset, page)
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60
await self.redis.set(
cache_key,
safe_json_dumps(ranking_data),
ex=ttl
)
await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
logger.debug(f"Cached country ranking data for {cache_key}")
except Exception as e:
logger.error(f"Error caching country ranking: {e}")
@@ -200,7 +191,7 @@ class RankingCacheService:
try:
cache_key = self._get_country_stats_cache_key(ruleset)
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
@@ -219,11 +210,7 @@ class RankingCacheService:
cache_key = self._get_country_stats_cache_key(ruleset)
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
await self.redis.set(
cache_key,
safe_json_dumps(stats),
ex=ttl
)
await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
logger.debug(f"Cached country stats for {cache_key}")
except Exception as e:
logger.error(f"Error caching country stats: {e}")
@@ -238,7 +225,9 @@ class RankingCacheService:
) -> None:
"""刷新排行榜缓存"""
if self._refreshing:
logger.info(f"Ranking cache refresh already in progress for {ruleset}:{type}")
logger.info(
f"Ranking cache refresh already in progress for {ruleset}:{type}"
)
return
# 使用配置文件的设置
@@ -248,7 +237,7 @@ class RankingCacheService:
self._refreshing = True
try:
logger.info(f"Starting ranking cache refresh for {ruleset}:{type}")
# 构建查询条件
wheres = [
col(UserStatistics.mode) == ruleset,
@@ -256,20 +245,22 @@ class RankingCacheService:
col(UserStatistics.is_ranked).is_(True),
]
include = ["user"]
if type == "performance":
order_by = col(UserStatistics.pp).desc()
include.append("rank_change_since_30_days")
else:
order_by = col(UserStatistics.ranked_score).desc()
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
wheres.append(
col(UserStatistics.user).has(country_code=country.upper())
)
# 获取总用户数用于统计
total_users_query = select(UserStatistics).where(*wheres)
total_users = len((await session.exec(total_users_query)).all())
# 计算统计信息
stats = {
"total_users": total_users,
@@ -278,7 +269,7 @@ class RankingCacheService:
"ruleset": ruleset,
"country": country,
}
# 缓存统计信息
await self.cache_stats(ruleset, type, stats, country)
@@ -292,11 +283,11 @@ class RankingCacheService:
.limit(50)
.offset(50 * (page - 1))
)
statistics_data = statistics_list.all()
if not statistics_data:
break # 没有更多数据
# 转换为响应格式并确保正确序列化
ranking_data = []
for statistics in statistics_data:
@@ -306,21 +297,19 @@ class RankingCacheService:
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题
user_dict = json.loads(user_stats_resp.model_dump_json())
ranking_data.append(user_dict)
# 缓存这一页的数据
await self.cache_ranking(
ruleset, type, ranking_data, country, page
)
await self.cache_ranking(ruleset, type, ranking_data, country, page)
# 添加延迟避免数据库过载
if page < max_pages:
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error caching page {page} for {ruleset}:{type}: {e}")
logger.info(f"Completed ranking cache refresh for {ruleset}:{type}")
except Exception as e:
logger.error(f"Ranking cache refresh failed for {ruleset}:{type}: {e}")
finally:
@@ -334,7 +323,9 @@ class RankingCacheService:
) -> None:
"""刷新地区排行榜缓存"""
if self._refreshing:
logger.info(f"Country ranking cache refresh already in progress for {ruleset}")
logger.info(
f"Country ranking cache refresh already in progress for {ruleset}"
)
return
if max_pages is None:
@@ -343,17 +334,18 @@ class RankingCacheService:
self._refreshing = True
try:
logger.info(f"Starting country ranking cache refresh for {ruleset}")
# 获取所有国家
from app.database import User
countries = (await session.exec(select(User.country_code).distinct())).all()
# 计算每个国家的统计数据
country_stats_list = []
for country in countries:
if not country: # 跳过空的国家代码
continue
statistics = (
await session.exec(
select(UserStatistics).where(
@@ -364,10 +356,10 @@ class RankingCacheService:
)
)
).all()
if not statistics: # 跳过没有数据的国家
continue
pp = 0
country_stats = {
"code": country,
@@ -376,48 +368,48 @@ class RankingCacheService:
"ranked_score": 0,
"performance": 0,
}
for stat in statistics:
country_stats["active_users"] += 1
country_stats["play_count"] += stat.play_count
country_stats["ranked_score"] += stat.ranked_score
pp += stat.pp
country_stats["performance"] = round(pp)
country_stats_list.append(country_stats)
# 按表现分排序
country_stats_list.sort(key=lambda x: x["performance"], reverse=True)
# 计算统计信息
stats = {
"total_countries": len(country_stats_list),
"last_updated": datetime.now(UTC).isoformat(),
"ruleset": ruleset,
}
# 缓存统计信息
await self.cache_country_stats(ruleset, stats)
# 分页缓存数据每页50个国家
page_size = 50
for page in range(1, max_pages + 1):
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
page_data = country_stats_list[start_idx:end_idx]
if not page_data:
break # 没有更多数据
# 缓存这一页的数据
await self.cache_country_ranking(ruleset, page_data, page)
# 添加延迟避免Redis过载
if page < max_pages and page_data:
await asyncio.sleep(0.1)
logger.info(f"Completed country ranking cache refresh for {ruleset}")
except Exception as e:
logger.error(f"Country ranking cache refresh failed for {ruleset}: {e}")
finally:
@@ -427,11 +419,12 @@ class RankingCacheService:
"""刷新所有排行榜缓存"""
game_modes = [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
# 获取需要缓存的国家列表活跃用户数量前20的国家
from app.database import User
from sqlmodel import func
countries_query = (
await session.exec(
select(User.country_code, func.count().label("user_count"))
@@ -441,38 +434,40 @@ class RankingCacheService:
.limit(settings.ranking_cache_top_countries)
)
).all()
top_countries = [country for country, _ in countries_query]
refresh_tasks = []
# 全球排行榜
for mode in game_modes:
for ranking_type in ranking_types:
task = self.refresh_ranking_cache(session, mode, ranking_type)
refresh_tasks.append(task)
# 国家排行榜仅前20个国家
for country in top_countries:
for mode in game_modes:
for ranking_type in ranking_types:
task = self.refresh_ranking_cache(session, mode, ranking_type, country)
task = self.refresh_ranking_cache(
session, mode, ranking_type, country
)
refresh_tasks.append(task)
# 地区排行榜
for mode in game_modes:
task = self.refresh_country_ranking_cache(session, mode)
refresh_tasks.append(task)
# 并发执行刷新任务,但限制并发数
semaphore = asyncio.Semaphore(5) # 最多同时5个任务
async def bounded_refresh(task):
async with semaphore:
await task
bounded_tasks = [bounded_refresh(task) for task in refresh_tasks]
try:
await asyncio.gather(*bounded_tasks, return_exceptions=True)
logger.info("All ranking cache refresh completed")
@@ -489,7 +484,7 @@ class RankingCacheService:
"""使缓存失效"""
try:
deleted_keys = 0
if ruleset and type:
# 删除特定的用户排行榜缓存
country_part = f":{country.upper()}" if country else ""
@@ -498,12 +493,14 @@ class RankingCacheService:
if keys:
await self.redis.delete(*keys)
deleted_keys += len(keys)
logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}")
logger.info(
f"Invalidated {len(keys)} cache keys for {ruleset}:{type}"
)
elif ruleset:
# 删除特定游戏模式的所有缓存
patterns = [
f"ranking:{ruleset}:*",
f"country_ranking:{ruleset}:*" if include_country_ranking else None
f"country_ranking:{ruleset}:*" if include_country_ranking else None,
]
for pattern in patterns:
if pattern:
@@ -516,15 +513,15 @@ class RankingCacheService:
patterns = ["ranking:*"]
if include_country_ranking:
patterns.append("country_ranking:*")
for pattern in patterns:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
deleted_keys += len(keys)
logger.info(f"Invalidated all {deleted_keys} ranking cache keys")
except Exception as e:
logger.error(f"Error invalidating cache: {e}")
@@ -535,7 +532,7 @@ class RankingCacheService:
pattern = f"country_ranking:{ruleset}:*"
else:
pattern = "country_ranking:*"
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
@@ -550,10 +547,10 @@ class RankingCacheService:
ranking_keys = await self.redis.keys("ranking:*")
# 获取地区排行榜缓存
country_keys = await self.redis.keys("country_ranking:*")
total_keys = ranking_keys + country_keys
total_size = 0
for key in total_keys[:100]: # 限制检查数量以避免性能问题
try:
size = await self.redis.memory_usage(key)
@@ -561,7 +558,7 @@ class RankingCacheService:
total_size += size
except Exception:
continue
return {
"cached_user_rankings": len(ranking_keys),
"cached_country_rankings": len(country_keys),

View File

@@ -2,24 +2,23 @@
用户缓存服务
用于缓存用户信息,提供热缓存和实时刷新功能
"""
from __future__ import annotations
import asyncio
from datetime import datetime
import json
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any
from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import User, UserResp
from app.database.lazer_user import SEARCH_INCLUDED
from app.database.pp_best_score import PPBestScore
from app.database.score import Score, ScoreResp
from app.database.score import ScoreResp
from app.log import logger
from app.models.score import GameMode
from redis.asyncio import Redis
from sqlmodel import col, exists, select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
@@ -28,7 +27,7 @@ if TYPE_CHECKING:
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
@@ -48,16 +47,16 @@ class UserCacheService:
self._refreshing = False
self._background_tasks: set = set()
def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str:
def _get_v1_user_cache_key(
self, user_id: int, ruleset: GameMode | None = None
) -> str:
"""生成 V1 用户缓存键"""
if ruleset:
return f"v1_user:{user_id}:ruleset:{ruleset}"
return f"v1_user:{user_id}"
async def get_v1_user_from_cache(
self,
user_id: int,
ruleset: GameMode | None = None
self, user_id: int, ruleset: GameMode | None = None
) -> dict | None:
"""从缓存获取 V1 用户信息"""
try:
@@ -72,11 +71,11 @@ class UserCacheService:
return None
async def cache_v1_user(
self,
user_data: dict,
self,
user_data: dict,
user_id: int,
ruleset: GameMode | None = None,
expire_seconds: int | None = None
expire_seconds: int | None = None,
):
"""缓存 V1 用户信息"""
try:
@@ -97,7 +96,9 @@ class UserCacheService:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}")
logger.info(
f"Invalidated {len(keys)} V1 cache entries for user {user_id}"
)
except Exception as e:
logger.error(f"Error invalidating V1 user cache: {e}")
@@ -108,31 +109,25 @@ class UserCacheService:
return f"user:{user_id}"
def _get_user_scores_cache_key(
self,
user_id: int,
score_type: str,
self,
user_id: int,
score_type: str,
mode: GameMode | None = None,
limit: int = 100,
offset: int = 0
offset: int = 0,
) -> str:
"""生成用户成绩缓存键"""
mode_part = f":{mode}" if mode else ""
return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}"
def _get_user_beatmapsets_cache_key(
self,
user_id: int,
beatmapset_type: str,
limit: int = 100,
offset: int = 0
self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0
) -> str:
"""生成用户谱面集缓存键"""
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
async def get_user_from_cache(
self,
user_id: int,
ruleset: GameMode | None = None
self, user_id: int, ruleset: GameMode | None = None
) -> UserResp | None:
"""从缓存获取用户信息"""
try:
@@ -148,10 +143,10 @@ class UserCacheService:
return None
async def cache_user(
self,
user_resp: UserResp,
self,
user_resp: UserResp,
ruleset: GameMode | None = None,
expire_seconds: int | None = None
expire_seconds: int | None = None,
):
"""缓存用户信息"""
try:
@@ -173,14 +168,18 @@ class UserCacheService:
score_type: str,
mode: GameMode | None = None,
limit: int = 100,
offset: int = 0
offset: int = 0,
) -> list[ScoreResp] | None:
"""从缓存获取用户成绩"""
try:
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
cache_key = self._get_user_scores_cache_key(
user_id, score_type, mode, limit, offset
)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
logger.debug(
f"User scores cache hit for user {user_id}, type {score_type}"
)
data = json.loads(cached_data)
return [ScoreResp(**score_data) for score_data in data]
return None
@@ -196,34 +195,38 @@ class UserCacheService:
mode: GameMode | None = None,
limit: int = 100,
offset: int = 0,
expire_seconds: int | None = None
expire_seconds: int | None = None,
):
"""缓存用户成绩"""
try:
if expire_seconds is None:
expire_seconds = settings.user_scores_cache_expire_seconds
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
cache_key = self._get_user_scores_cache_key(
user_id, score_type, mode, limit, offset
)
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
scores_json_list = [score.model_dump_json() for score in scores]
cached_data = f"[{','.join(scores_json_list)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
logger.debug(
f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s"
)
except Exception as e:
logger.error(f"Error caching user scores: {e}")
async def get_user_beatmapsets_from_cache(
self,
user_id: int,
beatmapset_type: str,
limit: int = 100,
offset: int = 0
self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0
) -> list[Any] | None:
"""从缓存获取用户谱面集"""
try:
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
cache_key = self._get_user_beatmapsets_cache_key(
user_id, beatmapset_type, limit, offset
)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}")
logger.debug(
f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}"
)
return json.loads(cached_data)
return None
except Exception as e:
@@ -237,23 +240,27 @@ class UserCacheService:
beatmapsets: list[Any],
limit: int = 100,
offset: int = 0,
expire_seconds: int | None = None
expire_seconds: int | None = None,
):
"""缓存用户谱面集"""
try:
if expire_seconds is None:
expire_seconds = settings.user_beatmapsets_cache_expire_seconds
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
cache_key = self._get_user_beatmapsets_cache_key(
user_id, beatmapset_type, limit, offset
)
# 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps
serialized_beatmapsets = []
for bms in beatmapsets:
if hasattr(bms, 'model_dump_json'):
if hasattr(bms, "model_dump_json"):
serialized_beatmapsets.append(bms.model_dump_json())
else:
serialized_beatmapsets.append(safe_json_dumps(bms))
cached_data = f"[{','.join(serialized_beatmapsets)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s")
logger.debug(
f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s"
)
except Exception as e:
logger.error(f"Error caching user beatmapsets: {e}")
@@ -269,7 +276,9 @@ class UserCacheService:
except Exception as e:
logger.error(f"Error invalidating user cache: {e}")
async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None):
async def invalidate_user_scores_cache(
self, user_id: int, mode: GameMode | None = None
):
"""使用户成绩缓存失效"""
try:
# 删除用户成绩相关缓存
@@ -278,7 +287,9 @@ class UserCacheService:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}")
logger.info(
f"Invalidated {len(keys)} score cache entries for user {user_id}"
)
except Exception as e:
logger.error(f"Error invalidating user scores cache: {e}")
@@ -290,12 +301,10 @@ class UserCacheService:
self._refreshing = True
try:
logger.info(f"Preloading cache for {len(user_ids)} users")
# 批量获取用户
users = (
await session.exec(
select(User).where(col(User.id).in_(user_ids))
)
await session.exec(select(User).where(col(User.id).in_(user_ids)))
).all()
# 串行缓存用户信息,避免并发数据库访问问题
@@ -324,10 +333,7 @@ class UserCacheService:
logger.error(f"Error caching single user {user.id}: {e}")
async def refresh_user_cache_on_score_submit(
self,
session: AsyncSession,
user_id: int,
mode: GameMode
self, session: AsyncSession, user_id: int, mode: GameMode
):
"""成绩提交后刷新用户缓存"""
try:
@@ -335,7 +341,7 @@ class UserCacheService:
await self.invalidate_user_cache(user_id)
await self.invalidate_v1_user_cache(user_id)
await self.invalidate_user_scores_cache(user_id, mode)
# 立即重新加载用户信息
user = await session.get(User, user_id)
if user and user.id != BANCHOBOT_ID:
@@ -351,7 +357,7 @@ class UserCacheService:
v1_user_keys = await self.redis.keys("v1_user:*")
all_keys = user_keys + v1_user_keys
total_size = 0
for key in all_keys[:100]: # 限制检查数量
try:
size = await self.redis.memory_usage(key)
@@ -359,12 +365,22 @@ class UserCacheService:
total_size += size
except Exception:
continue
return {
"cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]),
"cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]),
"cached_users": len(
[
k
for k in user_keys
if ":scores:" not in k and ":beatmapsets:" not in k
]
),
"cached_v1_users": len(
[k for k in v1_user_keys if ":scores:" not in k]
),
"cached_user_scores": len([k for k in user_keys if ":scores:" in k]),
"cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]),
"cached_user_beatmapsets": len(
[k for k in user_keys if ":beatmapsets:" in k]
),
"total_cached_entries": len(all_keys),
"estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0