Add grade hot cache
This commit is contained in:
@@ -149,6 +149,14 @@ class Settings(BaseSettings):
|
||||
ranking_cache_refresh_interval_minutes: int = 10 # 排行榜缓存刷新间隔(分钟)
|
||||
ranking_cache_max_pages: int = 20 # 最多缓存的页数
|
||||
ranking_cache_top_countries: int = 20 # 缓存前N个国家的排行榜
|
||||
|
||||
# 用户缓存设置
|
||||
enable_user_cache_preload: bool = True # 启用用户缓存预加载
|
||||
user_cache_expire_seconds: int = 300 # 用户信息缓存过期时间(秒)
|
||||
user_scores_cache_expire_seconds: int = 60 # 用户成绩缓存过期时间(秒)
|
||||
user_beatmapsets_cache_expire_seconds: int = 600 # 用户谱面集缓存过期时间(秒)
|
||||
user_cache_max_preload_users: int = 200 # 最多预加载的用户数量
|
||||
user_cache_concurrent_limit: int = 10 # 并发缓存用户的限制
|
||||
|
||||
# 反作弊设置
|
||||
suspicious_score_check: bool = True
|
||||
|
||||
@@ -7,8 +7,8 @@ from app.models.score import GameMode
|
||||
|
||||
from .lazer_user import BASE_INCLUDES, User, UserResp
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import JSON, Column, DateTime, Text
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from sqlalchemy import Boolean, JSON, Column, DateTime, Text
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -73,16 +73,16 @@ class BeatmapsetBase(SQLModel):
|
||||
artist_unicode: str = Field(index=True)
|
||||
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
|
||||
creator: str = Field(index=True)
|
||||
nsfw: bool = Field(default=False)
|
||||
nsfw: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
play_count: int = Field(index=True)
|
||||
preview_url: str
|
||||
source: str = Field(default="")
|
||||
|
||||
spotlight: bool = Field(default=False)
|
||||
spotlight: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
title: str = Field(index=True)
|
||||
title_unicode: str = Field(index=True)
|
||||
user_id: int = Field(index=True)
|
||||
video: bool = Field(index=True)
|
||||
video: bool = Field(sa_column=Column(Boolean, index=True))
|
||||
|
||||
# optional
|
||||
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
|
||||
@@ -102,13 +102,13 @@ class BeatmapsetBase(SQLModel):
|
||||
|
||||
# BeatmapsetExtended
|
||||
bpm: float = Field(default=0.0)
|
||||
can_be_hyped: bool = Field(default=False)
|
||||
discussion_locked: bool = Field(default=False)
|
||||
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
ranked_date: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime, index=True)
|
||||
)
|
||||
storyboard: bool = Field(default=False, index=True)
|
||||
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
|
||||
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
tags: str = Field(default="", sa_column=Column(Text))
|
||||
|
||||
@@ -133,7 +133,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
hype_current: int = Field(default=0)
|
||||
hype_required: int = Field(default=0)
|
||||
availability_info: str | None = Field(default=None)
|
||||
download_disabled: bool = Field(default=False)
|
||||
download_disabled: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
||||
|
||||
@classmethod
|
||||
@@ -205,6 +205,14 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
favourite_count: int = 0
|
||||
recent_favourites: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
@field_validator('nsfw', 'spotlight', 'video', 'can_be_hyped', 'discussion_locked', 'storyboard', 'discussion_enabled', 'is_scoreable', 'has_favourited', mode='before')
|
||||
@classmethod
|
||||
def validate_bool_fields(cls, v):
|
||||
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
|
||||
if isinstance(v, int):
|
||||
return bool(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def fix_genre_language(self) -> Self:
|
||||
if self.genre is None:
|
||||
|
||||
35
app/database/field_utils.py
Normal file
35
app/database/field_utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
数据库字段类型工具
|
||||
提供处理数据库和 Pydantic 之间类型转换的工具
|
||||
"""
|
||||
from typing import Any, Union
|
||||
from pydantic import field_validator
|
||||
from sqlalchemy import Boolean
|
||||
|
||||
|
||||
def bool_field_validator(field_name: str):
|
||||
"""为特定布尔字段创建验证器,处理数据库中的 0/1 整数"""
|
||||
@field_validator(field_name, mode="before")
|
||||
@classmethod
|
||||
def validate_bool_field(cls, v: Any) -> bool:
|
||||
"""将整数 0/1 转换为布尔值"""
|
||||
if isinstance(v, int):
|
||||
return bool(v)
|
||||
return v
|
||||
return validate_bool_field
|
||||
|
||||
|
||||
def create_bool_field(**kwargs):
|
||||
"""创建一个带有正确 SQLAlchemy 列定义的布尔字段"""
|
||||
from sqlmodel import Field, Column
|
||||
|
||||
# 如果没有指定 sa_column,则使用 Boolean 类型
|
||||
if 'sa_column' not in kwargs:
|
||||
# 处理 index 参数
|
||||
index = kwargs.pop('index', False)
|
||||
if index:
|
||||
kwargs['sa_column'] = Column(Boolean, index=True)
|
||||
else:
|
||||
kwargs['sa_column'] = Column(Boolean)
|
||||
|
||||
return Field(**kwargs)
|
||||
@@ -45,8 +45,9 @@ from .relationship import (
|
||||
)
|
||||
from .score_token import ScoreToken
|
||||
|
||||
from pydantic import field_validator
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
|
||||
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
@@ -79,13 +80,13 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
default=0, sa_column=Column(BigInteger)
|
||||
) # solo_score
|
||||
ended_at: datetime = Field(sa_column=Column(DateTime))
|
||||
has_replay: bool
|
||||
has_replay: bool = Field(sa_column=Column(Boolean))
|
||||
max_combo: int
|
||||
mods: list[APIMod] = Field(sa_column=Column(JSON))
|
||||
passed: bool
|
||||
passed: bool = Field(sa_column=Column(Boolean))
|
||||
playlist_item_id: int | None = Field(default=None) # multiplayer
|
||||
pp: float = Field(default=0.0)
|
||||
preserve: bool = Field(default=True)
|
||||
preserve: bool = Field(default=True, sa_column=Column(Boolean))
|
||||
rank: Rank
|
||||
room_id: int | None = Field(default=None) # multiplayer
|
||||
started_at: datetime = Field(sa_column=Column(DateTime))
|
||||
@@ -176,6 +177,14 @@ class ScoreResp(ScoreBase):
|
||||
ranked: bool = False
|
||||
current_user_attributes: CurrentUserAttributes | None = None
|
||||
|
||||
@field_validator('has_replay', 'passed', 'preserve', 'is_perfect_combo', 'legacy_perfect', 'processed', 'ranked', mode='before')
|
||||
@classmethod
|
||||
def validate_bool_fields(cls, v):
|
||||
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
|
||||
if isinstance(v, int):
|
||||
return bool(v)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
||||
s = cls.model_validate(score.model_dump())
|
||||
|
||||
@@ -28,9 +28,11 @@ def json_serializer(value):
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
json_serializer=json_serializer,
|
||||
pool_size=20,
|
||||
max_overflow=20,
|
||||
pool_size=30, # 增加连接池大小
|
||||
max_overflow=50, # 增加最大溢出连接数
|
||||
pool_timeout=30.0,
|
||||
pool_recycle=3600, # 1小时回收连接
|
||||
pool_pre_ping=True, # 启用连接预检查
|
||||
)
|
||||
|
||||
# Redis 连接
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.database.lazer_user import User
|
||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.log import logger
|
||||
from app.models.score import GameMode
|
||||
from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import AllStrModel, router
|
||||
|
||||
@@ -38,10 +42,21 @@ class V1User(AllStrModel):
|
||||
pp_country_rank: int
|
||||
events: list[dict]
|
||||
|
||||
@classmethod
|
||||
def _get_cache_key(cls, user_id: int, ruleset: GameMode | None = None) -> str:
|
||||
"""生成 V1 用户缓存键"""
|
||||
if ruleset:
|
||||
return f"v1_user:{user_id}:ruleset:{ruleset}"
|
||||
return f"v1_user:{user_id}"
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, session: Database, db_user: User, ruleset: GameMode | None = None
|
||||
) -> "V1User":
|
||||
# 确保 user_id 不为 None
|
||||
if db_user.id is None:
|
||||
raise ValueError("User ID cannot be None")
|
||||
|
||||
ruleset = ruleset or db_user.playmode
|
||||
current_statistics: UserStatistics | None = None
|
||||
for i in await db_user.awaitable_attrs.statistics:
|
||||
@@ -101,24 +116,55 @@ async def get_user(
|
||||
default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
|
||||
),
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 确定查询方式和用户ID
|
||||
is_id_query = type == "id" or user.isdigit()
|
||||
|
||||
# 解析 ruleset
|
||||
ruleset = GameMode.from_int_extra(ruleset_id) if ruleset_id else None
|
||||
|
||||
# 如果是 ID 查询,先尝试从缓存获取
|
||||
cached_v1_user = None
|
||||
user_id_for_cache = None
|
||||
|
||||
if is_id_query:
|
||||
try:
|
||||
user_id_for_cache = int(user)
|
||||
cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset)
|
||||
if cached_v1_user:
|
||||
return [V1User(**cached_v1_user)]
|
||||
except (ValueError, TypeError):
|
||||
pass # 不是有效的用户ID,继续数据库查询
|
||||
|
||||
# 从数据库查询用户
|
||||
db_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
User.id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else User.username == user,
|
||||
User.id == user if is_id_query else User.username == user,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not db_user:
|
||||
return []
|
||||
|
||||
try:
|
||||
return [
|
||||
await V1User.from_db(
|
||||
session,
|
||||
db_user,
|
||||
GameMode.from_int_extra(ruleset_id) if ruleset_id else None,
|
||||
# 生成用户数据
|
||||
v1_user = await V1User.from_db(session, db_user, ruleset)
|
||||
|
||||
# 异步缓存结果(如果有用户ID)
|
||||
if db_user.id is not None:
|
||||
user_data = v1_user.model_dump()
|
||||
asyncio.create_task(
|
||||
cache_service.cache_v1_user(user_data, db_user.id, ruleset)
|
||||
)
|
||||
]
|
||||
|
||||
return [v1_user]
|
||||
|
||||
except KeyError:
|
||||
raise HTTPException(400, "Invalid request")
|
||||
except ValueError as e:
|
||||
logger.error(f"Error processing V1 user data: {e}")
|
||||
raise HTTPException(500, "Internal server error")
|
||||
|
||||
144
app/router/v2/cache.py
Normal file
144
app/router/v2/cache.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
缓存管理和监控接口
|
||||
提供缓存统计、清理和预热功能
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class CacheStatsResponse(BaseModel):
|
||||
user_cache: dict
|
||||
redis_info: dict
|
||||
|
||||
|
||||
@router.get(
|
||||
"/cache/stats",
|
||||
response_model=CacheStatsResponse,
|
||||
name="获取缓存统计信息",
|
||||
description="获取用户缓存和Redis的统计信息,需要管理员权限。",
|
||||
tags=["缓存管理"],
|
||||
)
|
||||
async def get_cache_stats(
|
||||
redis: Redis = Depends(get_redis),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释,可根据需要启用
|
||||
):
|
||||
try:
|
||||
cache_service = get_user_cache_service(redis)
|
||||
user_cache_stats = await cache_service.get_cache_stats()
|
||||
|
||||
# 获取 Redis 基本信息
|
||||
redis_info = await redis.info()
|
||||
redis_stats = {
|
||||
"connected_clients": redis_info.get("connected_clients", 0),
|
||||
"used_memory_human": redis_info.get("used_memory_human", "0B"),
|
||||
"used_memory_peak_human": redis_info.get("used_memory_peak_human", "0B"),
|
||||
"total_commands_processed": redis_info.get("total_commands_processed", 0),
|
||||
"keyspace_hits": redis_info.get("keyspace_hits", 0),
|
||||
"keyspace_misses": redis_info.get("keyspace_misses", 0),
|
||||
"evicted_keys": redis_info.get("evicted_keys", 0),
|
||||
"expired_keys": redis_info.get("expired_keys", 0),
|
||||
}
|
||||
|
||||
# 计算缓存命中率
|
||||
hits = redis_stats["keyspace_hits"]
|
||||
misses = redis_stats["keyspace_misses"]
|
||||
hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0
|
||||
redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2)
|
||||
|
||||
return CacheStatsResponse(
|
||||
user_cache=user_cache_stats,
|
||||
redis_info=redis_stats
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"Failed to get cache stats: {str(e)}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cache/invalidate/{user_id}",
|
||||
name="清除指定用户缓存",
|
||||
description="清除指定用户的所有缓存数据,需要管理员权限。",
|
||||
tags=["缓存管理"],
|
||||
)
|
||||
async def invalidate_user_cache(
|
||||
user_id: int,
|
||||
redis: Redis = Depends(get_redis),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
|
||||
):
|
||||
try:
|
||||
cache_service = get_user_cache_service(redis)
|
||||
await cache_service.invalidate_user_cache(user_id)
|
||||
await cache_service.invalidate_v1_user_cache(user_id)
|
||||
return {"message": f"Cache invalidated for user {user_id}"}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"Failed to invalidate cache: {str(e)}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cache/clear",
|
||||
name="清除所有用户缓存",
|
||||
description="清除所有用户相关的缓存数据,需要管理员权限。谨慎使用!",
|
||||
tags=["缓存管理"],
|
||||
)
|
||||
async def clear_all_user_cache(
|
||||
redis: Redis = Depends(get_redis),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
|
||||
):
|
||||
try:
|
||||
# 获取所有用户相关的缓存键
|
||||
user_keys = await redis.keys("user:*")
|
||||
v1_user_keys = await redis.keys("v1_user:*")
|
||||
all_keys = user_keys + v1_user_keys
|
||||
|
||||
if all_keys:
|
||||
await redis.delete(*all_keys)
|
||||
return {"message": f"Cleared {len(all_keys)} cache entries"}
|
||||
else:
|
||||
return {"message": "No cache entries found"}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"Failed to clear cache: {str(e)}")
|
||||
|
||||
|
||||
class CacheWarmupRequest(BaseModel):
|
||||
user_ids: list[int] | None = None
|
||||
limit: int = 100
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cache/warmup",
|
||||
name="缓存预热",
|
||||
description="对指定用户或活跃用户进行缓存预热,需要管理员权限。",
|
||||
tags=["缓存管理"],
|
||||
)
|
||||
async def warmup_cache(
|
||||
request: CacheWarmupRequest,
|
||||
redis: Redis = Depends(get_redis),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释
|
||||
):
|
||||
try:
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
if request.user_ids:
|
||||
# 预热指定用户
|
||||
from app.dependencies.database import with_db
|
||||
async with with_db() as session:
|
||||
await cache_service.preload_user_cache(session, request.user_ids)
|
||||
return {"message": f"Warmed up cache for {len(request.user_ids)} users"}
|
||||
else:
|
||||
# 预热活跃用户
|
||||
from app.scheduler.user_cache_scheduler import schedule_user_cache_preload_task
|
||||
await schedule_user_cache_preload_task()
|
||||
return {"message": f"Warmed up cache for top {request.limit} active users"}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"Failed to warmup cache: {str(e)}")
|
||||
@@ -47,6 +47,7 @@ from app.models.score import (
|
||||
Rank,
|
||||
SoloScoreSubmissionInfo,
|
||||
)
|
||||
from app.service.user_cache_service import get_user_cache_service
|
||||
from app.storage.base import StorageService
|
||||
from app.storage.local import LocalStorageService
|
||||
|
||||
@@ -182,6 +183,17 @@ async def submit_score(
|
||||
}
|
||||
db.add(rank_event)
|
||||
await db.commit()
|
||||
|
||||
# 成绩提交后刷新用户缓存
|
||||
try:
|
||||
user_cache_service = get_user_cache_service(redis)
|
||||
if current_user.id is not None:
|
||||
await user_cache_service.refresh_user_cache_on_score_submit(
|
||||
db, current_user.id, score.gamemode
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh user cache after score submit: {e}")
|
||||
|
||||
background_task.add_task(process_user_achievement, resp.id)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import (
|
||||
BeatmapPlaycounts,
|
||||
@@ -15,10 +17,11 @@ from app.database.events import EventResp
|
||||
from app.database.lazer_user import SEARCH_INCLUDED
|
||||
from app.database.pp_best_score import PPBestScore
|
||||
from app.database.score import Score, ScoreResp
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import BeatmapsetType
|
||||
from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
@@ -51,23 +54,55 @@ async def get_users(
|
||||
default=False, description="是否包含各模式的统计信息"
|
||||
), # TODO: future use
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
if user_ids:
|
||||
searched_users = (
|
||||
await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
|
||||
).all()
|
||||
# 先尝试从缓存获取
|
||||
cached_users = []
|
||||
uncached_user_ids = []
|
||||
|
||||
for user_id in user_ids[:50]: # 限制50个
|
||||
cached_user = await cache_service.get_user_from_cache(user_id)
|
||||
if cached_user:
|
||||
cached_users.append(cached_user)
|
||||
else:
|
||||
uncached_user_ids.append(user_id)
|
||||
|
||||
# 查询未缓存的用户
|
||||
if uncached_user_ids:
|
||||
searched_users = (
|
||||
await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))
|
||||
).all()
|
||||
|
||||
# 将查询到的用户添加到缓存并返回
|
||||
for searched_user in searched_users:
|
||||
if searched_user.id != BANCHOBOT_ID:
|
||||
user_resp = await UserResp.from_db(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
)
|
||||
cached_users.append(user_resp)
|
||||
# 异步缓存,不阻塞响应
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
|
||||
return BatchUserResponse(users=cached_users)
|
||||
else:
|
||||
searched_users = (await session.exec(select(User).limit(50))).all()
|
||||
return BatchUserResponse(
|
||||
users=[
|
||||
await UserResp.from_db(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
)
|
||||
for searched_user in searched_users
|
||||
if searched_user.id != BANCHOBOT_ID
|
||||
]
|
||||
)
|
||||
users = []
|
||||
for searched_user in searched_users:
|
||||
if searched_user.id != BANCHOBOT_ID:
|
||||
user_resp = await UserResp.from_db(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
)
|
||||
users.append(user_resp)
|
||||
# 异步缓存
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
|
||||
return BatchUserResponse(users=users)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -83,6 +118,16 @@ async def get_user_info_ruleset(
|
||||
ruleset: GameMode | None = Path(description="指定 ruleset"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 如果是数字ID,先尝试从缓存获取
|
||||
if user_id.isdigit():
|
||||
user_id_int = int(user_id)
|
||||
cached_user = await cache_service.get_user_from_cache(user_id_int, ruleset)
|
||||
if cached_user:
|
||||
return cached_user
|
||||
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
@@ -94,12 +139,18 @@ async def get_user_info_ruleset(
|
||||
).first()
|
||||
if not searched_user or searched_user.id == BANCHOBOT_ID:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
return await UserResp.from_db(
|
||||
|
||||
user_resp = await UserResp.from_db(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
ruleset=ruleset,
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
|
||||
|
||||
return user_resp
|
||||
|
||||
|
||||
@router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False)
|
||||
@@ -115,6 +166,16 @@ async def get_user_info(
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 如果是数字ID,先尝试从缓存获取
|
||||
if user_id.isdigit():
|
||||
user_id_int = int(user_id)
|
||||
cached_user = await cache_service.get_user_from_cache(user_id_int)
|
||||
if cached_user:
|
||||
return cached_user
|
||||
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
@@ -126,11 +187,17 @@ async def get_user_info(
|
||||
).first()
|
||||
if not searched_user or searched_user.id == BANCHOBOT_ID:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
return await UserResp.from_db(
|
||||
|
||||
user_resp = await UserResp.from_db(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
|
||||
return user_resp
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -148,6 +215,20 @@ async def get_user_beatmapsets(
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 先尝试从缓存获取
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(
|
||||
user_id, type.value, limit, offset
|
||||
)
|
||||
if cached_result is not None:
|
||||
# 根据类型恢复对象
|
||||
if type == BeatmapsetType.MOST_PLAYED:
|
||||
return [BeatmapPlaycountsResp(**item) for item in cached_result]
|
||||
else:
|
||||
return [BeatmapsetResp(**item) for item in cached_result]
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if not user or user.id == BANCHOBOT_ID:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
@@ -190,6 +271,11 @@ async def get_user_beatmapsets(
|
||||
else:
|
||||
raise HTTPException(400, detail="Invalid beatmapset type")
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(
|
||||
cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@@ -218,6 +304,17 @@ async def get_user_scores(
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 先尝试从缓存获取(对于recent类型使用较短的缓存时间)
|
||||
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
|
||||
cached_scores = await cache_service.get_user_scores_from_cache(
|
||||
user_id, type, mode, limit, offset
|
||||
)
|
||||
if cached_scores is not None:
|
||||
return cached_scores
|
||||
|
||||
db_user = await session.get(User, user_id)
|
||||
if not db_user or db_user.id == BANCHOBOT_ID:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
@@ -253,13 +350,23 @@ async def get_user_scores(
|
||||
).all()
|
||||
if not scores:
|
||||
return []
|
||||
return [
|
||||
|
||||
score_responses = [
|
||||
await ScoreResp.from_db(
|
||||
session,
|
||||
score,
|
||||
)
|
||||
for score in scores
|
||||
]
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(
|
||||
cache_service.cache_user_scores(
|
||||
user_id, type, score_responses, mode, limit, offset, cache_expire
|
||||
)
|
||||
)
|
||||
|
||||
return score_responses
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -6,6 +6,11 @@ from app.config import settings
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.log import logger
|
||||
from app.scheduler.user_cache_scheduler import (
|
||||
schedule_user_cache_cleanup_task,
|
||||
schedule_user_cache_preload_task,
|
||||
schedule_user_cache_warmup_task,
|
||||
)
|
||||
|
||||
|
||||
class CacheScheduler:
|
||||
@@ -42,17 +47,26 @@ class CacheScheduler:
|
||||
|
||||
# 启动时执行一次排行榜缓存刷新
|
||||
await self._refresh_ranking_cache()
|
||||
|
||||
# 启动时执行一次用户缓存预热
|
||||
await self._warmup_user_cache()
|
||||
|
||||
beatmap_cache_counter = 0
|
||||
ranking_cache_counter = 0
|
||||
user_cache_counter = 0
|
||||
user_cleanup_counter = 0
|
||||
|
||||
# 从配置文件获取间隔设置
|
||||
check_interval = 5 * 60 # 5分钟检查间隔
|
||||
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
|
||||
ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取
|
||||
user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔
|
||||
user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔
|
||||
|
||||
beatmap_cache_cycles = beatmap_cache_interval // check_interval
|
||||
ranking_cache_cycles = ranking_cache_interval // check_interval
|
||||
user_cache_cycles = user_cache_interval // check_interval
|
||||
user_cleanup_cycles = user_cleanup_interval // check_interval
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
@@ -64,6 +78,8 @@ class CacheScheduler:
|
||||
|
||||
beatmap_cache_counter += 1
|
||||
ranking_cache_counter += 1
|
||||
user_cache_counter += 1
|
||||
user_cleanup_counter += 1
|
||||
|
||||
# beatmap缓存预热
|
||||
if beatmap_cache_counter >= beatmap_cache_cycles:
|
||||
@@ -74,6 +90,16 @@ class CacheScheduler:
|
||||
if ranking_cache_counter >= ranking_cache_cycles:
|
||||
await self._refresh_ranking_cache()
|
||||
ranking_cache_counter = 0
|
||||
|
||||
# 用户缓存预加载
|
||||
if user_cache_counter >= user_cache_cycles:
|
||||
await self._preload_user_cache()
|
||||
user_cache_counter = 0
|
||||
|
||||
# 用户缓存清理
|
||||
if user_cleanup_counter >= user_cleanup_cycles:
|
||||
await self._cleanup_user_cache()
|
||||
user_cleanup_counter = 0
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
@@ -110,16 +136,37 @@ class CacheScheduler:
|
||||
schedule_ranking_refresh_task,
|
||||
)
|
||||
|
||||
# 获取数据库会话
|
||||
async for session in get_db():
|
||||
# 使用独立的数据库会话
|
||||
from app.dependencies.database import with_db
|
||||
async with with_db() as session:
|
||||
await schedule_ranking_refresh_task(session, redis)
|
||||
break # 只需要一次会话
|
||||
|
||||
logger.info("Ranking cache refresh completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ranking cache refresh failed: {e}")
|
||||
|
||||
async def _warmup_user_cache(self):
|
||||
"""用户缓存预热"""
|
||||
try:
|
||||
await schedule_user_cache_warmup_task()
|
||||
except Exception as e:
|
||||
logger.error(f"User cache warmup failed: {e}")
|
||||
|
||||
async def _preload_user_cache(self):
|
||||
"""用户缓存预加载"""
|
||||
try:
|
||||
await schedule_user_cache_preload_task()
|
||||
except Exception as e:
|
||||
logger.error(f"User cache preload failed: {e}")
|
||||
|
||||
async def _cleanup_user_cache(self):
|
||||
"""用户缓存清理"""
|
||||
try:
|
||||
await schedule_user_cache_cleanup_task()
|
||||
except Exception as e:
|
||||
logger.error(f"User cache cleanup failed: {e}")
|
||||
|
||||
|
||||
# Beatmap缓存调度器(保持向后兼容)
|
||||
class BeatmapsetCacheScheduler(CacheScheduler):
|
||||
|
||||
120
app/scheduler/user_cache_scheduler.py
Normal file
120
app/scheduler/user_cache_scheduler.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
用户缓存预热任务调度器
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.config import settings
|
||||
from app.database import User
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.log import logger
|
||||
from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from sqlmodel import col, func, select
|
||||
|
||||
|
||||
async def schedule_user_cache_preload_task():
|
||||
"""定时用户缓存预加载任务"""
|
||||
# 默认启用用户缓存预加载,除非明确禁用
|
||||
enable_user_cache_preload = getattr(settings, "enable_user_cache_preload", True)
|
||||
if not enable_user_cache_preload:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Starting user cache preload task...")
|
||||
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 使用独立的数据库会话
|
||||
from app.dependencies.database import with_db
|
||||
async with with_db() as session:
|
||||
# 获取最近24小时内活跃的用户(提交过成绩的用户)
|
||||
recent_time = datetime.now(UTC) - timedelta(hours=24)
|
||||
|
||||
active_user_ids = (
|
||||
await session.exec(
|
||||
select(Score.user_id, func.count().label("score_count"))
|
||||
.where(col(Score.ended_at) >= recent_time)
|
||||
.group_by(col(Score.user_id))
|
||||
.order_by(col("score_count").desc())
|
||||
.limit(settings.user_cache_max_preload_users) # 使用配置中的限制
|
||||
)
|
||||
).all()
|
||||
|
||||
if active_user_ids:
|
||||
user_ids = [row[0] for row in active_user_ids]
|
||||
await cache_service.preload_user_cache(session, user_ids)
|
||||
logger.info(f"Preloaded cache for {len(user_ids)} active users")
|
||||
else:
|
||||
logger.info("No active users found for cache preload")
|
||||
|
||||
logger.info("User cache preload task completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"User cache preload task failed: {e}")
|
||||
|
||||
|
||||
async def schedule_user_cache_warmup_task():
|
||||
"""定时用户缓存预热任务 - 预加载排行榜前100用户"""
|
||||
try:
|
||||
logger.info("Starting user cache warmup task...")
|
||||
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 使用独立的数据库会话
|
||||
from app.dependencies.database import with_db
|
||||
async with with_db() as session:
|
||||
# 获取全球排行榜前100的用户
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.models.score import GameMode
|
||||
|
||||
for mode in GameMode:
|
||||
try:
|
||||
top_users = (
|
||||
await session.exec(
|
||||
select(UserStatistics.user_id)
|
||||
.where(UserStatistics.mode == mode)
|
||||
.order_by(col(UserStatistics.pp).desc())
|
||||
.limit(100)
|
||||
)
|
||||
).all()
|
||||
|
||||
if top_users:
|
||||
user_ids = list(top_users)
|
||||
await cache_service.preload_user_cache(session, user_ids)
|
||||
logger.info(f"Warmed cache for top 100 users in {mode}")
|
||||
|
||||
# 避免过载,稍微延迟
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to warm cache for {mode}: {e}")
|
||||
continue
|
||||
|
||||
logger.info("User cache warmup task completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"User cache warmup task failed: {e}")
|
||||
|
||||
|
||||
async def schedule_user_cache_cleanup_task():
|
||||
"""定时用户缓存清理任务"""
|
||||
try:
|
||||
logger.info("Starting user cache cleanup task...")
|
||||
|
||||
redis = get_redis()
|
||||
|
||||
# 清理过期的用户缓存(Redis会自动处理TTL,这里主要记录统计信息)
|
||||
cache_service = get_user_cache_service(redis)
|
||||
stats = await cache_service.get_cache_stats()
|
||||
|
||||
logger.info(f"User cache stats: {stats}")
|
||||
logger.info("User cache cleanup task completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"User cache cleanup task failed: {e}")
|
||||
@@ -421,7 +421,7 @@ class RankingCacheService:
|
||||
select(User.country_code, func.count().label("user_count"))
|
||||
.where(col(User.is_active).is_(True))
|
||||
.group_by(User.country_code)
|
||||
.order_by(col("user_count").desc())
|
||||
.order_by(func.count().desc())
|
||||
.limit(settings.ranking_cache_top_countries)
|
||||
)
|
||||
).all()
|
||||
|
||||
388
app/service/user_cache_service.py
Normal file
388
app/service/user_cache_service.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
用户缓存服务
|
||||
用于缓存用户信息,提供热缓存和实时刷新功能
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import User, UserResp
|
||||
from app.database.lazer_user import SEARCH_INCLUDED
|
||||
from app.database.pp_best_score import PPBestScore
|
||||
from app.database.score import Score, ScoreResp
|
||||
from app.log import logger
|
||||
from app.models.score import GameMode
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""自定义 JSON 编码器,支持 datetime 序列化"""
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def safe_json_dumps(data: Any) -> str:
|
||||
"""安全的 JSON 序列化,支持 datetime 对象"""
|
||||
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
|
||||
|
||||
|
||||
class UserCacheService:
|
||||
"""用户缓存服务"""
|
||||
|
||||
def __init__(self, redis: Redis):
|
||||
self.redis = redis
|
||||
self._refreshing = False
|
||||
self._background_tasks: set = set()
|
||||
|
||||
def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str:
|
||||
"""生成 V1 用户缓存键"""
|
||||
if ruleset:
|
||||
return f"v1_user:{user_id}:ruleset:{ruleset}"
|
||||
return f"v1_user:{user_id}"
|
||||
|
||||
async def get_v1_user_from_cache(
|
||||
self,
|
||||
user_id: int,
|
||||
ruleset: GameMode | None = None
|
||||
) -> dict | None:
|
||||
"""从缓存获取 V1 用户信息"""
|
||||
try:
|
||||
cache_key = self._get_v1_user_cache_key(user_id, ruleset)
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"V1 User cache hit for user {user_id}")
|
||||
return json.loads(cached_data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting V1 user from cache: {e}")
|
||||
return None
|
||||
|
||||
async def cache_v1_user(
|
||||
self,
|
||||
user_data: dict,
|
||||
user_id: int,
|
||||
ruleset: GameMode | None = None,
|
||||
expire_seconds: int | None = None
|
||||
):
|
||||
"""缓存 V1 用户信息"""
|
||||
try:
|
||||
if expire_seconds is None:
|
||||
expire_seconds = settings.user_cache_expire_seconds
|
||||
cache_key = self._get_v1_user_cache_key(user_id, ruleset)
|
||||
cached_data = safe_json_dumps(user_data)
|
||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||
logger.debug(f"Cached V1 user {user_id} for {expire_seconds}s")
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching V1 user: {e}")
|
||||
|
||||
async def invalidate_v1_user_cache(self, user_id: int):
|
||||
"""使 V1 用户缓存失效"""
|
||||
try:
|
||||
# 删除 V1 用户信息缓存
|
||||
pattern = f"v1_user:{user_id}*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating V1 user cache: {e}")
|
||||
|
||||
def _get_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str:
|
||||
"""生成用户缓存键"""
|
||||
if ruleset:
|
||||
return f"user:{user_id}:ruleset:{ruleset}"
|
||||
return f"user:{user_id}"
|
||||
|
||||
def _get_user_scores_cache_key(
|
||||
self,
|
||||
user_id: int,
|
||||
score_type: str,
|
||||
mode: GameMode | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> str:
|
||||
"""生成用户成绩缓存键"""
|
||||
mode_part = f":{mode}" if mode else ""
|
||||
return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}"
|
||||
|
||||
def _get_user_beatmapsets_cache_key(
|
||||
self,
|
||||
user_id: int,
|
||||
beatmapset_type: str,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> str:
|
||||
"""生成用户谱面集缓存键"""
|
||||
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
|
||||
|
||||
async def get_user_from_cache(
|
||||
self,
|
||||
user_id: int,
|
||||
ruleset: GameMode | None = None
|
||||
) -> UserResp | None:
|
||||
"""从缓存获取用户信息"""
|
||||
try:
|
||||
cache_key = self._get_user_cache_key(user_id, ruleset)
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"User cache hit for user {user_id}")
|
||||
data = json.loads(cached_data)
|
||||
return UserResp(**data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user from cache: {e}")
|
||||
return None
|
||||
|
||||
async def cache_user(
|
||||
self,
|
||||
user_resp: UserResp,
|
||||
ruleset: GameMode | None = None,
|
||||
expire_seconds: int | None = None
|
||||
):
|
||||
"""缓存用户信息"""
|
||||
try:
|
||||
if expire_seconds is None:
|
||||
expire_seconds = settings.user_cache_expire_seconds
|
||||
if user_resp.id is None:
|
||||
logger.warning("Cannot cache user with None id")
|
||||
return
|
||||
cache_key = self._get_user_cache_key(user_resp.id, ruleset)
|
||||
cached_data = user_resp.model_dump_json()
|
||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||
logger.debug(f"Cached user {user_resp.id} for {expire_seconds}s")
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching user: {e}")
|
||||
|
||||
async def get_user_scores_from_cache(
|
||||
self,
|
||||
user_id: int,
|
||||
score_type: str,
|
||||
mode: GameMode | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> list[ScoreResp] | None:
|
||||
"""从缓存获取用户成绩"""
|
||||
try:
|
||||
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
|
||||
data = json.loads(cached_data)
|
||||
return [ScoreResp(**score_data) for score_data in data]
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user scores from cache: {e}")
|
||||
return None
|
||||
|
||||
async def cache_user_scores(
|
||||
self,
|
||||
user_id: int,
|
||||
score_type: str,
|
||||
scores: list[ScoreResp],
|
||||
mode: GameMode | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
expire_seconds: int | None = None
|
||||
):
|
||||
"""缓存用户成绩"""
|
||||
try:
|
||||
if expire_seconds is None:
|
||||
expire_seconds = settings.user_scores_cache_expire_seconds
|
||||
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
|
||||
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
|
||||
scores_json_list = [score.model_dump_json() for score in scores]
|
||||
cached_data = f"[{','.join(scores_json_list)}]"
|
||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching user scores: {e}")
|
||||
|
||||
async def get_user_beatmapsets_from_cache(
|
||||
self,
|
||||
user_id: int,
|
||||
beatmapset_type: str,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> list[Any] | None:
|
||||
"""从缓存获取用户谱面集"""
|
||||
try:
|
||||
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}")
|
||||
return json.loads(cached_data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user beatmapsets from cache: {e}")
|
||||
return None
|
||||
|
||||
async def cache_user_beatmapsets(
|
||||
self,
|
||||
user_id: int,
|
||||
beatmapset_type: str,
|
||||
beatmapsets: list[Any],
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
expire_seconds: int | None = None
|
||||
):
|
||||
"""缓存用户谱面集"""
|
||||
try:
|
||||
if expire_seconds is None:
|
||||
expire_seconds = settings.user_beatmapsets_cache_expire_seconds
|
||||
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
|
||||
# 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps
|
||||
serialized_beatmapsets = []
|
||||
for bms in beatmapsets:
|
||||
if hasattr(bms, 'model_dump_json'):
|
||||
serialized_beatmapsets.append(bms.model_dump_json())
|
||||
else:
|
||||
serialized_beatmapsets.append(safe_json_dumps(bms))
|
||||
cached_data = f"[{','.join(serialized_beatmapsets)}]"
|
||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||
logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s")
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching user beatmapsets: {e}")
|
||||
|
||||
async def invalidate_user_cache(self, user_id: int):
|
||||
"""使用户缓存失效"""
|
||||
try:
|
||||
# 删除用户信息缓存
|
||||
pattern = f"user:{user_id}*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
logger.info(f"Invalidated {len(keys)} cache entries for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating user cache: {e}")
|
||||
|
||||
async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None):
|
||||
"""使用户成绩缓存失效"""
|
||||
try:
|
||||
# 删除用户成绩相关缓存
|
||||
mode_pattern = f":{mode}" if mode else "*"
|
||||
pattern = f"user:{user_id}:scores:*{mode_pattern}*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating user scores cache: {e}")
|
||||
|
||||
async def preload_user_cache(self, session: AsyncSession, user_ids: list[int]):
|
||||
"""预加载用户缓存"""
|
||||
if self._refreshing:
|
||||
return
|
||||
|
||||
self._refreshing = True
|
||||
try:
|
||||
logger.info(f"Preloading cache for {len(user_ids)} users")
|
||||
|
||||
# 批量获取用户
|
||||
users = (
|
||||
await session.exec(
|
||||
select(User).where(col(User.id).in_(user_ids))
|
||||
)
|
||||
).all()
|
||||
|
||||
# 串行缓存用户信息,避免并发数据库访问问题
|
||||
cached_count = 0
|
||||
for user in users:
|
||||
if user.id != BANCHOBOT_ID:
|
||||
try:
|
||||
await self._cache_single_user(user, session)
|
||||
cached_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cache user {user.id}: {e}")
|
||||
|
||||
logger.info(f"Preloaded cache for {cached_count} users")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preloading user cache: {e}")
|
||||
finally:
|
||||
self._refreshing = False
|
||||
|
||||
async def _cache_single_user(self, user: User, session: AsyncSession):
|
||||
"""缓存单个用户"""
|
||||
try:
|
||||
user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED)
|
||||
await self.cache_user(user_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching single user {user.id}: {e}")
|
||||
|
||||
async def refresh_user_cache_on_score_submit(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: int,
|
||||
mode: GameMode
|
||||
):
|
||||
"""成绩提交后刷新用户缓存"""
|
||||
try:
|
||||
# 使相关缓存失效(包括 v1 和 v2)
|
||||
await self.invalidate_user_cache(user_id)
|
||||
await self.invalidate_v1_user_cache(user_id)
|
||||
await self.invalidate_user_scores_cache(user_id, mode)
|
||||
|
||||
# 立即重新加载用户信息
|
||||
user = await session.get(User, user_id)
|
||||
if user and user.id != BANCHOBOT_ID:
|
||||
await self._cache_single_user(user, session)
|
||||
logger.info(f"Refreshed cache for user {user_id} after score submit")
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing user cache on score submit: {e}")
|
||||
|
||||
async def get_cache_stats(self) -> dict:
|
||||
"""获取缓存统计信息"""
|
||||
try:
|
||||
user_keys = await self.redis.keys("user:*")
|
||||
v1_user_keys = await self.redis.keys("v1_user:*")
|
||||
all_keys = user_keys + v1_user_keys
|
||||
total_size = 0
|
||||
|
||||
for key in all_keys[:100]: # 限制检查数量
|
||||
try:
|
||||
size = await self.redis.memory_usage(key)
|
||||
if size:
|
||||
total_size += size
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {
|
||||
"cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]),
|
||||
"cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]),
|
||||
"cached_user_scores": len([k for k in user_keys if ":scores:" in k]),
|
||||
"cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]),
|
||||
"total_cached_entries": len(all_keys),
|
||||
"estimated_total_size_mb": (
|
||||
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
|
||||
),
|
||||
"refreshing": self._refreshing,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user cache stats: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# 全局缓存服务实例
|
||||
_user_cache_service: UserCacheService | None = None
|
||||
|
||||
|
||||
def get_user_cache_service(redis: Redis) -> UserCacheService:
|
||||
"""获取用户缓存服务实例"""
|
||||
global _user_cache_service
|
||||
if _user_cache_service is None:
|
||||
_user_cache_service = UserCacheService(redis)
|
||||
return _user_cache_service
|
||||
Reference in New Issue
Block a user