Add grade hot cache

This commit is contained in:
咕谷酱
2025-08-21 23:35:25 +08:00
parent 7c193937d1
commit 822d7c6377
13 changed files with 973 additions and 47 deletions

View File

@@ -149,6 +149,14 @@ class Settings(BaseSettings):
ranking_cache_refresh_interval_minutes: int = 10 # 排行榜缓存刷新间隔(分钟)
ranking_cache_max_pages: int = 20 # 最多缓存的页数
ranking_cache_top_countries: int = 20 # 缓存前N个国家的排行榜
# 用户缓存设置
enable_user_cache_preload: bool = True # 启用用户缓存预加载
user_cache_expire_seconds: int = 300 # 用户信息缓存过期时间(秒)
user_scores_cache_expire_seconds: int = 60 # 用户成绩缓存过期时间(秒)
user_beatmapsets_cache_expire_seconds: int = 600 # 用户谱面集缓存过期时间(秒)
user_cache_max_preload_users: int = 200 # 最多预加载的用户数量
user_cache_concurrent_limit: int = 10 # 并发缓存用户的限制
# 反作弊设置
suspicious_score_check: bool = True

View File

@@ -7,8 +7,8 @@ from app.models.score import GameMode
from .lazer_user import BASE_INCLUDES, User, UserResp
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, Column, DateTime, Text
from pydantic import BaseModel, field_validator, model_validator
from sqlalchemy import Boolean, JSON, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -73,16 +73,16 @@ class BeatmapsetBase(SQLModel):
artist_unicode: str = Field(index=True)
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
creator: str = Field(index=True)
nsfw: bool = Field(default=False)
nsfw: bool = Field(default=False, sa_column=Column(Boolean))
play_count: int = Field(index=True)
preview_url: str
source: str = Field(default="")
spotlight: bool = Field(default=False)
spotlight: bool = Field(default=False, sa_column=Column(Boolean))
title: str = Field(index=True)
title_unicode: str = Field(index=True)
user_id: int = Field(index=True)
video: bool = Field(index=True)
video: bool = Field(sa_column=Column(Boolean, index=True))
# optional
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
@@ -102,13 +102,13 @@ class BeatmapsetBase(SQLModel):
# BeatmapsetExtended
bpm: float = Field(default=0.0)
can_be_hyped: bool = Field(default=False)
discussion_locked: bool = Field(default=False)
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
ranked_date: datetime | None = Field(
default=None, sa_column=Column(DateTime, index=True)
)
storyboard: bool = Field(default=False, index=True)
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
tags: str = Field(default="", sa_column=Column(Text))
@@ -133,7 +133,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
hype_current: int = Field(default=0)
hype_required: int = Field(default=0)
availability_info: str | None = Field(default=None)
download_disabled: bool = Field(default=False)
download_disabled: bool = Field(default=False, sa_column=Column(Boolean))
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
@@ -205,6 +205,14 @@ class BeatmapsetResp(BeatmapsetBase):
favourite_count: int = 0
recent_favourites: list[UserResp] = Field(default_factory=list)
@field_validator('nsfw', 'spotlight', 'video', 'can_be_hyped', 'discussion_locked', 'storyboard', 'discussion_enabled', 'is_scoreable', 'has_favourited', mode='before')
@classmethod
def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
if isinstance(v, int):
return bool(v)
return v
@model_validator(mode="after")
def fix_genre_language(self) -> Self:
if self.genre is None:

View File

@@ -0,0 +1,35 @@
"""
数据库字段类型工具
提供处理数据库和 Pydantic 之间类型转换的工具
"""
from typing import Any, Union
from pydantic import field_validator
from sqlalchemy import Boolean
def bool_field_validator(field_name: str):
"""为特定布尔字段创建验证器,处理数据库中的 0/1 整数"""
@field_validator(field_name, mode="before")
@classmethod
def validate_bool_field(cls, v: Any) -> bool:
"""将整数 0/1 转换为布尔值"""
if isinstance(v, int):
return bool(v)
return v
return validate_bool_field
def create_bool_field(**kwargs):
"""创建一个带有正确 SQLAlchemy 列定义的布尔字段"""
from sqlmodel import Field, Column
# 如果没有指定 sa_column则使用 Boolean 类型
if 'sa_column' not in kwargs:
# 处理 index 参数
index = kwargs.pop('index', False)
if index:
kwargs['sa_column'] = Column(Boolean, index=True)
else:
kwargs['sa_column'] = Column(Boolean)
return Field(**kwargs)

View File

@@ -45,8 +45,9 @@ from .relationship import (
)
from .score_token import ScoreToken
from pydantic import field_validator
from redis.asyncio import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased
from sqlalchemy.sql.elements import ColumnElement
@@ -79,13 +80,13 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
default=0, sa_column=Column(BigInteger)
) # solo_score
ended_at: datetime = Field(sa_column=Column(DateTime))
has_replay: bool
has_replay: bool = Field(sa_column=Column(Boolean))
max_combo: int
mods: list[APIMod] = Field(sa_column=Column(JSON))
passed: bool
passed: bool = Field(sa_column=Column(Boolean))
playlist_item_id: int | None = Field(default=None) # multiplayer
pp: float = Field(default=0.0)
preserve: bool = Field(default=True)
preserve: bool = Field(default=True, sa_column=Column(Boolean))
rank: Rank
room_id: int | None = Field(default=None) # multiplayer
started_at: datetime = Field(sa_column=Column(DateTime))
@@ -176,6 +177,14 @@ class ScoreResp(ScoreBase):
ranked: bool = False
current_user_attributes: CurrentUserAttributes | None = None
@field_validator('has_replay', 'passed', 'preserve', 'is_perfect_combo', 'legacy_perfect', 'processed', 'ranked', mode='before')
@classmethod
def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
if isinstance(v, int):
return bool(v)
return v
@classmethod
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
s = cls.model_validate(score.model_dump())

View File

@@ -28,9 +28,11 @@ def json_serializer(value):
engine = create_async_engine(
settings.database_url,
json_serializer=json_serializer,
pool_size=20,
max_overflow=20,
pool_size=30, # 增加连接池大小
max_overflow=50, # 增加最大溢出连接数
pool_timeout=30.0,
pool_recycle=3600, # 1小时回收连接
pool_pre_ping=True, # 启用连接预检查
)
# Redis 连接

View File

@@ -1,12 +1,16 @@
from __future__ import annotations
import asyncio
from datetime import datetime
from typing import Literal
from app.config import settings
from app.database.lazer_user import User
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.dependencies.database import Database
from app.dependencies.database import Database, get_redis
from app.log import logger
from app.models.score import GameMode
from app.service.user_cache_service import get_user_cache_service
from .router import AllStrModel, router
@@ -38,10 +42,21 @@ class V1User(AllStrModel):
pp_country_rank: int
events: list[dict]
@classmethod
def _get_cache_key(cls, user_id: int, ruleset: GameMode | None = None) -> str:
"""生成 V1 用户缓存键"""
if ruleset:
return f"v1_user:{user_id}:ruleset:{ruleset}"
return f"v1_user:{user_id}"
@classmethod
async def from_db(
cls, session: Database, db_user: User, ruleset: GameMode | None = None
) -> "V1User":
# 确保 user_id 不为 None
if db_user.id is None:
raise ValueError("User ID cannot be None")
ruleset = ruleset or db_user.playmode
current_statistics: UserStatistics | None = None
for i in await db_user.awaitable_attrs.statistics:
@@ -101,24 +116,55 @@ async def get_user(
default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
),
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 确定查询方式和用户ID
is_id_query = type == "id" or user.isdigit()
# 解析 ruleset
ruleset = GameMode.from_int_extra(ruleset_id) if ruleset_id else None
# 如果是 ID 查询,先尝试从缓存获取
cached_v1_user = None
user_id_for_cache = None
if is_id_query:
try:
user_id_for_cache = int(user)
cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset)
if cached_v1_user:
return [V1User(**cached_v1_user)]
except (ValueError, TypeError):
pass # 不是有效的用户ID继续数据库查询
# 从数据库查询用户
db_user = (
await session.exec(
select(User).where(
User.id == user
if type == "id" or user.isdigit()
else User.username == user,
User.id == user if is_id_query else User.username == user,
)
)
).first()
if not db_user:
return []
try:
return [
await V1User.from_db(
session,
db_user,
GameMode.from_int_extra(ruleset_id) if ruleset_id else None,
# 生成用户数据
v1_user = await V1User.from_db(session, db_user, ruleset)
# 异步缓存结果如果有用户ID
if db_user.id is not None:
user_data = v1_user.model_dump()
asyncio.create_task(
cache_service.cache_v1_user(user_data, db_user.id, ruleset)
)
]
return [v1_user]
except KeyError:
raise HTTPException(400, "Invalid request")
except ValueError as e:
logger.error(f"Error processing V1 user data: {e}")
raise HTTPException(500, "Internal server error")

144
app/router/v2/cache.py Normal file
View File

@@ -0,0 +1,144 @@
"""
缓存管理和监控接口
提供缓存统计、清理和预热功能
"""
from __future__ import annotations
from app.dependencies.database import get_redis
from app.dependencies.user import get_current_user
from app.service.user_cache_service import get_user_cache_service
from .router import router
from fastapi import Depends, HTTPException, Security
from pydantic import BaseModel
from redis.asyncio import Redis
class CacheStatsResponse(BaseModel):
user_cache: dict
redis_info: dict
@router.get(
"/cache/stats",
response_model=CacheStatsResponse,
name="获取缓存统计信息",
description="获取用户缓存和Redis的统计信息需要管理员权限。",
tags=["缓存管理"],
)
async def get_cache_stats(
redis: Redis = Depends(get_redis),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释,可根据需要启用
):
try:
cache_service = get_user_cache_service(redis)
user_cache_stats = await cache_service.get_cache_stats()
# 获取 Redis 基本信息
redis_info = await redis.info()
redis_stats = {
"connected_clients": redis_info.get("connected_clients", 0),
"used_memory_human": redis_info.get("used_memory_human", "0B"),
"used_memory_peak_human": redis_info.get("used_memory_peak_human", "0B"),
"total_commands_processed": redis_info.get("total_commands_processed", 0),
"keyspace_hits": redis_info.get("keyspace_hits", 0),
"keyspace_misses": redis_info.get("keyspace_misses", 0),
"evicted_keys": redis_info.get("evicted_keys", 0),
"expired_keys": redis_info.get("expired_keys", 0),
}
# 计算缓存命中率
hits = redis_stats["keyspace_hits"]
misses = redis_stats["keyspace_misses"]
hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0
redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2)
return CacheStatsResponse(
user_cache=user_cache_stats,
redis_info=redis_stats
)
except Exception as e:
raise HTTPException(500, f"Failed to get cache stats: {str(e)}")
@router.post(
"/cache/invalidate/{user_id}",
name="清除指定用户缓存",
description="清除指定用户的所有缓存数据,需要管理员权限。",
tags=["缓存管理"],
)
async def invalidate_user_cache(
user_id: int,
redis: Redis = Depends(get_redis),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
):
try:
cache_service = get_user_cache_service(redis)
await cache_service.invalidate_user_cache(user_id)
await cache_service.invalidate_v1_user_cache(user_id)
return {"message": f"Cache invalidated for user {user_id}"}
except Exception as e:
raise HTTPException(500, f"Failed to invalidate cache: {str(e)}")
@router.post(
"/cache/clear",
name="清除所有用户缓存",
description="清除所有用户相关的缓存数据,需要管理员权限。谨慎使用!",
tags=["缓存管理"],
)
async def clear_all_user_cache(
redis: Redis = Depends(get_redis),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
):
try:
# 获取所有用户相关的缓存键
user_keys = await redis.keys("user:*")
v1_user_keys = await redis.keys("v1_user:*")
all_keys = user_keys + v1_user_keys
if all_keys:
await redis.delete(*all_keys)
return {"message": f"Cleared {len(all_keys)} cache entries"}
else:
return {"message": "No cache entries found"}
except Exception as e:
raise HTTPException(500, f"Failed to clear cache: {str(e)}")
class CacheWarmupRequest(BaseModel):
user_ids: list[int] | None = None
limit: int = 100
@router.post(
"/cache/warmup",
name="缓存预热",
description="对指定用户或活跃用户进行缓存预热,需要管理员权限。",
tags=["缓存管理"],
)
async def warmup_cache(
request: CacheWarmupRequest,
redis: Redis = Depends(get_redis),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
):
try:
cache_service = get_user_cache_service(redis)
if request.user_ids:
# 预热指定用户
from app.dependencies.database import with_db
async with with_db() as session:
await cache_service.preload_user_cache(session, request.user_ids)
return {"message": f"Warmed up cache for {len(request.user_ids)} users"}
else:
# 预热活跃用户
from app.scheduler.user_cache_scheduler import schedule_user_cache_preload_task
await schedule_user_cache_preload_task()
return {"message": f"Warmed up cache for top {request.limit} active users"}
except Exception as e:
raise HTTPException(500, f"Failed to warmup cache: {str(e)}")

View File

@@ -47,6 +47,7 @@ from app.models.score import (
Rank,
SoloScoreSubmissionInfo,
)
from app.service.user_cache_service import get_user_cache_service
from app.storage.base import StorageService
from app.storage.local import LocalStorageService
@@ -182,6 +183,17 @@ async def submit_score(
}
db.add(rank_event)
await db.commit()
# 成绩提交后刷新用户缓存
try:
user_cache_service = get_user_cache_service(redis)
if current_user.id is not None:
await user_cache_service.refresh_user_cache_on_score_submit(
db, current_user.id, score.gamemode
)
except Exception as e:
logger.error(f"Failed to refresh user cache after score submit: {e}")
background_task.add_task(process_user_achievement, resp.id)
return resp

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
from datetime import UTC, datetime, timedelta
from typing import Literal
from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import (
BeatmapPlaycounts,
@@ -15,10 +17,11 @@ from app.database.events import EventResp
from app.database.lazer_user import SEARCH_INCLUDED
from app.database.pp_best_score import PPBestScore
from app.database.score import Score, ScoreResp
from app.dependencies.database import Database
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_current_user
from app.models.score import GameMode
from app.models.user import BeatmapsetType
from app.service.user_cache_service import get_user_cache_service
from .router import router
@@ -51,23 +54,55 @@ async def get_users(
default=False, description="是否包含各模式的统计信息"
), # TODO: future use
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
if user_ids:
searched_users = (
await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
).all()
# 先尝试从缓存获取
cached_users = []
uncached_user_ids = []
for user_id in user_ids[:50]: # 限制50个
cached_user = await cache_service.get_user_from_cache(user_id)
if cached_user:
cached_users.append(cached_user)
else:
uncached_user_ids.append(user_id)
# 查询未缓存的用户
if uncached_user_ids:
searched_users = (
await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))
).all()
# 将查询到的用户添加到缓存并返回
for searched_user in searched_users:
if searched_user.id != BANCHOBOT_ID:
user_resp = await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
)
cached_users.append(user_resp)
# 异步缓存,不阻塞响应
asyncio.create_task(cache_service.cache_user(user_resp))
return BatchUserResponse(users=cached_users)
else:
searched_users = (await session.exec(select(User).limit(50))).all()
return BatchUserResponse(
users=[
await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
)
for searched_user in searched_users
if searched_user.id != BANCHOBOT_ID
]
)
users = []
for searched_user in searched_users:
if searched_user.id != BANCHOBOT_ID:
user_resp = await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
)
users.append(user_resp)
# 异步缓存
asyncio.create_task(cache_service.cache_user(user_resp))
return BatchUserResponse(users=users)
@router.get(
@@ -83,6 +118,16 @@ async def get_user_info_ruleset(
ruleset: GameMode | None = Path(description="指定 ruleset"),
# current_user: User = Security(get_current_user, scopes=["public"]),
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 如果是数字ID先尝试从缓存获取
if user_id.isdigit():
user_id_int = int(user_id)
cached_user = await cache_service.get_user_from_cache(user_id_int, ruleset)
if cached_user:
return cached_user
searched_user = (
await session.exec(
select(User).where(
@@ -94,12 +139,18 @@ async def get_user_info_ruleset(
).first()
if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
return await UserResp.from_db(
user_resp = await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
ruleset=ruleset,
)
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
return user_resp
@router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False)
@@ -115,6 +166,16 @@ async def get_user_info(
user_id: str = Path(description="用户 ID 或用户名"),
# current_user: User = Security(get_current_user, scopes=["public"]),
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 如果是数字ID先尝试从缓存获取
if user_id.isdigit():
user_id_int = int(user_id)
cached_user = await cache_service.get_user_from_cache(user_id_int)
if cached_user:
return cached_user
searched_user = (
await session.exec(
select(User).where(
@@ -126,11 +187,17 @@ async def get_user_info(
).first()
if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
return await UserResp.from_db(
user_resp = await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
)
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp))
return user_resp
@router.get(
@@ -148,6 +215,20 @@ async def get_user_beatmapsets(
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache(
user_id, type.value, limit, offset
)
if cached_result is not None:
# 根据类型恢复对象
if type == BeatmapsetType.MOST_PLAYED:
return [BeatmapPlaycountsResp(**item) for item in cached_result]
else:
return [BeatmapsetResp(**item) for item in cached_result]
user = await session.get(User, user_id)
if not user or user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
@@ -190,6 +271,11 @@ async def get_user_beatmapsets(
else:
raise HTTPException(400, detail="Invalid beatmapset type")
# 异步缓存结果
asyncio.create_task(
cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
)
return resp
@@ -218,6 +304,17 @@ async def get_user_scores(
offset: int = Query(0, ge=0, description="偏移量"),
current_user: User = Security(get_current_user, scopes=["public"]),
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 先尝试从缓存获取对于recent类型使用较短的缓存时间
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
cached_scores = await cache_service.get_user_scores_from_cache(
user_id, type, mode, limit, offset
)
if cached_scores is not None:
return cached_scores
db_user = await session.get(User, user_id)
if not db_user or db_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
@@ -253,13 +350,23 @@ async def get_user_scores(
).all()
if not scores:
return []
return [
score_responses = [
await ScoreResp.from_db(
session,
score,
)
for score in scores
]
# 异步缓存结果
asyncio.create_task(
cache_service.cache_user_scores(
user_id, type, score_responses, mode, limit, offset, cache_expire
)
)
return score_responses
@router.get(

View File

@@ -6,6 +6,11 @@ from app.config import settings
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.log import logger
from app.scheduler.user_cache_scheduler import (
schedule_user_cache_cleanup_task,
schedule_user_cache_preload_task,
schedule_user_cache_warmup_task,
)
class CacheScheduler:
@@ -42,17 +47,26 @@ class CacheScheduler:
# 启动时执行一次排行榜缓存刷新
await self._refresh_ranking_cache()
# 启动时执行一次用户缓存预热
await self._warmup_user_cache()
beatmap_cache_counter = 0
ranking_cache_counter = 0
user_cache_counter = 0
user_cleanup_counter = 0
# 从配置文件获取间隔设置
check_interval = 5 * 60 # 5分钟检查间隔
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取
user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔
user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔
beatmap_cache_cycles = beatmap_cache_interval // check_interval
ranking_cache_cycles = ranking_cache_interval // check_interval
user_cache_cycles = user_cache_interval // check_interval
user_cleanup_cycles = user_cleanup_interval // check_interval
while self.running:
try:
@@ -64,6 +78,8 @@ class CacheScheduler:
beatmap_cache_counter += 1
ranking_cache_counter += 1
user_cache_counter += 1
user_cleanup_counter += 1
# beatmap缓存预热
if beatmap_cache_counter >= beatmap_cache_cycles:
@@ -74,6 +90,16 @@ class CacheScheduler:
if ranking_cache_counter >= ranking_cache_cycles:
await self._refresh_ranking_cache()
ranking_cache_counter = 0
# 用户缓存预加载
if user_cache_counter >= user_cache_cycles:
await self._preload_user_cache()
user_cache_counter = 0
# 用户缓存清理
if user_cleanup_counter >= user_cleanup_cycles:
await self._cleanup_user_cache()
user_cleanup_counter = 0
except asyncio.CancelledError:
break
@@ -110,16 +136,37 @@ class CacheScheduler:
schedule_ranking_refresh_task,
)
# 获取数据库会话
async for session in get_db():
# 使用独立的数据库会话
from app.dependencies.database import with_db
async with with_db() as session:
await schedule_ranking_refresh_task(session, redis)
break # 只需要一次会话
logger.info("Ranking cache refresh completed successfully")
except Exception as e:
logger.error(f"Ranking cache refresh failed: {e}")
async def _warmup_user_cache(self):
"""用户缓存预热"""
try:
await schedule_user_cache_warmup_task()
except Exception as e:
logger.error(f"User cache warmup failed: {e}")
async def _preload_user_cache(self):
"""用户缓存预加载"""
try:
await schedule_user_cache_preload_task()
except Exception as e:
logger.error(f"User cache preload failed: {e}")
async def _cleanup_user_cache(self):
"""用户缓存清理"""
try:
await schedule_user_cache_cleanup_task()
except Exception as e:
logger.error(f"User cache cleanup failed: {e}")
# Beatmap缓存调度器保持向后兼容
class BeatmapsetCacheScheduler(CacheScheduler):

View File

@@ -0,0 +1,120 @@
"""
用户缓存预热任务调度器
"""
from __future__ import annotations
import asyncio
from datetime import UTC, datetime, timedelta
from app.config import settings
from app.database import User
from app.database.score import Score
from app.dependencies.database import get_db, get_redis
from app.log import logger
from app.service.user_cache_service import get_user_cache_service
from sqlmodel import col, func, select
async def schedule_user_cache_preload_task():
"""定时用户缓存预加载任务"""
# 默认启用用户缓存预加载,除非明确禁用
enable_user_cache_preload = getattr(settings, "enable_user_cache_preload", True)
if not enable_user_cache_preload:
return
try:
logger.info("Starting user cache preload task...")
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 使用独立的数据库会话
from app.dependencies.database import with_db
async with with_db() as session:
# 获取最近24小时内活跃的用户提交过成绩的用户
recent_time = datetime.now(UTC) - timedelta(hours=24)
active_user_ids = (
await session.exec(
select(Score.user_id, func.count().label("score_count"))
.where(col(Score.ended_at) >= recent_time)
.group_by(col(Score.user_id))
.order_by(col("score_count").desc())
.limit(settings.user_cache_max_preload_users) # 使用配置中的限制
)
).all()
if active_user_ids:
user_ids = [row[0] for row in active_user_ids]
await cache_service.preload_user_cache(session, user_ids)
logger.info(f"Preloaded cache for {len(user_ids)} active users")
else:
logger.info("No active users found for cache preload")
logger.info("User cache preload task completed successfully")
except Exception as e:
logger.error(f"User cache preload task failed: {e}")
async def schedule_user_cache_warmup_task():
"""定时用户缓存预热任务 - 预加载排行榜前100用户"""
try:
logger.info("Starting user cache warmup task...")
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 使用独立的数据库会话
from app.dependencies.database import with_db
async with with_db() as session:
# 获取全球排行榜前100的用户
from app.database.statistics import UserStatistics
from app.models.score import GameMode
for mode in GameMode:
try:
top_users = (
await session.exec(
select(UserStatistics.user_id)
.where(UserStatistics.mode == mode)
.order_by(col(UserStatistics.pp).desc())
.limit(100)
)
).all()
if top_users:
user_ids = list(top_users)
await cache_service.preload_user_cache(session, user_ids)
logger.info(f"Warmed cache for top 100 users in {mode}")
# 避免过载,稍微延迟
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Failed to warm cache for {mode}: {e}")
continue
logger.info("User cache warmup task completed successfully")
except Exception as e:
logger.error(f"User cache warmup task failed: {e}")
async def schedule_user_cache_cleanup_task():
"""定时用户缓存清理任务"""
try:
logger.info("Starting user cache cleanup task...")
redis = get_redis()
# 清理过期的用户缓存Redis会自动处理TTL这里主要记录统计信息
cache_service = get_user_cache_service(redis)
stats = await cache_service.get_cache_stats()
logger.info(f"User cache stats: {stats}")
logger.info("User cache cleanup task completed successfully")
except Exception as e:
logger.error(f"User cache cleanup task failed: {e}")

View File

@@ -421,7 +421,7 @@ class RankingCacheService:
select(User.country_code, func.count().label("user_count"))
.where(col(User.is_active).is_(True))
.group_by(User.country_code)
.order_by(col("user_count").desc())
.order_by(func.count().desc())
.limit(settings.ranking_cache_top_countries)
)
).all()

View File

@@ -0,0 +1,388 @@
"""
用户缓存服务
用于缓存用户信息,提供热缓存和实时刷新功能
"""
from __future__ import annotations
import asyncio
import json
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Any, Literal
from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import User, UserResp
from app.database.lazer_user import SEARCH_INCLUDED
from app.database.pp_best_score import PPBestScore
from app.database.score import Score, ScoreResp
from app.log import logger
from app.models.score import GameMode
from redis.asyncio import Redis
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
pass
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
def safe_json_dumps(data: Any) -> str:
"""安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
class UserCacheService:
"""用户缓存服务"""
def __init__(self, redis: Redis):
self.redis = redis
self._refreshing = False
self._background_tasks: set = set()
def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str:
"""生成 V1 用户缓存键"""
if ruleset:
return f"v1_user:{user_id}:ruleset:{ruleset}"
return f"v1_user:{user_id}"
async def get_v1_user_from_cache(
self,
user_id: int,
ruleset: GameMode | None = None
) -> dict | None:
"""从缓存获取 V1 用户信息"""
try:
cache_key = self._get_v1_user_cache_key(user_id, ruleset)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(f"V1 User cache hit for user {user_id}")
return json.loads(cached_data)
return None
except Exception as e:
logger.error(f"Error getting V1 user from cache: {e}")
return None
async def cache_v1_user(
self,
user_data: dict,
user_id: int,
ruleset: GameMode | None = None,
expire_seconds: int | None = None
):
"""缓存 V1 用户信息"""
try:
if expire_seconds is None:
expire_seconds = settings.user_cache_expire_seconds
cache_key = self._get_v1_user_cache_key(user_id, ruleset)
cached_data = safe_json_dumps(user_data)
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached V1 user {user_id} for {expire_seconds}s")
except Exception as e:
logger.error(f"Error caching V1 user: {e}")
async def invalidate_v1_user_cache(self, user_id: int):
"""使 V1 用户缓存失效"""
try:
# 删除 V1 用户信息缓存
pattern = f"v1_user:{user_id}*"
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}")
except Exception as e:
logger.error(f"Error invalidating V1 user cache: {e}")
def _get_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str:
"""生成用户缓存键"""
if ruleset:
return f"user:{user_id}:ruleset:{ruleset}"
return f"user:{user_id}"
def _get_user_scores_cache_key(
self,
user_id: int,
score_type: str,
mode: GameMode | None = None,
limit: int = 100,
offset: int = 0
) -> str:
"""生成用户成绩缓存键"""
mode_part = f":{mode}" if mode else ""
return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}"
def _get_user_beatmapsets_cache_key(
self,
user_id: int,
beatmapset_type: str,
limit: int = 100,
offset: int = 0
) -> str:
"""生成用户谱面集缓存键"""
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
async def get_user_from_cache(
self,
user_id: int,
ruleset: GameMode | None = None
) -> UserResp | None:
"""从缓存获取用户信息"""
try:
cache_key = self._get_user_cache_key(user_id, ruleset)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(f"User cache hit for user {user_id}")
data = json.loads(cached_data)
return UserResp(**data)
return None
except Exception as e:
logger.error(f"Error getting user from cache: {e}")
return None
async def cache_user(
self,
user_resp: UserResp,
ruleset: GameMode | None = None,
expire_seconds: int | None = None
):
"""缓存用户信息"""
try:
if expire_seconds is None:
expire_seconds = settings.user_cache_expire_seconds
if user_resp.id is None:
logger.warning("Cannot cache user with None id")
return
cache_key = self._get_user_cache_key(user_resp.id, ruleset)
cached_data = user_resp.model_dump_json()
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_resp.id} for {expire_seconds}s")
except Exception as e:
logger.error(f"Error caching user: {e}")
async def get_user_scores_from_cache(
self,
user_id: int,
score_type: str,
mode: GameMode | None = None,
limit: int = 100,
offset: int = 0
) -> list[ScoreResp] | None:
"""从缓存获取用户成绩"""
try:
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
data = json.loads(cached_data)
return [ScoreResp(**score_data) for score_data in data]
return None
except Exception as e:
logger.error(f"Error getting user scores from cache: {e}")
return None
async def cache_user_scores(
self,
user_id: int,
score_type: str,
scores: list[ScoreResp],
mode: GameMode | None = None,
limit: int = 100,
offset: int = 0,
expire_seconds: int | None = None
):
"""缓存用户成绩"""
try:
if expire_seconds is None:
expire_seconds = settings.user_scores_cache_expire_seconds
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
scores_json_list = [score.model_dump_json() for score in scores]
cached_data = f"[{','.join(scores_json_list)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
except Exception as e:
logger.error(f"Error caching user scores: {e}")
async def get_user_beatmapsets_from_cache(
self,
user_id: int,
beatmapset_type: str,
limit: int = 100,
offset: int = 0
) -> list[Any] | None:
"""从缓存获取用户谱面集"""
try:
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}")
return json.loads(cached_data)
return None
except Exception as e:
logger.error(f"Error getting user beatmapsets from cache: {e}")
return None
async def cache_user_beatmapsets(
self,
user_id: int,
beatmapset_type: str,
beatmapsets: list[Any],
limit: int = 100,
offset: int = 0,
expire_seconds: int | None = None
):
"""缓存用户谱面集"""
try:
if expire_seconds is None:
expire_seconds = settings.user_beatmapsets_cache_expire_seconds
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
# 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps
serialized_beatmapsets = []
for bms in beatmapsets:
if hasattr(bms, 'model_dump_json'):
serialized_beatmapsets.append(bms.model_dump_json())
else:
serialized_beatmapsets.append(safe_json_dumps(bms))
cached_data = f"[{','.join(serialized_beatmapsets)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s")
except Exception as e:
logger.error(f"Error caching user beatmapsets: {e}")
async def invalidate_user_cache(self, user_id: int):
"""使用户缓存失效"""
try:
# 删除用户信息缓存
pattern = f"user:{user_id}*"
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
logger.info(f"Invalidated {len(keys)} cache entries for user {user_id}")
except Exception as e:
logger.error(f"Error invalidating user cache: {e}")
async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None):
"""使用户成绩缓存失效"""
try:
# 删除用户成绩相关缓存
mode_pattern = f":{mode}" if mode else "*"
pattern = f"user:{user_id}:scores:*{mode_pattern}*"
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}")
except Exception as e:
logger.error(f"Error invalidating user scores cache: {e}")
async def preload_user_cache(self, session: AsyncSession, user_ids: list[int]):
"""预加载用户缓存"""
if self._refreshing:
return
self._refreshing = True
try:
logger.info(f"Preloading cache for {len(user_ids)} users")
# 批量获取用户
users = (
await session.exec(
select(User).where(col(User.id).in_(user_ids))
)
).all()
# 串行缓存用户信息,避免并发数据库访问问题
cached_count = 0
for user in users:
if user.id != BANCHOBOT_ID:
try:
await self._cache_single_user(user, session)
cached_count += 1
except Exception as e:
logger.error(f"Failed to cache user {user.id}: {e}")
logger.info(f"Preloaded cache for {cached_count} users")
except Exception as e:
logger.error(f"Error preloading user cache: {e}")
finally:
self._refreshing = False
async def _cache_single_user(self, user: User, session: AsyncSession):
"""缓存单个用户"""
try:
user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED)
await self.cache_user(user_resp)
except Exception as e:
logger.error(f"Error caching single user {user.id}: {e}")
async def refresh_user_cache_on_score_submit(
self,
session: AsyncSession,
user_id: int,
mode: GameMode
):
"""成绩提交后刷新用户缓存"""
try:
# 使相关缓存失效(包括 v1 和 v2
await self.invalidate_user_cache(user_id)
await self.invalidate_v1_user_cache(user_id)
await self.invalidate_user_scores_cache(user_id, mode)
# 立即重新加载用户信息
user = await session.get(User, user_id)
if user and user.id != BANCHOBOT_ID:
await self._cache_single_user(user, session)
logger.info(f"Refreshed cache for user {user_id} after score submit")
except Exception as e:
logger.error(f"Error refreshing user cache on score submit: {e}")
async def get_cache_stats(self) -> dict:
"""获取缓存统计信息"""
try:
user_keys = await self.redis.keys("user:*")
v1_user_keys = await self.redis.keys("v1_user:*")
all_keys = user_keys + v1_user_keys
total_size = 0
for key in all_keys[:100]: # 限制检查数量
try:
size = await self.redis.memory_usage(key)
if size:
total_size += size
except Exception:
continue
return {
"cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]),
"cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]),
"cached_user_scores": len([k for k in user_keys if ":scores:" in k]),
"cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]),
"total_cached_entries": len(all_keys),
"estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
),
"refreshing": self._refreshing,
}
except Exception as e:
logger.error(f"Error getting user cache stats: {e}")
return {"error": str(e)}
# 全局缓存服务实例
_user_cache_service: UserCacheService | None = None
def get_user_cache_service(redis: Redis) -> UserCacheService:
"""获取用户缓存服务实例"""
global _user_cache_service
if _user_cache_service is None:
_user_cache_service = UserCacheService(redis)
return _user_cache_service