From 822d7c6377c2fe84b4cc5a3bd95fac90361f4949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= <74496778+GooGuJiang@users.noreply.github.com> Date: Thu, 21 Aug 2025 23:35:25 +0800 Subject: [PATCH] Add grade hot cache --- app/config.py | 8 + app/database/beatmapset.py | 26 +- app/database/field_utils.py | 35 +++ app/database/score.py | 17 +- app/dependencies/database.py | 6 +- app/router/v1/user.py | 66 ++++- app/router/v2/cache.py | 144 ++++++++++ app/router/v2/score.py | 12 + app/router/v2/user.py | 143 ++++++++-- app/scheduler/cache_scheduler.py | 53 +++- app/scheduler/user_cache_scheduler.py | 120 ++++++++ app/service/ranking_cache_service.py | 2 +- app/service/user_cache_service.py | 388 ++++++++++++++++++++++++++ 13 files changed, 973 insertions(+), 47 deletions(-) create mode 100644 app/database/field_utils.py create mode 100644 app/router/v2/cache.py create mode 100644 app/scheduler/user_cache_scheduler.py create mode 100644 app/service/user_cache_service.py diff --git a/app/config.py b/app/config.py index 1c679f5..96ee2df 100644 --- a/app/config.py +++ b/app/config.py @@ -149,6 +149,14 @@ class Settings(BaseSettings): ranking_cache_refresh_interval_minutes: int = 10 # 排行榜缓存刷新间隔(分钟) ranking_cache_max_pages: int = 20 # 最多缓存的页数 ranking_cache_top_countries: int = 20 # 缓存前N个国家的排行榜 + + # 用户缓存设置 + enable_user_cache_preload: bool = True # 启用用户缓存预加载 + user_cache_expire_seconds: int = 300 # 用户信息缓存过期时间(秒) + user_scores_cache_expire_seconds: int = 60 # 用户成绩缓存过期时间(秒) + user_beatmapsets_cache_expire_seconds: int = 600 # 用户谱面集缓存过期时间(秒) + user_cache_max_preload_users: int = 200 # 最多预加载的用户数量 + user_cache_concurrent_limit: int = 10 # 并发缓存用户的限制 # 反作弊设置 suspicious_score_check: bool = True diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 2c4dfd8..447da5b 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -7,8 +7,8 @@ from app.models.score import GameMode from .lazer_user import BASE_INCLUDES, User, UserResp -from pydantic import BaseModel, model_validator -from sqlalchemy import JSON, Column, DateTime, Text +from pydantic import BaseModel, field_validator, model_validator +from sqlalchemy import Boolean, JSON, Column, DateTime, Text from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -73,16 +73,16 @@ class BeatmapsetBase(SQLModel): artist_unicode: str = Field(index=True) covers: BeatmapCovers | None = Field(sa_column=Column(JSON)) creator: str = Field(index=True) - nsfw: bool = Field(default=False) + nsfw: bool = Field(default=False, sa_column=Column(Boolean)) play_count: int = Field(index=True) preview_url: str source: str = Field(default="") - spotlight: bool = Field(default=False) + spotlight: bool = Field(default=False, sa_column=Column(Boolean)) title: str = Field(index=True) title_unicode: str = Field(index=True) user_id: int = Field(index=True) - video: bool = Field(index=True) + video: bool = Field(sa_column=Column(Boolean, index=True)) # optional # converts: list[Beatmap] = Relationship(back_populates="beatmapset") @@ -102,13 +102,13 @@ class BeatmapsetBase(SQLModel): # BeatmapsetExtended bpm: float = Field(default=0.0) - can_be_hyped: bool = Field(default=False) - discussion_locked: bool = Field(default=False) + can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean)) + discussion_locked: bool = Field(default=False, sa_column=Column(Boolean)) last_updated: datetime = Field(sa_column=Column(DateTime, index=True)) ranked_date: datetime | None = Field( default=None, sa_column=Column(DateTime, index=True) ) - storyboard: bool = Field(default=False, index=True) + storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True)) submitted_date: datetime = Field(sa_column=Column(DateTime, index=True)) tags: str = Field(default="", sa_column=Column(Text)) @@ -133,7 +133,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): hype_current: int = Field(default=0) hype_required: int = Field(default=0) availability_info: str | None = Field(default=None) - download_disabled: bool = Field(default=False) + download_disabled: bool = Field(default=False, sa_column=Column(Boolean)) favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") @classmethod @@ -205,6 +205,14 @@ class BeatmapsetResp(BeatmapsetBase): favourite_count: int = 0 recent_favourites: list[UserResp] = Field(default_factory=list) + @field_validator('nsfw', 'spotlight', 'video', 'can_be_hyped', 'discussion_locked', 'storyboard', 'discussion_enabled', 'is_scoreable', 'has_favourited', mode='before') + @classmethod + def validate_bool_fields(cls, v): + """将整数 0/1 转换为布尔值,处理数据库中的布尔字段""" + if isinstance(v, int): + return bool(v) + return v + @model_validator(mode="after") def fix_genre_language(self) -> Self: if self.genre is None: diff --git a/app/database/field_utils.py b/app/database/field_utils.py new file mode 100644 index 0000000..53c1011 --- /dev/null +++ b/app/database/field_utils.py @@ -0,0 +1,35 @@ +""" +数据库字段类型工具 +提供处理数据库和 Pydantic 之间类型转换的工具 +""" +from typing import Any, Union +from pydantic import field_validator +from sqlalchemy import Boolean + + +def bool_field_validator(field_name: str): + """为特定布尔字段创建验证器,处理数据库中的 0/1 整数""" + @field_validator(field_name, mode="before") + @classmethod + def validate_bool_field(cls, v: Any) -> bool: + """将整数 0/1 转换为布尔值""" + if isinstance(v, int): + return bool(v) + return v + return validate_bool_field + + +def create_bool_field(**kwargs): + """创建一个带有正确 SQLAlchemy 列定义的布尔字段""" + from sqlmodel import Field, Column + + # 如果没有指定 sa_column,则使用 Boolean 类型 + if 'sa_column' not in kwargs: + # 处理 index 参数 + index = kwargs.pop('index', False) + if index: + kwargs['sa_column'] = Column(Boolean, index=True) + else: + kwargs['sa_column'] = Column(Boolean) + + return Field(**kwargs) diff --git a/app/database/score.py b/app/database/score.py index f1d3167..a92c6ad 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -45,8 +45,9 @@ from .relationship import ( ) from .score_token import ScoreToken +from pydantic import field_validator from redis.asyncio import Redis -from sqlalchemy import Column, ColumnExpressionArgument, DateTime +from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import aliased from sqlalchemy.sql.elements import ColumnElement @@ -79,13 +80,13 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): default=0, sa_column=Column(BigInteger) ) # solo_score ended_at: datetime = Field(sa_column=Column(DateTime)) - has_replay: bool + has_replay: bool = Field(sa_column=Column(Boolean)) max_combo: int mods: list[APIMod] = Field(sa_column=Column(JSON)) - passed: bool + passed: bool = Field(sa_column=Column(Boolean)) playlist_item_id: int | None = Field(default=None) # multiplayer pp: float = Field(default=0.0) - preserve: bool = Field(default=True) + preserve: bool = Field(default=True, sa_column=Column(Boolean)) rank: Rank room_id: int | None = Field(default=None) # multiplayer started_at: datetime = Field(sa_column=Column(DateTime)) @@ -176,6 +177,14 @@ class ScoreResp(ScoreBase): ranked: bool = False current_user_attributes: CurrentUserAttributes | None = None + @field_validator('has_replay', 'passed', 'preserve', 'is_perfect_combo', 'legacy_perfect', 'processed', 'ranked', mode='before') + @classmethod + def validate_bool_fields(cls, v): + """将整数 0/1 转换为布尔值,处理数据库中的布尔字段""" + if isinstance(v, int): + return bool(v) + return v + @classmethod async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": s = cls.model_validate(score.model_dump()) diff --git a/app/dependencies/database.py b/app/dependencies/database.py index 83e5876..f345537 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -28,9 +28,11 @@ def json_serializer(value): engine = create_async_engine( settings.database_url, json_serializer=json_serializer, - pool_size=20, - max_overflow=20, + pool_size=30, # 增加连接池大小 + max_overflow=50, # 增加最大溢出连接数 pool_timeout=30.0, + pool_recycle=3600, # 1小时回收连接 + pool_pre_ping=True, # 启用连接预检查 ) # Redis 连接 diff --git a/app/router/v1/user.py b/app/router/v1/user.py index 19ada39..d8938ed 100644 --- a/app/router/v1/user.py +++ b/app/router/v1/user.py @@ -1,12 +1,16 @@ from __future__ import annotations +import asyncio from datetime import datetime from typing import Literal +from app.config import settings from app.database.lazer_user import User from app.database.statistics import UserStatistics, UserStatisticsResp -from app.dependencies.database import Database +from app.dependencies.database import Database, get_redis +from app.log import logger from app.models.score import GameMode +from app.service.user_cache_service import get_user_cache_service from .router import AllStrModel, router @@ -38,10 +42,21 @@ class V1User(AllStrModel): pp_country_rank: int events: list[dict] + @classmethod + def _get_cache_key(cls, user_id: int, ruleset: GameMode | None = None) -> str: + """生成 V1 用户缓存键""" + if ruleset: + return f"v1_user:{user_id}:ruleset:{ruleset}" + return f"v1_user:{user_id}" + @classmethod async def from_db( cls, session: Database, db_user: User, ruleset: GameMode | None = None ) -> "V1User": + # 确保 user_id 不为 None + if db_user.id is None: + raise ValueError("User ID cannot be None") + ruleset = ruleset or db_user.playmode current_statistics: UserStatistics | None = None for i in await db_user.awaitable_attrs.statistics: @@ -101,24 +116,55 @@ async def get_user( default=1, ge=1, le=31, description="从现在起所有事件的最大天数" ), ): + redis = get_redis() + cache_service = get_user_cache_service(redis) + + # 确定查询方式和用户ID + is_id_query = type == "id" or user.isdigit() + + # 解析 ruleset + ruleset = GameMode.from_int_extra(ruleset_id) if ruleset_id else None + + # 如果是 ID 查询,先尝试从缓存获取 + cached_v1_user = None + user_id_for_cache = None + + if is_id_query: + try: + user_id_for_cache = int(user) + cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset) + if cached_v1_user: + return [V1User(**cached_v1_user)] + except (ValueError, TypeError): + pass # 不是有效的用户ID,继续数据库查询 + + # 从数据库查询用户 db_user = ( await session.exec( select(User).where( - User.id == user - if type == "id" or user.isdigit() - else User.username == user, + User.id == user if is_id_query else User.username == user, ) ) ).first() + if not db_user: return [] + try: - return [ - await V1User.from_db( - session, - db_user, - GameMode.from_int_extra(ruleset_id) if ruleset_id else None, + # 生成用户数据 + v1_user = await V1User.from_db(session, db_user, ruleset) + + # 异步缓存结果(如果有用户ID) + if db_user.id is not None: + user_data = v1_user.model_dump() + asyncio.create_task( + cache_service.cache_v1_user(user_data, db_user.id, ruleset) ) - ] + + return [v1_user] + except KeyError: raise HTTPException(400, "Invalid request") + except ValueError as e: + logger.error(f"Error processing V1 user data: {e}") + raise HTTPException(500, "Internal server error") diff --git a/app/router/v2/cache.py b/app/router/v2/cache.py new file mode 100644 index 0000000..7dd970c --- /dev/null +++ b/app/router/v2/cache.py @@ -0,0 +1,144 @@ +""" +缓存管理和监控接口 +提供缓存统计、清理和预热功能 +""" +from __future__ import annotations + +from app.dependencies.database import get_redis +from app.dependencies.user import get_current_user +from app.service.user_cache_service import get_user_cache_service + +from .router import router + +from fastapi import Depends, HTTPException, Security +from pydantic import BaseModel +from redis.asyncio import Redis + + +class CacheStatsResponse(BaseModel): + user_cache: dict + redis_info: dict + + +@router.get( + "/cache/stats", + response_model=CacheStatsResponse, + name="获取缓存统计信息", + description="获取用户缓存和Redis的统计信息,需要管理员权限。", + tags=["缓存管理"], +) +async def get_cache_stats( + redis: Redis = Depends(get_redis), + # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释,可根据需要启用 +): + try: + cache_service = get_user_cache_service(redis) + user_cache_stats = await cache_service.get_cache_stats() + + # 获取 Redis 基本信息 + redis_info = await redis.info() + redis_stats = { + "connected_clients": redis_info.get("connected_clients", 0), + "used_memory_human": redis_info.get("used_memory_human", "0B"), + "used_memory_peak_human": redis_info.get("used_memory_peak_human", "0B"), + "total_commands_processed": redis_info.get("total_commands_processed", 0), + "keyspace_hits": redis_info.get("keyspace_hits", 0), + "keyspace_misses": redis_info.get("keyspace_misses", 0), + "evicted_keys": redis_info.get("evicted_keys", 0), + "expired_keys": redis_info.get("expired_keys", 0), + } + + # 计算缓存命中率 + hits = redis_stats["keyspace_hits"] + misses = redis_stats["keyspace_misses"] + hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0 + redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2) + + return CacheStatsResponse( + user_cache=user_cache_stats, + redis_info=redis_stats + ) + + except Exception as e: + raise HTTPException(500, f"Failed to get cache stats: {str(e)}") + + +@router.post( + "/cache/invalidate/{user_id}", + name="清除指定用户缓存", + description="清除指定用户的所有缓存数据,需要管理员权限。", + tags=["缓存管理"], +) +async def invalidate_user_cache( + user_id: int, + redis: Redis = Depends(get_redis), + # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释 +): + try: + cache_service = get_user_cache_service(redis) + await cache_service.invalidate_user_cache(user_id) + await cache_service.invalidate_v1_user_cache(user_id) + return {"message": f"Cache invalidated for user {user_id}"} + except Exception as e: + raise HTTPException(500, f"Failed to invalidate cache: {str(e)}") + + +@router.post( + "/cache/clear", + name="清除所有用户缓存", + description="清除所有用户相关的缓存数据,需要管理员权限。谨慎使用!", + tags=["缓存管理"], +) +async def clear_all_user_cache( + redis: Redis = Depends(get_redis), + # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释 +): + try: + # 获取所有用户相关的缓存键 + user_keys = await redis.keys("user:*") + v1_user_keys = await redis.keys("v1_user:*") + all_keys = user_keys + v1_user_keys + + if all_keys: + await redis.delete(*all_keys) + return {"message": f"Cleared {len(all_keys)} cache entries"} + else: + return {"message": "No cache entries found"} + + except Exception as e: + raise HTTPException(500, f"Failed to clear cache: {str(e)}") + + +class CacheWarmupRequest(BaseModel): + user_ids: list[int] | None = None + limit: int = 100 + + +@router.post( + "/cache/warmup", + name="缓存预热", + description="对指定用户或活跃用户进行缓存预热,需要管理员权限。", + tags=["缓存管理"], +) +async def warmup_cache( + request: CacheWarmupRequest, + redis: Redis = Depends(get_redis), + # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释 +): + try: + cache_service = get_user_cache_service(redis) + + if request.user_ids: + # 预热指定用户 + from app.dependencies.database import with_db + async with with_db() as session: + await cache_service.preload_user_cache(session, request.user_ids) + return {"message": f"Warmed up cache for {len(request.user_ids)} users"} + else: + # 预热活跃用户 + from app.scheduler.user_cache_scheduler import schedule_user_cache_preload_task + await schedule_user_cache_preload_task() + return {"message": f"Warmed up cache for top {request.limit} active users"} + + except Exception as e: + raise HTTPException(500, f"Failed to warmup cache: {str(e)}") diff --git a/app/router/v2/score.py b/app/router/v2/score.py index 192f383..648505a 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -47,6 +47,7 @@ from app.models.score import ( Rank, SoloScoreSubmissionInfo, ) +from app.service.user_cache_service import get_user_cache_service from app.storage.base import StorageService from app.storage.local import LocalStorageService @@ -182,6 +183,17 @@ async def submit_score( } db.add(rank_event) await db.commit() + + # 成绩提交后刷新用户缓存 + try: + user_cache_service = get_user_cache_service(redis) + if current_user.id is not None: + await user_cache_service.refresh_user_cache_on_score_submit( + db, current_user.id, score.gamemode + ) + except Exception as e: + logger.error(f"Failed to refresh user cache after score submit: {e}") + background_task.add_task(process_user_achievement, resp.id) return resp diff --git a/app/router/v2/user.py b/app/router/v2/user.py index e18cd0b..6687e07 100644 --- a/app/router/v2/user.py +++ b/app/router/v2/user.py @@ -1,8 +1,10 @@ from __future__ import annotations +import asyncio from datetime import UTC, datetime, timedelta from typing import Literal +from app.config import settings from app.const import BANCHOBOT_ID from app.database import ( BeatmapPlaycounts, @@ -15,10 +17,11 @@ from app.database.events import EventResp from app.database.lazer_user import SEARCH_INCLUDED from app.database.pp_best_score import PPBestScore from app.database.score import Score, ScoreResp -from app.dependencies.database import Database +from app.dependencies.database import Database, get_redis from app.dependencies.user import get_current_user from app.models.score import GameMode from app.models.user import BeatmapsetType +from app.service.user_cache_service import get_user_cache_service from .router import router @@ -51,23 +54,55 @@ async def get_users( default=False, description="是否包含各模式的统计信息" ), # TODO: future use ): + redis = get_redis() + cache_service = get_user_cache_service(redis) + if user_ids: - searched_users = ( - await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids))) - ).all() + # 先尝试从缓存获取 + cached_users = [] + uncached_user_ids = [] + + for user_id in user_ids[:50]: # 限制50个 + cached_user = await cache_service.get_user_from_cache(user_id) + if cached_user: + cached_users.append(cached_user) + else: + uncached_user_ids.append(user_id) + + # 查询未缓存的用户 + if uncached_user_ids: + searched_users = ( + await session.exec(select(User).where(col(User.id).in_(uncached_user_ids))) + ).all() + + # 将查询到的用户添加到缓存并返回 + for searched_user in searched_users: + if searched_user.id != BANCHOBOT_ID: + user_resp = await UserResp.from_db( + searched_user, + session, + include=SEARCH_INCLUDED, + ) + cached_users.append(user_resp) + # 异步缓存,不阻塞响应 + asyncio.create_task(cache_service.cache_user(user_resp)) + + return BatchUserResponse(users=cached_users) else: searched_users = (await session.exec(select(User).limit(50))).all() - return BatchUserResponse( - users=[ - await UserResp.from_db( - searched_user, - session, - include=SEARCH_INCLUDED, - ) - for searched_user in searched_users - if searched_user.id != BANCHOBOT_ID - ] - ) + users = [] + for searched_user in searched_users: + if searched_user.id != BANCHOBOT_ID: + user_resp = await UserResp.from_db( + searched_user, + session, + include=SEARCH_INCLUDED, + ) + users.append(user_resp) + # 异步缓存 + asyncio.create_task(cache_service.cache_user(user_resp)) + + return BatchUserResponse(users=users) @router.get( @@ -83,6 +118,16 @@ async def get_user_info_ruleset( ruleset: GameMode | None = Path(description="指定 ruleset"), # current_user: User = Security(get_current_user, scopes=["public"]), ): + redis = get_redis() + cache_service = get_user_cache_service(redis) + + # 如果是数字ID,先尝试从缓存获取 + if user_id.isdigit(): + user_id_int = int(user_id) + cached_user = await cache_service.get_user_from_cache(user_id_int, ruleset) + if cached_user: + return cached_user + searched_user = ( await session.exec( select(User).where( @@ -94,12 +139,18 @@ async def get_user_info_ruleset( ).first() if not searched_user or searched_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") - return await UserResp.from_db( + + user_resp = await UserResp.from_db( searched_user, session, include=SEARCH_INCLUDED, ruleset=ruleset, ) + + # 异步缓存结果 + asyncio.create_task(cache_service.cache_user(user_resp, ruleset)) + + return user_resp @router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False) @@ -115,6 +166,16 @@ async def get_user_info( user_id: str = Path(description="用户 ID 或用户名"), # current_user: User = Security(get_current_user, scopes=["public"]), ): + redis = get_redis() + cache_service = get_user_cache_service(redis) + + # 如果是数字ID,先尝试从缓存获取 + if user_id.isdigit(): + user_id_int = int(user_id) + cached_user = await cache_service.get_user_from_cache(user_id_int) + if cached_user: + return cached_user + searched_user = ( await session.exec( select(User).where( @@ -126,11 +187,17 @@ async def get_user_info( ).first() if not searched_user or searched_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") - return await UserResp.from_db( + + user_resp = await UserResp.from_db( searched_user, session, include=SEARCH_INCLUDED, ) + + # 异步缓存结果 + asyncio.create_task(cache_service.cache_user(user_resp)) + + return user_resp @router.get( @@ -148,6 +215,20 @@ async def get_user_beatmapsets( limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), offset: int = Query(0, ge=0, description="偏移量"), ): + redis = get_redis() + cache_service = get_user_cache_service(redis) + + # 先尝试从缓存获取 + cached_result = await cache_service.get_user_beatmapsets_from_cache( + user_id, type.value, limit, offset + ) + if cached_result is not None: + # 根据类型恢复对象 + if type == BeatmapsetType.MOST_PLAYED: + return [BeatmapPlaycountsResp(**item) for item in cached_result] + else: + return [BeatmapsetResp(**item) for item in cached_result] + user = await session.get(User, user_id) if not user or user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") @@ -190,6 +271,11 @@ async def get_user_beatmapsets( else: raise HTTPException(400, detail="Invalid beatmapset type") + # 异步缓存结果 + asyncio.create_task( + cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset) + ) + return resp @@ -218,6 +304,17 @@ async def get_user_scores( offset: int = Query(0, ge=0, description="偏移量"), current_user: User = Security(get_current_user, scopes=["public"]), ): + redis = get_redis() + cache_service = get_user_cache_service(redis) + + # 先尝试从缓存获取(对于recent类型使用较短的缓存时间) + cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds + cached_scores = await cache_service.get_user_scores_from_cache( + user_id, type, mode, limit, offset + ) + if cached_scores is not None: + return cached_scores + db_user = await session.get(User, user_id) if not db_user or db_user.id == BANCHOBOT_ID: raise HTTPException(404, detail="User not found") @@ -253,13 +350,23 @@ async def get_user_scores( ).all() if not scores: return [] - return [ + + score_responses = [ await ScoreResp.from_db( session, score, ) for score in scores ] + + # 异步缓存结果 + asyncio.create_task( + cache_service.cache_user_scores( + user_id, type, score_responses, mode, limit, offset, cache_expire + ) + ) + + return score_responses @router.get( diff --git a/app/scheduler/cache_scheduler.py b/app/scheduler/cache_scheduler.py index 4c4e8ae..72722e3 100644 --- a/app/scheduler/cache_scheduler.py +++ b/app/scheduler/cache_scheduler.py @@ -6,6 +6,11 @@ from app.config import settings from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.log import logger +from app.scheduler.user_cache_scheduler import ( + schedule_user_cache_cleanup_task, + schedule_user_cache_preload_task, + schedule_user_cache_warmup_task, +) class CacheScheduler: @@ -42,17 +47,26 @@ class CacheScheduler: # 启动时执行一次排行榜缓存刷新 await self._refresh_ranking_cache() + + # 启动时执行一次用户缓存预热 + await self._warmup_user_cache() beatmap_cache_counter = 0 ranking_cache_counter = 0 + user_cache_counter = 0 + user_cleanup_counter = 0 # 从配置文件获取间隔设置 check_interval = 5 * 60 # 5分钟检查间隔 beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔 ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取 + user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔 + user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔 beatmap_cache_cycles = beatmap_cache_interval // check_interval ranking_cache_cycles = ranking_cache_interval // check_interval + user_cache_cycles = user_cache_interval // check_interval + user_cleanup_cycles = user_cleanup_interval // check_interval while self.running: try: @@ -64,6 +78,8 @@ class CacheScheduler: beatmap_cache_counter += 1 ranking_cache_counter += 1 + user_cache_counter += 1 + user_cleanup_counter += 1 # beatmap缓存预热 if beatmap_cache_counter >= beatmap_cache_cycles: @@ -74,6 +90,16 @@ class CacheScheduler: if ranking_cache_counter >= ranking_cache_cycles: await self._refresh_ranking_cache() ranking_cache_counter = 0 + + # 用户缓存预加载 + if user_cache_counter >= user_cache_cycles: + await self._preload_user_cache() + user_cache_counter = 0 + + # 用户缓存清理 + if user_cleanup_counter >= user_cleanup_cycles: + await self._cleanup_user_cache() + user_cleanup_counter = 0 except asyncio.CancelledError: break @@ -110,16 +136,37 @@ class CacheScheduler: schedule_ranking_refresh_task, ) - # 获取数据库会话 - async for session in get_db(): + # 使用独立的数据库会话 + from app.dependencies.database import with_db + async with with_db() as session: await schedule_ranking_refresh_task(session, redis) - break # 只需要一次会话 logger.info("Ranking cache refresh completed successfully") except Exception as e: logger.error(f"Ranking cache refresh failed: {e}") + async def _warmup_user_cache(self): + """用户缓存预热""" + try: + await schedule_user_cache_warmup_task() + except Exception as e: + logger.error(f"User cache warmup failed: {e}") + + async def _preload_user_cache(self): + """用户缓存预加载""" + try: + await schedule_user_cache_preload_task() + except Exception as e: + logger.error(f"User cache preload failed: {e}") + + async def _cleanup_user_cache(self): + """用户缓存清理""" + try: + await schedule_user_cache_cleanup_task() + except Exception as e: + logger.error(f"User cache cleanup failed: {e}") + # Beatmap缓存调度器(保持向后兼容) class BeatmapsetCacheScheduler(CacheScheduler): diff --git a/app/scheduler/user_cache_scheduler.py b/app/scheduler/user_cache_scheduler.py new file mode 100644 index 0000000..cf43af2 --- /dev/null +++ b/app/scheduler/user_cache_scheduler.py @@ -0,0 +1,120 @@ +""" +用户缓存预热任务调度器 +""" +from __future__ import annotations + +import asyncio +from datetime import UTC, datetime, timedelta + +from app.config import settings +from app.database import User +from app.database.score import Score +from app.dependencies.database import get_db, get_redis +from app.log import logger +from app.service.user_cache_service import get_user_cache_service + +from sqlmodel import col, func, select + + +async def schedule_user_cache_preload_task(): + """定时用户缓存预加载任务""" + # 默认启用用户缓存预加载,除非明确禁用 + enable_user_cache_preload = getattr(settings, "enable_user_cache_preload", True) + if not enable_user_cache_preload: + return + + try: + logger.info("Starting user cache preload task...") + + redis = get_redis() + cache_service = get_user_cache_service(redis) + + # 使用独立的数据库会话 + from app.dependencies.database import with_db + async with with_db() as session: + # 获取最近24小时内活跃的用户(提交过成绩的用户) + recent_time = datetime.now(UTC) - timedelta(hours=24) + + active_user_ids = ( + await session.exec( + select(Score.user_id, func.count().label("score_count")) + .where(col(Score.ended_at) >= recent_time) + .group_by(col(Score.user_id)) + .order_by(col("score_count").desc()) + .limit(settings.user_cache_max_preload_users) # 使用配置中的限制 + ) + ).all() + + if active_user_ids: + user_ids = [row[0] for row in active_user_ids] + await cache_service.preload_user_cache(session, user_ids) + logger.info(f"Preloaded cache for {len(user_ids)} active users") + else: + logger.info("No active users found for cache preload") + + logger.info("User cache preload task completed successfully") + + except Exception as e: + logger.error(f"User cache preload task failed: {e}") + + +async def schedule_user_cache_warmup_task(): + """定时用户缓存预热任务 - 预加载排行榜前100用户""" + try: + logger.info("Starting user cache warmup task...") + + redis = get_redis() + cache_service = get_user_cache_service(redis) + + # 使用独立的数据库会话 + from app.dependencies.database import with_db + async with with_db() as session: + # 获取全球排行榜前100的用户 + from app.database.statistics import UserStatistics + from app.models.score import GameMode + + for mode in GameMode: + try: + top_users = ( + await session.exec( + select(UserStatistics.user_id) + .where(UserStatistics.mode == mode) + .order_by(col(UserStatistics.pp).desc()) + .limit(100) + ) + ).all() + + if top_users: + user_ids = list(top_users) + await cache_service.preload_user_cache(session, user_ids) + logger.info(f"Warmed cache for top 100 users in {mode}") + + # 避免过载,稍微延迟 + await asyncio.sleep(1) + + except Exception as e: + logger.error(f"Failed to warm cache for {mode}: {e}") + continue + + logger.info("User cache warmup task completed successfully") + + except Exception as e: + logger.error(f"User cache warmup task failed: {e}") + + +async def schedule_user_cache_cleanup_task(): + """定时用户缓存清理任务""" + try: + logger.info("Starting user cache cleanup task...") + + redis = get_redis() + + # 清理过期的用户缓存(Redis会自动处理TTL,这里主要记录统计信息) + cache_service = get_user_cache_service(redis) + stats = await cache_service.get_cache_stats() + + logger.info(f"User cache stats: {stats}") + logger.info("User cache cleanup task completed successfully") + + except Exception as e: + logger.error(f"User cache cleanup task failed: {e}") diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py index a2b9894..0cd75ac 100644 --- a/app/service/ranking_cache_service.py +++ b/app/service/ranking_cache_service.py @@ -421,7 +421,7 @@ class RankingCacheService: select(User.country_code, func.count().label("user_count")) .where(col(User.is_active).is_(True)) .group_by(User.country_code) - .order_by(col("user_count").desc()) + .order_by(func.count().desc()) .limit(settings.ranking_cache_top_countries) ) ).all() diff --git a/app/service/user_cache_service.py b/app/service/user_cache_service.py new file mode 100644 index 0000000..ed0926e --- /dev/null +++ b/app/service/user_cache_service.py @@ -0,0 +1,388 @@ +""" +用户缓存服务 +用于缓存用户信息,提供热缓存和实时刷新功能 +""" +from __future__ import annotations + +import asyncio +import json +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING, Any, Literal + +from app.config import settings +from app.const import BANCHOBOT_ID +from app.database import User, UserResp +from app.database.lazer_user import SEARCH_INCLUDED +from app.database.pp_best_score import PPBestScore +from app.database.score import Score, ScoreResp +from app.log import logger +from app.models.score import GameMode + +from redis.asyncio import Redis +from sqlmodel import col, exists, select +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + pass + + +class DateTimeEncoder(json.JSONEncoder): + """自定义 JSON 编码器,支持 datetime 序列化""" + + def default(self, obj): + if isinstance(obj, datetime): + return obj.isoformat() + return super().default(obj) + + +def safe_json_dumps(data: Any) -> str: + """安全的 JSON 序列化,支持 datetime 对象""" + return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False) + + +class UserCacheService: + """用户缓存服务""" + + def __init__(self, redis: Redis): + self.redis = redis + self._refreshing = False + self._background_tasks: set = set() + + def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str: + """生成 V1 用户缓存键""" + if ruleset: + return f"v1_user:{user_id}:ruleset:{ruleset}" + return f"v1_user:{user_id}" + + async def get_v1_user_from_cache( + self, + user_id: int, + ruleset: GameMode | None = None + ) -> dict | None: + """从缓存获取 V1 用户信息""" + try: + cache_key = self._get_v1_user_cache_key(user_id, ruleset) + cached_data = await self.redis.get(cache_key) + if cached_data: + logger.debug(f"V1 User cache hit for user {user_id}") + return json.loads(cached_data) + return None + except Exception as e: + logger.error(f"Error getting V1 user from cache: {e}") + return None + + async def cache_v1_user( + self, + user_data: dict, + user_id: int, + ruleset: GameMode | None = None, + expire_seconds: int | None = None + ): + """缓存 V1 用户信息""" + try: + if expire_seconds is None: + expire_seconds = settings.user_cache_expire_seconds + cache_key = self._get_v1_user_cache_key(user_id, ruleset) + cached_data = safe_json_dumps(user_data) + await self.redis.setex(cache_key, expire_seconds, cached_data) + logger.debug(f"Cached V1 user {user_id} for {expire_seconds}s") + except Exception as e: + logger.error(f"Error caching V1 user: {e}") + + async def invalidate_v1_user_cache(self, user_id: int): + """使 V1 用户缓存失效""" + try: + # 删除 V1 用户信息缓存 + pattern = f"v1_user:{user_id}*" + keys = await self.redis.keys(pattern) + if keys: + await self.redis.delete(*keys) + logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}") + except Exception as e: + logger.error(f"Error invalidating V1 user cache: {e}") + + def _get_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str: + """生成用户缓存键""" + if ruleset: + return f"user:{user_id}:ruleset:{ruleset}" + return f"user:{user_id}" + + def _get_user_scores_cache_key( + self, + user_id: int, + score_type: str, + mode: GameMode | None = None, + limit: int = 100, + offset: int = 0 + ) -> str: + """生成用户成绩缓存键""" + mode_part = f":{mode}" if mode else "" + return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}" + + def _get_user_beatmapsets_cache_key( + self, + user_id: int, + beatmapset_type: str, + limit: int = 100, + offset: int = 0 + ) -> str: + """生成用户谱面集缓存键""" + return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}" + + async def get_user_from_cache( + self, + user_id: int, + ruleset: GameMode | None = None + ) -> UserResp | None: + """从缓存获取用户信息""" + try: + cache_key = self._get_user_cache_key(user_id, ruleset) + cached_data = await self.redis.get(cache_key) + if cached_data: + logger.debug(f"User cache hit for user {user_id}") + data = json.loads(cached_data) + return UserResp(**data) + return None + except Exception as e: + logger.error(f"Error getting user from cache: {e}") + return None + + async def cache_user( + self, + user_resp: UserResp, + ruleset: GameMode | None = None, + expire_seconds: int | None = None + ): + """缓存用户信息""" + try: + if expire_seconds is None: + expire_seconds = settings.user_cache_expire_seconds + if user_resp.id is None: + logger.warning("Cannot cache user with None id") + return + cache_key = self._get_user_cache_key(user_resp.id, ruleset) + cached_data = user_resp.model_dump_json() + await self.redis.setex(cache_key, expire_seconds, cached_data) + logger.debug(f"Cached user {user_resp.id} for {expire_seconds}s") + except Exception as e: + logger.error(f"Error caching user: {e}") + + async def get_user_scores_from_cache( + self, + user_id: int, + score_type: str, + mode: GameMode | None = None, + limit: int = 100, + offset: int = 0 + ) -> list[ScoreResp] | None: + """从缓存获取用户成绩""" + try: + cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset) + cached_data = await self.redis.get(cache_key) + if cached_data: + logger.debug(f"User scores cache hit for user {user_id}, type {score_type}") + data = json.loads(cached_data) + return [ScoreResp(**score_data) for score_data in data] + return None + except Exception as e: + logger.error(f"Error getting user scores from cache: {e}") + return None + + async def cache_user_scores( + self, + user_id: int, + score_type: str, + scores: list[ScoreResp], + mode: GameMode | None = None, + limit: int = 100, + offset: int = 0, + expire_seconds: int | None = None + ): + """缓存用户成绩""" + try: + if expire_seconds is None: + expire_seconds = settings.user_scores_cache_expire_seconds + cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset) + # 使用 model_dump_json() 而不是 model_dump() + json.dumps() + scores_json_list = [score.model_dump_json() for score in scores] + cached_data = f"[{','.join(scores_json_list)}]" + await self.redis.setex(cache_key, expire_seconds, cached_data) + logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s") + except Exception as e: + logger.error(f"Error caching user scores: {e}") + + async def get_user_beatmapsets_from_cache( + self, + user_id: int, + beatmapset_type: str, + limit: int = 100, + offset: int = 0 + ) -> list[Any] | None: + """从缓存获取用户谱面集""" + try: + cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset) + cached_data = await self.redis.get(cache_key) + if cached_data: + logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}") + return json.loads(cached_data) + return None + except Exception as e: + logger.error(f"Error getting user beatmapsets from cache: {e}") + return None + + async def cache_user_beatmapsets( + self, + user_id: int, + beatmapset_type: str, + beatmapsets: list[Any], + limit: int = 100, + offset: int = 0, + expire_seconds: int | None = None + ): + """缓存用户谱面集""" + try: + if expire_seconds is None: + expire_seconds = settings.user_beatmapsets_cache_expire_seconds + cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset) + # 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps + serialized_beatmapsets = [] + for bms in beatmapsets: + if hasattr(bms, 'model_dump_json'): + serialized_beatmapsets.append(bms.model_dump_json()) + else: + serialized_beatmapsets.append(safe_json_dumps(bms)) + cached_data = f"[{','.join(serialized_beatmapsets)}]" + await self.redis.setex(cache_key, expire_seconds, cached_data) + logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s") + except Exception as e: + logger.error(f"Error caching user beatmapsets: {e}") + + async def invalidate_user_cache(self, user_id: int): + """使用户缓存失效""" + try: + # 删除用户信息缓存 + pattern = f"user:{user_id}*" + keys = await self.redis.keys(pattern) + if keys: + await self.redis.delete(*keys) + logger.info(f"Invalidated {len(keys)} cache entries for user {user_id}") + except Exception as e: + logger.error(f"Error invalidating user cache: {e}") + + async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None): + """使用户成绩缓存失效""" + try: + # 删除用户成绩相关缓存 + mode_pattern = f":{mode}" if mode else "*" + pattern = f"user:{user_id}:scores:*{mode_pattern}*" + keys = await self.redis.keys(pattern) + if keys: + await self.redis.delete(*keys) + logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}") + except Exception as e: + logger.error(f"Error invalidating user scores cache: {e}") + + async def preload_user_cache(self, session: AsyncSession, user_ids: list[int]): + """预加载用户缓存""" + if self._refreshing: + return + + self._refreshing = True + try: + logger.info(f"Preloading cache for {len(user_ids)} users") + + # 批量获取用户 + users = ( + await session.exec( + select(User).where(col(User.id).in_(user_ids)) + ) + ).all() + + # 串行缓存用户信息,避免并发数据库访问问题 + cached_count = 0 + for user in users: + if user.id != BANCHOBOT_ID: + try: + await self._cache_single_user(user, session) + cached_count += 1 + except Exception as e: + logger.error(f"Failed to cache user {user.id}: {e}") + + logger.info(f"Preloaded cache for {cached_count} users") + + except Exception as e: + logger.error(f"Error preloading user cache: {e}") + finally: + self._refreshing = False + + async def _cache_single_user(self, user: User, session: AsyncSession): + """缓存单个用户""" + try: + user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED) + await self.cache_user(user_resp) + except Exception as e: + logger.error(f"Error caching single user {user.id}: {e}") + + async def refresh_user_cache_on_score_submit( + self, + session: AsyncSession, + user_id: int, + mode: GameMode + ): + """成绩提交后刷新用户缓存""" + try: + # 使相关缓存失效(包括 v1 和 v2) + await self.invalidate_user_cache(user_id) + await self.invalidate_v1_user_cache(user_id) + await self.invalidate_user_scores_cache(user_id, mode) + + # 立即重新加载用户信息 + user = await session.get(User, user_id) + if user and user.id != BANCHOBOT_ID: + await self._cache_single_user(user, session) + logger.info(f"Refreshed cache for user {user_id} after score submit") + except Exception as e: + logger.error(f"Error refreshing user cache on score submit: {e}") + + async def get_cache_stats(self) -> dict: + """获取缓存统计信息""" + try: + user_keys = await self.redis.keys("user:*") + v1_user_keys = await self.redis.keys("v1_user:*") + all_keys = user_keys + v1_user_keys + total_size = 0 + + for key in all_keys[:100]: # 限制检查数量 + try: + size = await self.redis.memory_usage(key) + if size: + total_size += size + except Exception: + continue + + return { + "cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]), + "cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]), + "cached_user_scores": len([k for k in user_keys if ":scores:" in k]), + "cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]), + "total_cached_entries": len(all_keys), + "estimated_total_size_mb": ( + round(total_size / 1024 / 1024, 2) if total_size > 0 else 0 + ), + "refreshing": self._refreshing, + } + except Exception as e: + logger.error(f"Error getting user cache stats: {e}") + return {"error": str(e)} + + +# 全局缓存服务实例 +_user_cache_service: UserCacheService | None = None + + +def get_user_cache_service(redis: Redis) -> UserCacheService: + """获取用户缓存服务实例""" + global _user_cache_service + if _user_cache_service is None: + _user_cache_service = UserCacheService(redis) + return _user_cache_service