From 80d4237c5d81dde69dfa64ac45ce6f1c3592da02 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?=
<74496778+GooGuJiang@users.noreply.github.com>
Date: Fri, 22 Aug 2025 00:07:19 +0800
Subject: [PATCH] ruff fix
---
app/calculator.py | 16 +--
app/config.py | 4 +-
app/database/beatmapset.py | 15 ++-
app/database/field_utils.py | 20 ++--
app/database/lazer_user.py | 2 +-
app/database/score.py | 17 ++-
app/database/statistics.py | 2 +-
app/fetcher/_base.py | 28 +++--
app/fetcher/beatmapset.py | 24 ++--
app/helpers/rate_limiter.py | 1 +
app/router/v1/user.py | 27 +++--
app/router/v2/cache.py | 39 +++---
app/router/v2/ranking.py | 61 +++++-----
app/router/v2/score.py | 4 +-
app/router/v2/user.py | 60 +++++-----
app/scheduler/__init__.py | 1 +
app/scheduler/cache_scheduler.py | 26 ++--
app/scheduler/user_cache_scheduler.py | 32 ++---
app/service/beatmap_cache_service.py | 5 +-
app/service/ranking_cache_service.py | 163 +++++++++++++-------------
app/service/user_cache_service.py | 142 ++++++++++++----------
test_ranking_serialization.py | 90 ++++++++------
22 files changed, 423 insertions(+), 356 deletions(-)
diff --git a/app/calculator.py b/app/calculator.py
index cc340cf..344bafe 100644
--- a/app/calculator.py
+++ b/app/calculator.py
@@ -88,6 +88,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
# 使用线程池执行计算密集型操作以避免阻塞事件循环
import asyncio
+
loop = asyncio.get_event_loop()
def _calculate_pp_sync():
@@ -131,11 +132,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
async def pre_fetch_and_calculate_pp(
- score: "Score",
- beatmap_id: int,
- session: AsyncSession,
- redis,
- fetcher
+ score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher
) -> float:
"""
优化版PP计算:预先获取beatmap文件并使用缓存
@@ -148,9 +145,7 @@ async def pre_fetch_and_calculate_pp(
if settings.suspicious_score_check:
beatmap_banned = (
await session.exec(
- select(exists()).where(
- col(BannedBeatmaps.beatmap_id) == beatmap_id
- )
+ select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id)
)
).first()
if beatmap_banned:
@@ -184,10 +179,7 @@ async def pre_fetch_and_calculate_pp(
async def batch_calculate_pp(
- scores_data: list[tuple["Score", int]],
- session: AsyncSession,
- redis,
- fetcher
+ scores_data: list[tuple["Score", int]], session: AsyncSession, redis, fetcher
) -> list[float]:
"""
批量计算PP:适用于重新计算或批量处理场景
diff --git a/app/config.py b/app/config.py
index 96ee2df..26d2bec 100644
--- a/app/config.py
+++ b/app/config.py
@@ -142,14 +142,14 @@ class Settings(BaseSettings):
beatmap_cache_expire_hours: int = 24
max_concurrent_pp_calculations: int = 10
enable_pp_calculation_threading: bool = True
-
+
# 排行榜缓存设置
enable_ranking_cache: bool = True
ranking_cache_expire_minutes: int = 10 # 排行榜缓存过期时间(分钟)
ranking_cache_refresh_interval_minutes: int = 10 # 排行榜缓存刷新间隔(分钟)
ranking_cache_max_pages: int = 20 # 最多缓存的页数
ranking_cache_top_countries: int = 20 # 缓存前N个国家的排行榜
-
+
# 用户缓存设置
enable_user_cache_preload: bool = True # 启用用户缓存预加载
user_cache_expire_seconds: int = 300 # 用户信息缓存过期时间(秒)
diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py
index 447da5b..48b0408 100644
--- a/app/database/beatmapset.py
+++ b/app/database/beatmapset.py
@@ -8,7 +8,7 @@ from app.models.score import GameMode
from .lazer_user import BASE_INCLUDES, User, UserResp
from pydantic import BaseModel, field_validator, model_validator
-from sqlalchemy import Boolean, JSON, Column, DateTime, Text
+from sqlalchemy import JSON, Boolean, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -205,7 +205,18 @@ class BeatmapsetResp(BeatmapsetBase):
favourite_count: int = 0
recent_favourites: list[UserResp] = Field(default_factory=list)
- @field_validator('nsfw', 'spotlight', 'video', 'can_be_hyped', 'discussion_locked', 'storyboard', 'discussion_enabled', 'is_scoreable', 'has_favourited', mode='before')
+ @field_validator(
+ "nsfw",
+ "spotlight",
+ "video",
+ "can_be_hyped",
+ "discussion_locked",
+ "storyboard",
+ "discussion_enabled",
+ "is_scoreable",
+ "has_favourited",
+ mode="before",
+ )
@classmethod
def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
diff --git a/app/database/field_utils.py b/app/database/field_utils.py
index 53c1011..5f18134 100644
--- a/app/database/field_utils.py
+++ b/app/database/field_utils.py
@@ -2,13 +2,16 @@
数据库字段类型工具
提供处理数据库和 Pydantic 之间类型转换的工具
"""
-from typing import Any, Union
+
+from typing import Any
+
from pydantic import field_validator
from sqlalchemy import Boolean
def bool_field_validator(field_name: str):
"""为特定布尔字段创建验证器,处理数据库中的 0/1 整数"""
+
@field_validator(field_name, mode="before")
@classmethod
def validate_bool_field(cls, v: Any) -> bool:
@@ -16,20 +19,21 @@ def bool_field_validator(field_name: str):
if isinstance(v, int):
return bool(v)
return v
+
return validate_bool_field
def create_bool_field(**kwargs):
"""创建一个带有正确 SQLAlchemy 列定义的布尔字段"""
- from sqlmodel import Field, Column
-
+ from sqlmodel import Column, Field
+
# 如果没有指定 sa_column,则使用 Boolean 类型
- if 'sa_column' not in kwargs:
+ if "sa_column" not in kwargs:
# 处理 index 参数
- index = kwargs.pop('index', False)
+ index = kwargs.pop("index", False)
if index:
- kwargs['sa_column'] = Column(Boolean, index=True)
+ kwargs["sa_column"] = Column(Boolean, index=True)
else:
- kwargs['sa_column'] = Column(Boolean)
-
+ kwargs["sa_column"] = Column(Boolean)
+
return Field(**kwargs)
diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py
index 515f769..232b00d 100644
--- a/app/database/lazer_user.py
+++ b/app/database/lazer_user.py
@@ -136,7 +136,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_qat: bool = False
is_bng: bool = False
- @field_validator('playmode', mode='before')
+ @field_validator("playmode", mode="before")
@classmethod
def validate_playmode(cls, v):
"""将字符串转换为 GameMode 枚举"""
diff --git a/app/database/score.py b/app/database/score.py
index ddfef6d..88e2cfa 100644
--- a/app/database/score.py
+++ b/app/database/score.py
@@ -100,7 +100,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
sa_column=Column(JSON), default_factory=dict
)
- @field_validator('maximum_statistics', mode='before')
+ @field_validator("maximum_statistics", mode="before")
@classmethod
def validate_maximum_statistics(cls, v):
"""处理 maximum_statistics 字段中的字符串键,转换为 HitResult 枚举"""
@@ -151,7 +151,7 @@ class Score(ScoreBase, table=True):
gamemode: GameMode = Field(index=True)
pinned_order: int = Field(default=0, exclude=True)
- @field_validator('gamemode', mode='before')
+ @field_validator("gamemode", mode="before")
@classmethod
def validate_gamemode(cls, v):
"""将字符串转换为 GameMode 枚举"""
@@ -209,7 +209,16 @@ class ScoreResp(ScoreBase):
ranked: bool = False
current_user_attributes: CurrentUserAttributes | None = None
- @field_validator('has_replay', 'passed', 'preserve', 'is_perfect_combo', 'legacy_perfect', 'processed', 'ranked', mode='before')
+ @field_validator(
+ "has_replay",
+ "passed",
+ "preserve",
+ "is_perfect_combo",
+ "legacy_perfect",
+ "processed",
+ "ranked",
+ mode="before",
+ )
@classmethod
def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
@@ -217,7 +226,7 @@ class ScoreResp(ScoreBase):
return bool(v)
return v
- @field_validator('statistics', 'maximum_statistics', mode='before')
+ @field_validator("statistics", "maximum_statistics", mode="before")
@classmethod
def validate_statistics_fields(cls, v):
"""处理统计字段中的字符串键,转换为 HitResult 枚举"""
diff --git a/app/database/statistics.py b/app/database/statistics.py
index 04b037d..01226cb 100644
--- a/app/database/statistics.py
+++ b/app/database/statistics.py
@@ -44,7 +44,7 @@ class UserStatisticsBase(SQLModel):
replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=True)
- @field_validator('mode', mode='before')
+ @field_validator("mode", mode="before")
@classmethod
def validate_mode(cls, v):
"""将字符串转换为 GameMode 枚举"""
diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py
index ec0bb4a..97d16db 100644
--- a/app/fetcher/_base.py
+++ b/app/fetcher/_base.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import time
-from typing import Optional
from app.dependencies.database import get_redis
from app.log import logger
@@ -11,6 +10,7 @@ from httpx import AsyncClient
class TokenAuthError(Exception):
"""Token 授权失败异常"""
+
pass
@@ -55,7 +55,7 @@ class BaseFetcher:
return await self._request_with_retry(url, method, **kwargs)
async def _request_with_retry(
- self, url: str, method: str = "GET", max_retries: Optional[int] = None, **kwargs
+ self, url: str, method: str = "GET", max_retries: int | None = None, **kwargs
) -> dict:
"""
带重试机制的请求方法
@@ -64,7 +64,7 @@ class BaseFetcher:
max_retries = self.max_retries
last_error = None
-
+
for attempt in range(max_retries + 1):
try:
# 检查 token 是否过期
@@ -126,7 +126,9 @@ class BaseFetcher:
)
continue
else:
- logger.error(f"Request failed after {max_retries + 1} attempts: {e}")
+ logger.error(
+ f"Request failed after {max_retries + 1} attempts: {e}"
+ )
break
# 如果所有重试都失败了
@@ -194,9 +196,13 @@ class BaseFetcher:
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
- logger.info(f"Successfully refreshed access token for client {self.client_id}")
+ logger.info(
+ f"Successfully refreshed access token for client {self.client_id}"
+ )
except Exception as e:
- logger.error(f"Failed to refresh access token for client {self.client_id}: {e}")
+ logger.error(
+ f"Failed to refresh access token for client {self.client_id}: {e}"
+ )
# 清除无效的 token,要求重新授权
self.access_token = ""
self.refresh_token = ""
@@ -204,7 +210,9 @@ class BaseFetcher:
redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
- logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}")
+ logger.warning(
+ f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}"
+ )
raise
async def _trigger_reauthorization(self) -> None:
@@ -216,18 +224,18 @@ class BaseFetcher:
f"Authentication failed after {self._auth_retry_count} attempts. "
f"Triggering reauthorization for client {self.client_id}"
)
-
+
# 清除内存中的 token
self.access_token = ""
self.refresh_token = ""
self.token_expiry = 0
self._auth_retry_count = 0 # 重置重试计数器
-
+
# 清除 Redis 中的 token
redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
-
+
logger.warning(
f"All tokens cleared for client {self.client_id}. "
f"Please re-authorize using: {self.authorize_url}"
diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py
index f87f2ed..9f3c025 100644
--- a/app/fetcher/beatmapset.py
+++ b/app/fetcher/beatmapset.py
@@ -101,6 +101,7 @@ class BeatmapsetFetcher(BaseFetcher):
return json.loads(cursor_json)
except Exception:
return {}
+
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
logger.opt(colors=True).debug(
f"[BeatmapsetFetcher] get_beatmapset: {beatmap_set_id}"
@@ -164,9 +165,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 将结果缓存 15 分钟
cache_ttl = 15 * 60 # 15 分钟
await redis_client.set(
- cache_key,
- json.dumps(api_response, separators=(",", ":")),
- ex=cache_ttl
+ cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl
)
logger.opt(colors=True).debug(
@@ -178,10 +177,12 @@ class BeatmapsetFetcher(BaseFetcher):
# 智能预取:只在用户明确搜索时才预取,避免过多API请求
# 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
- if (api_response.get("cursor") and
- (query.q or query.s != "leaderboard" or cursor)):
+ if api_response.get("cursor") and (
+ query.q or query.s != "leaderboard" or cursor
+ ):
# 在后台预取下1页(减少预取量)
import asyncio
+
# 不立即创建任务,而是延迟一段时间再预取
async def delayed_prefetch():
await asyncio.sleep(3.0) # 延迟3秒
@@ -200,8 +201,11 @@ class BeatmapsetFetcher(BaseFetcher):
return resp
async def prefetch_next_pages(
- self, query: SearchQueryModel, current_cursor: Cursor,
- redis_client: redis.Redis, pages: int = 3
+ self,
+ query: SearchQueryModel,
+ current_cursor: Cursor,
+ redis_client: redis.Redis,
+ pages: int = 3,
) -> None:
"""预取下几页内容"""
if not current_cursor:
@@ -269,7 +273,7 @@ class BeatmapsetFetcher(BaseFetcher):
await redis_client.set(
next_cache_key,
json.dumps(api_response, separators=(",", ":")),
- ex=prefetch_ttl
+ ex=prefetch_ttl,
)
logger.opt(colors=True).debug(
@@ -317,7 +321,6 @@ class BeatmapsetFetcher(BaseFetcher):
params=params,
)
-
if api_response.get("cursor"):
cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
@@ -327,7 +330,7 @@ class BeatmapsetFetcher(BaseFetcher):
await redis_client.set(
cache_key,
json.dumps(api_response, separators=(",", ":")),
- ex=cache_ttl
+ ex=cache_ttl,
)
logger.opt(colors=True).info(
@@ -335,7 +338,6 @@ class BeatmapsetFetcher(BaseFetcher):
f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
)
-
if api_response.get("cursor"):
await self.prefetch_next_pages(
query, api_response["cursor"], redis_client, pages=2
diff --git a/app/helpers/rate_limiter.py b/app/helpers/rate_limiter.py
index ad1f27b..40efc9b 100644
--- a/app/helpers/rate_limiter.py
+++ b/app/helpers/rate_limiter.py
@@ -5,6 +5,7 @@ Rate limiter for osu! API requests to avoid abuse detection.
- 突发:短时间内最多 200 次额外请求
- 建议:每分钟不超过 60 次请求以避免滥用检测
"""
+
from __future__ import annotations
import asyncio
diff --git a/app/router/v1/user.py b/app/router/v1/user.py
index d8938ed..bb57c3a 100644
--- a/app/router/v1/user.py
+++ b/app/router/v1/user.py
@@ -4,7 +4,6 @@ import asyncio
from datetime import datetime
from typing import Literal
-from app.config import settings
from app.database.lazer_user import User
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.dependencies.database import Database, get_redis
@@ -56,7 +55,7 @@ class V1User(AllStrModel):
# 确保 user_id 不为 None
if db_user.id is None:
raise ValueError("User ID cannot be None")
-
+
ruleset = ruleset or db_user.playmode
current_statistics: UserStatistics | None = None
for i in await db_user.awaitable_attrs.statistics:
@@ -118,26 +117,28 @@ async def get_user(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
-
+
# 确定查询方式和用户ID
is_id_query = type == "id" or user.isdigit()
-
+
# 解析 ruleset
ruleset = GameMode.from_int_extra(ruleset_id) if ruleset_id else None
-
+
# 如果是 ID 查询,先尝试从缓存获取
cached_v1_user = None
user_id_for_cache = None
-
+
if is_id_query:
try:
user_id_for_cache = int(user)
- cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset)
+ cached_v1_user = await cache_service.get_v1_user_from_cache(
+ user_id_for_cache, ruleset
+ )
if cached_v1_user:
return [V1User(**cached_v1_user)]
except (ValueError, TypeError):
pass # 不是有效的用户ID,继续数据库查询
-
+
# 从数据库查询用户
db_user = (
await session.exec(
@@ -146,23 +147,23 @@ async def get_user(
)
)
).first()
-
+
if not db_user:
return []
-
+
try:
# 生成用户数据
v1_user = await V1User.from_db(session, db_user, ruleset)
-
+
# 异步缓存结果(如果有用户ID)
if db_user.id is not None:
user_data = v1_user.model_dump()
asyncio.create_task(
cache_service.cache_v1_user(user_data, db_user.id, ruleset)
)
-
+
return [v1_user]
-
+
except KeyError:
raise HTTPException(400, "Invalid request")
except ValueError as e:
diff --git a/app/router/v2/cache.py b/app/router/v2/cache.py
index 7dd970c..08a0b27 100644
--- a/app/router/v2/cache.py
+++ b/app/router/v2/cache.py
@@ -2,15 +2,15 @@
缓存管理和监控接口
提供缓存统计、清理和预热功能
"""
+
from __future__ import annotations
from app.dependencies.database import get_redis
-from app.dependencies.user import get_current_user
from app.service.user_cache_service import get_user_cache_service
from .router import router
-from fastapi import Depends, HTTPException, Security
+from fastapi import Depends, HTTPException
from pydantic import BaseModel
from redis.asyncio import Redis
@@ -34,7 +34,7 @@ async def get_cache_stats(
try:
cache_service = get_user_cache_service(redis)
user_cache_stats = await cache_service.get_cache_stats()
-
+
# 获取 Redis 基本信息
redis_info = await redis.info()
redis_stats = {
@@ -47,20 +47,17 @@ async def get_cache_stats(
"evicted_keys": redis_info.get("evicted_keys", 0),
"expired_keys": redis_info.get("expired_keys", 0),
}
-
+
# 计算缓存命中率
hits = redis_stats["keyspace_hits"]
misses = redis_stats["keyspace_misses"]
hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0
redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2)
-
- return CacheStatsResponse(
- user_cache=user_cache_stats,
- redis_info=redis_stats
- )
-
+
+ return CacheStatsResponse(user_cache=user_cache_stats, redis_info=redis_stats)
+
except Exception as e:
- raise HTTPException(500, f"Failed to get cache stats: {str(e)}")
+ raise HTTPException(500, f"Failed to get cache stats: {e!s}")
@router.post(
@@ -80,7 +77,7 @@ async def invalidate_user_cache(
await cache_service.invalidate_v1_user_cache(user_id)
return {"message": f"Cache invalidated for user {user_id}"}
except Exception as e:
- raise HTTPException(500, f"Failed to invalidate cache: {str(e)}")
+ raise HTTPException(500, f"Failed to invalidate cache: {e!s}")
@router.post(
@@ -98,15 +95,15 @@ async def clear_all_user_cache(
user_keys = await redis.keys("user:*")
v1_user_keys = await redis.keys("v1_user:*")
all_keys = user_keys + v1_user_keys
-
+
if all_keys:
await redis.delete(*all_keys)
return {"message": f"Cleared {len(all_keys)} cache entries"}
else:
return {"message": "No cache entries found"}
-
+
except Exception as e:
- raise HTTPException(500, f"Failed to clear cache: {str(e)}")
+ raise HTTPException(500, f"Failed to clear cache: {e!s}")
class CacheWarmupRequest(BaseModel):
@@ -127,18 +124,22 @@ async def warmup_cache(
):
try:
cache_service = get_user_cache_service(redis)
-
+
if request.user_ids:
# 预热指定用户
from app.dependencies.database import with_db
+
async with with_db() as session:
await cache_service.preload_user_cache(session, request.user_ids)
return {"message": f"Warmed up cache for {len(request.user_ids)} users"}
else:
# 预热活跃用户
- from app.scheduler.user_cache_scheduler import schedule_user_cache_preload_task
+ from app.scheduler.user_cache_scheduler import (
+ schedule_user_cache_preload_task,
+ )
+
await schedule_user_cache_preload_task()
return {"message": f"Warmed up cache for top {request.limit} active users"}
-
+
except Exception as e:
- raise HTTPException(500, f"Failed to warmup cache: {str(e)}")
+ raise HTTPException(500, f"Failed to warmup cache: {e!s}")
diff --git a/app/router/v2/ranking.py b/app/router/v2/ranking.py
index 3bc640f..dcb6d8f 100644
--- a/app/router/v2/ranking.py
+++ b/app/router/v2/ranking.py
@@ -45,24 +45,24 @@ async def get_country_ranking(
# 获取 Redis 连接和缓存服务
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
-
+
# 尝试从缓存获取数据
cached_data = await cache_service.get_cached_country_ranking(ruleset, page)
-
+
if cached_data:
# 从缓存返回数据
return CountryResponse(
ranking=[CountryStatistics.model_validate(item) for item in cached_data]
)
-
+
# 缓存未命中,从数据库查询
response = CountryResponse(ranking=[])
countries = (await session.exec(select(User.country_code).distinct())).all()
-
+
for country in countries:
if not country: # 跳过空的国家代码
continue
-
+
statistics = (
await session.exec(
select(UserStatistics).where(
@@ -73,10 +73,10 @@ async def get_country_ranking(
)
)
).all()
-
+
if not statistics: # 跳过没有数据的国家
continue
-
+
pp = 0
country_stats = CountryStatistics(
code=country,
@@ -92,27 +92,28 @@ async def get_country_ranking(
pp += stat.pp
country_stats.performance = round(pp)
response.ranking.append(country_stats)
-
+
response.ranking.sort(key=lambda x: x.performance, reverse=True)
-
+
# 分页处理
page_size = 50
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
-
+
# 获取当前页的数据
current_page_data = response.ranking[start_idx:end_idx]
-
+
# 异步缓存数据(不等待完成)
cache_data = [item.model_dump() for item in current_page_data]
cache_task = cache_service.cache_country_ranking(
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
)
-
+
# 创建后台任务来缓存数据
import asyncio
+
asyncio.create_task(cache_task)
-
+
# 返回当前页的结果
response.ranking = current_page_data
return response
@@ -142,20 +143,16 @@ async def get_user_ranking(
# 获取 Redis 连接和缓存服务
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
-
+
# 尝试从缓存获取数据
- cached_data = await cache_service.get_cached_ranking(
- ruleset, type, country, page
- )
-
+ cached_data = await cache_service.get_cached_ranking(ruleset, type, country, page)
+
if cached_data:
# 从缓存返回数据
return TopUsersResponse(
- ranking=[
- UserStatisticsResp.model_validate(item) for item in cached_data
- ]
+ ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
)
-
+
# 缓存未命中,从数据库查询
wheres = [
col(UserStatistics.mode) == ruleset,
@@ -170,7 +167,7 @@ async def get_user_ranking(
order_by = col(UserStatistics.ranked_score).desc()
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
-
+
statistics_list = await session.exec(
select(UserStatistics)
.where(*wheres)
@@ -178,7 +175,7 @@ async def get_user_ranking(
.limit(50)
.offset(50 * (page - 1))
)
-
+
# 转换为响应格式
ranking_data = []
for statistics in statistics_list:
@@ -186,18 +183,24 @@ async def get_user_ranking(
statistics, session, None, include
)
ranking_data.append(user_stats_resp)
-
+
# 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data]
cache_task = cache_service.cache_ranking(
- ruleset, type, cache_data, country, page, ttl=settings.ranking_cache_expire_minutes * 60
+ ruleset,
+ type,
+ cache_data,
+ country,
+ page,
+ ttl=settings.ranking_cache_expire_minutes * 60,
)
-
+
# 创建后台任务来缓存数据
import asyncio
+
asyncio.create_task(cache_task)
-
+
resp = TopUsersResponse(ranking=ranking_data)
return resp
@@ -328,4 +331,4 @@ async def get_ranking_cache_stats(
cache_service = get_ranking_cache_service(redis)
stats = await cache_service.get_cache_stats()
- return stats """
\ No newline at end of file
+ return stats """
diff --git a/app/router/v2/score.py b/app/router/v2/score.py
index 648505a..d3eb903 100644
--- a/app/router/v2/score.py
+++ b/app/router/v2/score.py
@@ -183,7 +183,7 @@ async def submit_score(
}
db.add(rank_event)
await db.commit()
-
+
# 成绩提交后刷新用户缓存
try:
user_cache_service = get_user_cache_service(redis)
@@ -193,7 +193,7 @@ async def submit_score(
)
except Exception as e:
logger.error(f"Failed to refresh user cache after score submit: {e}")
-
+
background_task.add_task(process_user_achievement, resp.id)
return resp
diff --git a/app/router/v2/user.py b/app/router/v2/user.py
index 28ef086..9e67de6 100644
--- a/app/router/v2/user.py
+++ b/app/router/v2/user.py
@@ -57,25 +57,27 @@ async def get_users(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
-
+
if user_ids:
# 先尝试从缓存获取
cached_users = []
uncached_user_ids = []
-
+
for user_id in user_ids[:50]: # 限制50个
cached_user = await cache_service.get_user_from_cache(user_id)
if cached_user:
cached_users.append(cached_user)
else:
uncached_user_ids.append(user_id)
-
+
# 查询未缓存的用户
if uncached_user_ids:
searched_users = (
- await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))
+ await session.exec(
+ select(User).where(col(User.id).in_(uncached_user_ids))
+ )
).all()
-
+
# 将查询到的用户添加到缓存并返回
for searched_user in searched_users:
if searched_user.id != BANCHOBOT_ID:
@@ -87,7 +89,7 @@ async def get_users(
cached_users.append(user_resp)
# 异步缓存,不阻塞响应
asyncio.create_task(cache_service.cache_user(user_resp))
-
+
return BatchUserResponse(users=cached_users)
else:
searched_users = (await session.exec(select(User).limit(50))).all()
@@ -102,7 +104,7 @@ async def get_users(
users.append(user_resp)
# 异步缓存
asyncio.create_task(cache_service.cache_user(user_resp))
-
+
return BatchUserResponse(users=users)
@@ -121,14 +123,14 @@ async def get_user_info_ruleset(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
-
+
# 如果是数字ID,先尝试从缓存获取
if user_id.isdigit():
user_id_int = int(user_id)
cached_user = await cache_service.get_user_from_cache(user_id_int, ruleset)
if cached_user:
return cached_user
-
+
searched_user = (
await session.exec(
select(User).where(
@@ -140,17 +142,17 @@ async def get_user_info_ruleset(
).first()
if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
-
+
user_resp = await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
ruleset=ruleset,
)
-
+
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
-
+
return user_resp
@@ -169,14 +171,14 @@ async def get_user_info(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
-
+
# 如果是数字ID,先尝试从缓存获取
if user_id.isdigit():
user_id_int = int(user_id)
cached_user = await cache_service.get_user_from_cache(user_id_int)
if cached_user:
return cached_user
-
+
searched_user = (
await session.exec(
select(User).where(
@@ -188,16 +190,16 @@ async def get_user_info(
).first()
if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
-
+
user_resp = await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
)
-
+
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp))
-
+
return user_resp
@@ -218,7 +220,7 @@ async def get_user_beatmapsets(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
-
+
# 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache(
user_id, type.value, limit, offset
@@ -229,7 +231,7 @@ async def get_user_beatmapsets(
return [BeatmapPlaycountsResp(**item) for item in cached_result]
else:
return [BeatmapsetResp(**item) for item in cached_result]
-
+
user = await session.get(User, user_id)
if not user or user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
@@ -275,10 +277,14 @@ async def get_user_beatmapsets(
# 异步缓存结果
async def cache_beatmapsets():
try:
- await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
+ await cache_service.cache_user_beatmapsets(
+ user_id, type.value, resp, limit, offset
+ )
except Exception as e:
- logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
-
+ logger.error(
+ f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
+ )
+
asyncio.create_task(cache_beatmapsets())
return resp
@@ -311,7 +317,7 @@ async def get_user_scores(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
-
+
# 先尝试从缓存获取(对于recent类型使用较短的缓存时间)
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
cached_scores = await cache_service.get_user_scores_from_cache(
@@ -319,7 +325,7 @@ async def get_user_scores(
)
if cached_scores is not None:
return cached_scores
-
+
db_user = await session.get(User, user_id)
if not db_user or db_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
@@ -355,7 +361,7 @@ async def get_user_scores(
).all()
if not scores:
return []
-
+
score_responses = [
await ScoreResp.from_db(
session,
@@ -363,14 +369,14 @@ async def get_user_scores(
)
for score in scores
]
-
+
# 异步缓存结果
asyncio.create_task(
cache_service.cache_user_scores(
user_id, type, score_responses, mode, limit, offset, cache_expire
)
)
-
+
return score_responses
diff --git a/app/scheduler/__init__.py b/app/scheduler/__init__.py
index c9d77d0..d6e4f7c 100644
--- a/app/scheduler/__init__.py
+++ b/app/scheduler/__init__.py
@@ -1,4 +1,5 @@
"""缓存调度器模块"""
+
from __future__ import annotations
from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler
diff --git a/app/scheduler/cache_scheduler.py b/app/scheduler/cache_scheduler.py
index 72722e3..8edecfb 100644
--- a/app/scheduler/cache_scheduler.py
+++ b/app/scheduler/cache_scheduler.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
from app.config import settings
-from app.dependencies.database import get_db, get_redis
+from app.dependencies.database import get_redis
from app.dependencies.fetcher import get_fetcher
from app.log import logger
from app.scheduler.user_cache_scheduler import (
@@ -44,10 +44,10 @@ class CacheScheduler:
"""运行调度器主循环"""
# 启动时立即执行一次预热
await self._warmup_cache()
-
+
# 启动时执行一次排行榜缓存刷新
await self._refresh_ranking_cache()
-
+
# 启动时执行一次用户缓存预热
await self._warmup_user_cache()
@@ -55,14 +55,16 @@ class CacheScheduler:
ranking_cache_counter = 0
user_cache_counter = 0
user_cleanup_counter = 0
-
+
# 从配置文件获取间隔设置
check_interval = 5 * 60 # 5分钟检查间隔
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
- ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取
+ ranking_cache_interval = (
+ settings.ranking_cache_refresh_interval_minutes * 60
+ ) # 从配置读取
user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔
user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔
-
+
beatmap_cache_cycles = beatmap_cache_interval // check_interval
ranking_cache_cycles = ranking_cache_interval // check_interval
user_cache_cycles = user_cache_interval // check_interval
@@ -90,12 +92,12 @@ class CacheScheduler:
if ranking_cache_counter >= ranking_cache_cycles:
await self._refresh_ranking_cache()
ranking_cache_counter = 0
-
+
# 用户缓存预加载
if user_cache_counter >= user_cache_cycles:
await self._preload_user_cache()
user_cache_counter = 0
-
+
# 用户缓存清理
if user_cleanup_counter >= user_cleanup_cycles:
await self._cleanup_user_cache()
@@ -129,15 +131,14 @@ class CacheScheduler:
logger.info("Starting ranking cache refresh...")
redis = get_redis()
-
+
# 导入排行榜缓存服务
+ # 使用独立的数据库会话
+ from app.dependencies.database import with_db
from app.service.ranking_cache_service import (
- get_ranking_cache_service,
schedule_ranking_refresh_task,
)
- # 使用独立的数据库会话
- from app.dependencies.database import with_db
async with with_db() as session:
await schedule_ranking_refresh_task(session, redis)
@@ -171,6 +172,7 @@ class CacheScheduler:
# Beatmap缓存调度器(保持向后兼容)
class BeatmapsetCacheScheduler(CacheScheduler):
"""谱面集缓存调度器 - 为了向后兼容"""
+
pass
diff --git a/app/scheduler/user_cache_scheduler.py b/app/scheduler/user_cache_scheduler.py
index cf43af2..b98df8a 100644
--- a/app/scheduler/user_cache_scheduler.py
+++ b/app/scheduler/user_cache_scheduler.py
@@ -1,15 +1,15 @@
"""
用户缓存预热任务调度器
"""
+
from __future__ import annotations
import asyncio
from datetime import UTC, datetime, timedelta
from app.config import settings
-from app.database import User
from app.database.score import Score
-from app.dependencies.database import get_db, get_redis
+from app.dependencies.database import get_redis
from app.log import logger
from app.service.user_cache_service import get_user_cache_service
@@ -25,16 +25,17 @@ async def schedule_user_cache_preload_task():
try:
logger.info("Starting user cache preload task...")
-
+
redis = get_redis()
cache_service = get_user_cache_service(redis)
-
+
# 使用独立的数据库会话
from app.dependencies.database import with_db
+
async with with_db() as session:
# 获取最近24小时内活跃的用户(提交过成绩的用户)
recent_time = datetime.now(UTC) - timedelta(hours=24)
-
+
active_user_ids = (
await session.exec(
select(Score.user_id, func.count().label("score_count"))
@@ -44,7 +45,7 @@ async def schedule_user_cache_preload_task():
.limit(settings.user_cache_max_preload_users) # 使用配置中的限制
)
).all()
-
+
if active_user_ids:
user_ids = [row[0] for row in active_user_ids]
await cache_service.preload_user_cache(session, user_ids)
@@ -62,17 +63,18 @@ async def schedule_user_cache_warmup_task():
"""定时用户缓存预热任务 - 预加载排行榜前100用户"""
try:
logger.info("Starting user cache warmup task...")
-
+
redis = get_redis()
cache_service = get_user_cache_service(redis)
-
+
# 使用独立的数据库会话
from app.dependencies.database import with_db
+
async with with_db() as session:
# 获取全球排行榜前100的用户
from app.database.statistics import UserStatistics
from app.models.score import GameMode
-
+
for mode in GameMode:
try:
top_users = (
@@ -83,15 +85,15 @@ async def schedule_user_cache_warmup_task():
.limit(100)
)
).all()
-
+
if top_users:
user_ids = list(top_users)
await cache_service.preload_user_cache(session, user_ids)
logger.info(f"Warmed cache for top 100 users in {mode}")
-
+
# 避免过载,稍微延迟
await asyncio.sleep(1)
-
+
except Exception as e:
logger.error(f"Failed to warm cache for {mode}: {e}")
continue
@@ -106,13 +108,13 @@ async def schedule_user_cache_cleanup_task():
"""定时用户缓存清理任务"""
try:
logger.info("Starting user cache cleanup task...")
-
+
redis = get_redis()
-
+
# 清理过期的用户缓存(Redis会自动处理TTL,这里主要记录统计信息)
cache_service = get_user_cache_service(redis)
stats = await cache_service.get_cache_stats()
-
+
logger.info(f"User cache stats: {stats}")
logger.info("User cache cleanup task completed successfully")
diff --git a/app/service/beatmap_cache_service.py b/app/service/beatmap_cache_service.py
index a4f6c17..2f76426 100644
--- a/app/service/beatmap_cache_service.py
+++ b/app/service/beatmap_cache_service.py
@@ -2,6 +2,7 @@
Beatmap缓存预取服务
用于提前缓存热门beatmap,减少成绩计算时的获取延迟
"""
+
from __future__ import annotations
import asyncio
@@ -155,9 +156,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS
async def schedule_preload_task(
- session: AsyncSession,
- redis: Redis,
- fetcher: "Fetcher"
+ session: AsyncSession, redis: Redis, fetcher: "Fetcher"
):
"""
定时预加载任务
diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py
index 7eff17d..8006f8d 100644
--- a/app/service/ranking_cache_service.py
+++ b/app/service/ranking_cache_service.py
@@ -2,11 +2,12 @@
用户排行榜缓存服务
用于缓存用户排行榜数据,减轻数据库压力
"""
+
from __future__ import annotations
import asyncio
+from datetime import UTC, datetime
import json
-from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Literal
from app.config import settings
@@ -24,7 +25,7 @@ if TYPE_CHECKING:
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
-
+
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
@@ -33,7 +34,9 @@ class DateTimeEncoder(json.JSONEncoder):
def safe_json_dumps(data) -> str:
"""安全的 JSON 序列化,支持 datetime 对象"""
- return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
+ return json.dumps(
+ data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")
+ )
class RankingCacheService:
@@ -84,7 +87,7 @@ class RankingCacheService:
try:
cache_key = self._get_cache_key(ruleset, type, country, page)
cached_data = await self.redis.get(cache_key)
-
+
if cached_data:
return json.loads(cached_data)
return None
@@ -107,11 +110,7 @@ class RankingCacheService:
# 使用配置文件的TTL设置
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60
- await self.redis.set(
- cache_key,
- safe_json_dumps(ranking_data),
- ex=ttl
- )
+ await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
logger.debug(f"Cached ranking data for {cache_key}")
except Exception as e:
logger.error(f"Error caching ranking: {e}")
@@ -126,7 +125,7 @@ class RankingCacheService:
try:
cache_key = self._get_stats_cache_key(ruleset, type, country)
cached_data = await self.redis.get(cache_key)
-
+
if cached_data:
return json.loads(cached_data)
return None
@@ -148,11 +147,7 @@ class RankingCacheService:
# 使用配置文件的TTL设置,统计信息缓存时间更长
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
- await self.redis.set(
- cache_key,
- safe_json_dumps(stats),
- ex=ttl
- )
+ await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
logger.debug(f"Cached stats for {cache_key}")
except Exception as e:
logger.error(f"Error caching stats: {e}")
@@ -166,7 +161,7 @@ class RankingCacheService:
try:
cache_key = self._get_country_cache_key(ruleset, page)
cached_data = await self.redis.get(cache_key)
-
+
if cached_data:
return json.loads(cached_data)
return None
@@ -186,11 +181,7 @@ class RankingCacheService:
cache_key = self._get_country_cache_key(ruleset, page)
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60
- await self.redis.set(
- cache_key,
- safe_json_dumps(ranking_data),
- ex=ttl
- )
+ await self.redis.set(cache_key, safe_json_dumps(ranking_data), ex=ttl)
logger.debug(f"Cached country ranking data for {cache_key}")
except Exception as e:
logger.error(f"Error caching country ranking: {e}")
@@ -200,7 +191,7 @@ class RankingCacheService:
try:
cache_key = self._get_country_stats_cache_key(ruleset)
cached_data = await self.redis.get(cache_key)
-
+
if cached_data:
return json.loads(cached_data)
return None
@@ -219,11 +210,7 @@ class RankingCacheService:
cache_key = self._get_country_stats_cache_key(ruleset)
if ttl is None:
ttl = settings.ranking_cache_expire_minutes * 60 * 6 # 6倍时间
- await self.redis.set(
- cache_key,
- safe_json_dumps(stats),
- ex=ttl
- )
+ await self.redis.set(cache_key, safe_json_dumps(stats), ex=ttl)
logger.debug(f"Cached country stats for {cache_key}")
except Exception as e:
logger.error(f"Error caching country stats: {e}")
@@ -238,7 +225,9 @@ class RankingCacheService:
) -> None:
"""刷新排行榜缓存"""
if self._refreshing:
- logger.info(f"Ranking cache refresh already in progress for {ruleset}:{type}")
+ logger.info(
+ f"Ranking cache refresh already in progress for {ruleset}:{type}"
+ )
return
# 使用配置文件的设置
@@ -248,7 +237,7 @@ class RankingCacheService:
self._refreshing = True
try:
logger.info(f"Starting ranking cache refresh for {ruleset}:{type}")
-
+
# 构建查询条件
wheres = [
col(UserStatistics.mode) == ruleset,
@@ -256,20 +245,22 @@ class RankingCacheService:
col(UserStatistics.is_ranked).is_(True),
]
include = ["user"]
-
+
if type == "performance":
order_by = col(UserStatistics.pp).desc()
include.append("rank_change_since_30_days")
else:
order_by = col(UserStatistics.ranked_score).desc()
-
+
if country:
- wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
+ wheres.append(
+ col(UserStatistics.user).has(country_code=country.upper())
+ )
# 获取总用户数用于统计
total_users_query = select(UserStatistics).where(*wheres)
total_users = len((await session.exec(total_users_query)).all())
-
+
# 计算统计信息
stats = {
"total_users": total_users,
@@ -278,7 +269,7 @@ class RankingCacheService:
"ruleset": ruleset,
"country": country,
}
-
+
# 缓存统计信息
await self.cache_stats(ruleset, type, stats, country)
@@ -292,11 +283,11 @@ class RankingCacheService:
.limit(50)
.offset(50 * (page - 1))
)
-
+
statistics_data = statistics_list.all()
if not statistics_data:
break # 没有更多数据
-
+
# 转换为响应格式并确保正确序列化
ranking_data = []
for statistics in statistics_data:
@@ -306,21 +297,19 @@ class RankingCacheService:
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题
user_dict = json.loads(user_stats_resp.model_dump_json())
ranking_data.append(user_dict)
-
+
# 缓存这一页的数据
- await self.cache_ranking(
- ruleset, type, ranking_data, country, page
- )
-
+ await self.cache_ranking(ruleset, type, ranking_data, country, page)
+
# 添加延迟避免数据库过载
if page < max_pages:
await asyncio.sleep(0.1)
-
+
except Exception as e:
logger.error(f"Error caching page {page} for {ruleset}:{type}: {e}")
-
+
logger.info(f"Completed ranking cache refresh for {ruleset}:{type}")
-
+
except Exception as e:
logger.error(f"Ranking cache refresh failed for {ruleset}:{type}: {e}")
finally:
@@ -334,7 +323,9 @@ class RankingCacheService:
) -> None:
"""刷新地区排行榜缓存"""
if self._refreshing:
- logger.info(f"Country ranking cache refresh already in progress for {ruleset}")
+ logger.info(
+ f"Country ranking cache refresh already in progress for {ruleset}"
+ )
return
if max_pages is None:
@@ -343,17 +334,18 @@ class RankingCacheService:
self._refreshing = True
try:
logger.info(f"Starting country ranking cache refresh for {ruleset}")
-
+
# 获取所有国家
from app.database import User
+
countries = (await session.exec(select(User.country_code).distinct())).all()
-
+
# 计算每个国家的统计数据
country_stats_list = []
for country in countries:
if not country: # 跳过空的国家代码
continue
-
+
statistics = (
await session.exec(
select(UserStatistics).where(
@@ -364,10 +356,10 @@ class RankingCacheService:
)
)
).all()
-
+
if not statistics: # 跳过没有数据的国家
continue
-
+
pp = 0
country_stats = {
"code": country,
@@ -376,48 +368,48 @@ class RankingCacheService:
"ranked_score": 0,
"performance": 0,
}
-
+
for stat in statistics:
country_stats["active_users"] += 1
country_stats["play_count"] += stat.play_count
country_stats["ranked_score"] += stat.ranked_score
pp += stat.pp
-
+
country_stats["performance"] = round(pp)
country_stats_list.append(country_stats)
-
+
# 按表现分排序
country_stats_list.sort(key=lambda x: x["performance"], reverse=True)
-
+
# 计算统计信息
stats = {
"total_countries": len(country_stats_list),
"last_updated": datetime.now(UTC).isoformat(),
"ruleset": ruleset,
}
-
+
# 缓存统计信息
await self.cache_country_stats(ruleset, stats)
-
+
# 分页缓存数据(每页50个国家)
page_size = 50
for page in range(1, max_pages + 1):
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
-
+
page_data = country_stats_list[start_idx:end_idx]
if not page_data:
break # 没有更多数据
-
+
# 缓存这一页的数据
await self.cache_country_ranking(ruleset, page_data, page)
-
+
# 添加延迟避免Redis过载
if page < max_pages and page_data:
await asyncio.sleep(0.1)
-
+
logger.info(f"Completed country ranking cache refresh for {ruleset}")
-
+
except Exception as e:
logger.error(f"Country ranking cache refresh failed for {ruleset}: {e}")
finally:
@@ -427,11 +419,12 @@ class RankingCacheService:
"""刷新所有排行榜缓存"""
game_modes = [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
-
+
# 获取需要缓存的国家列表(活跃用户数量前20的国家)
from app.database import User
+
from sqlmodel import func
-
+
countries_query = (
await session.exec(
select(User.country_code, func.count().label("user_count"))
@@ -441,38 +434,40 @@ class RankingCacheService:
.limit(settings.ranking_cache_top_countries)
)
).all()
-
+
top_countries = [country for country, _ in countries_query]
-
+
refresh_tasks = []
-
+
# 全球排行榜
for mode in game_modes:
for ranking_type in ranking_types:
task = self.refresh_ranking_cache(session, mode, ranking_type)
refresh_tasks.append(task)
-
+
# 国家排行榜(仅前20个国家)
for country in top_countries:
for mode in game_modes:
for ranking_type in ranking_types:
- task = self.refresh_ranking_cache(session, mode, ranking_type, country)
+ task = self.refresh_ranking_cache(
+ session, mode, ranking_type, country
+ )
refresh_tasks.append(task)
-
+
# 地区排行榜
for mode in game_modes:
task = self.refresh_country_ranking_cache(session, mode)
refresh_tasks.append(task)
-
+
# 并发执行刷新任务,但限制并发数
semaphore = asyncio.Semaphore(5) # 最多同时5个任务
-
+
async def bounded_refresh(task):
async with semaphore:
await task
-
+
bounded_tasks = [bounded_refresh(task) for task in refresh_tasks]
-
+
try:
await asyncio.gather(*bounded_tasks, return_exceptions=True)
logger.info("All ranking cache refresh completed")
@@ -489,7 +484,7 @@ class RankingCacheService:
"""使缓存失效"""
try:
deleted_keys = 0
-
+
if ruleset and type:
# 删除特定的用户排行榜缓存
country_part = f":{country.upper()}" if country else ""
@@ -498,12 +493,14 @@ class RankingCacheService:
if keys:
await self.redis.delete(*keys)
deleted_keys += len(keys)
- logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}")
+ logger.info(
+ f"Invalidated {len(keys)} cache keys for {ruleset}:{type}"
+ )
elif ruleset:
# 删除特定游戏模式的所有缓存
patterns = [
f"ranking:{ruleset}:*",
- f"country_ranking:{ruleset}:*" if include_country_ranking else None
+ f"country_ranking:{ruleset}:*" if include_country_ranking else None,
]
for pattern in patterns:
if pattern:
@@ -516,15 +513,15 @@ class RankingCacheService:
patterns = ["ranking:*"]
if include_country_ranking:
patterns.append("country_ranking:*")
-
+
for pattern in patterns:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
deleted_keys += len(keys)
-
+
logger.info(f"Invalidated all {deleted_keys} ranking cache keys")
-
+
except Exception as e:
logger.error(f"Error invalidating cache: {e}")
@@ -535,7 +532,7 @@ class RankingCacheService:
pattern = f"country_ranking:{ruleset}:*"
else:
pattern = "country_ranking:*"
-
+
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
@@ -550,10 +547,10 @@ class RankingCacheService:
ranking_keys = await self.redis.keys("ranking:*")
# 获取地区排行榜缓存
country_keys = await self.redis.keys("country_ranking:*")
-
+
total_keys = ranking_keys + country_keys
total_size = 0
-
+
for key in total_keys[:100]: # 限制检查数量以避免性能问题
try:
size = await self.redis.memory_usage(key)
@@ -561,7 +558,7 @@ class RankingCacheService:
total_size += size
except Exception:
continue
-
+
return {
"cached_user_rankings": len(ranking_keys),
"cached_country_rankings": len(country_keys),
diff --git a/app/service/user_cache_service.py b/app/service/user_cache_service.py
index ed0926e..f999fe8 100644
--- a/app/service/user_cache_service.py
+++ b/app/service/user_cache_service.py
@@ -2,24 +2,23 @@
用户缓存服务
用于缓存用户信息,提供热缓存和实时刷新功能
"""
+
from __future__ import annotations
-import asyncio
+from datetime import datetime
import json
-from datetime import UTC, datetime, timedelta
-from typing import TYPE_CHECKING, Any, Literal
+from typing import TYPE_CHECKING, Any
from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import User, UserResp
from app.database.lazer_user import SEARCH_INCLUDED
-from app.database.pp_best_score import PPBestScore
-from app.database.score import Score, ScoreResp
+from app.database.score import ScoreResp
from app.log import logger
from app.models.score import GameMode
from redis.asyncio import Redis
-from sqlmodel import col, exists, select
+from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
@@ -28,7 +27,7 @@ if TYPE_CHECKING:
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
-
+
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
@@ -48,16 +47,16 @@ class UserCacheService:
self._refreshing = False
self._background_tasks: set = set()
- def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str:
+ def _get_v1_user_cache_key(
+ self, user_id: int, ruleset: GameMode | None = None
+ ) -> str:
"""生成 V1 用户缓存键"""
if ruleset:
return f"v1_user:{user_id}:ruleset:{ruleset}"
return f"v1_user:{user_id}"
async def get_v1_user_from_cache(
- self,
- user_id: int,
- ruleset: GameMode | None = None
+ self, user_id: int, ruleset: GameMode | None = None
) -> dict | None:
"""从缓存获取 V1 用户信息"""
try:
@@ -72,11 +71,11 @@ class UserCacheService:
return None
async def cache_v1_user(
- self,
- user_data: dict,
+ self,
+ user_data: dict,
user_id: int,
ruleset: GameMode | None = None,
- expire_seconds: int | None = None
+ expire_seconds: int | None = None,
):
"""缓存 V1 用户信息"""
try:
@@ -97,7 +96,9 @@ class UserCacheService:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
- logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}")
+ logger.info(
+ f"Invalidated {len(keys)} V1 cache entries for user {user_id}"
+ )
except Exception as e:
logger.error(f"Error invalidating V1 user cache: {e}")
@@ -108,31 +109,25 @@ class UserCacheService:
return f"user:{user_id}"
def _get_user_scores_cache_key(
- self,
- user_id: int,
- score_type: str,
+ self,
+ user_id: int,
+ score_type: str,
mode: GameMode | None = None,
limit: int = 100,
- offset: int = 0
+ offset: int = 0,
) -> str:
"""生成用户成绩缓存键"""
mode_part = f":{mode}" if mode else ""
return f"user:{user_id}:scores:{score_type}{mode_part}:limit:{limit}:offset:{offset}"
def _get_user_beatmapsets_cache_key(
- self,
- user_id: int,
- beatmapset_type: str,
- limit: int = 100,
- offset: int = 0
+ self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0
) -> str:
"""生成用户谱面集缓存键"""
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
async def get_user_from_cache(
- self,
- user_id: int,
- ruleset: GameMode | None = None
+ self, user_id: int, ruleset: GameMode | None = None
) -> UserResp | None:
"""从缓存获取用户信息"""
try:
@@ -148,10 +143,10 @@ class UserCacheService:
return None
async def cache_user(
- self,
- user_resp: UserResp,
+ self,
+ user_resp: UserResp,
ruleset: GameMode | None = None,
- expire_seconds: int | None = None
+ expire_seconds: int | None = None,
):
"""缓存用户信息"""
try:
@@ -173,14 +168,18 @@ class UserCacheService:
score_type: str,
mode: GameMode | None = None,
limit: int = 100,
- offset: int = 0
+ offset: int = 0,
) -> list[ScoreResp] | None:
"""从缓存获取用户成绩"""
try:
- cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
+ cache_key = self._get_user_scores_cache_key(
+ user_id, score_type, mode, limit, offset
+ )
cached_data = await self.redis.get(cache_key)
if cached_data:
- logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
+ logger.debug(
+ f"User scores cache hit for user {user_id}, type {score_type}"
+ )
data = json.loads(cached_data)
return [ScoreResp(**score_data) for score_data in data]
return None
@@ -196,34 +195,38 @@ class UserCacheService:
mode: GameMode | None = None,
limit: int = 100,
offset: int = 0,
- expire_seconds: int | None = None
+ expire_seconds: int | None = None,
):
"""缓存用户成绩"""
try:
if expire_seconds is None:
expire_seconds = settings.user_scores_cache_expire_seconds
- cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
+ cache_key = self._get_user_scores_cache_key(
+ user_id, score_type, mode, limit, offset
+ )
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
scores_json_list = [score.model_dump_json() for score in scores]
cached_data = f"[{','.join(scores_json_list)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
- logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
+ logger.debug(
+ f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s"
+ )
except Exception as e:
logger.error(f"Error caching user scores: {e}")
async def get_user_beatmapsets_from_cache(
- self,
- user_id: int,
- beatmapset_type: str,
- limit: int = 100,
- offset: int = 0
+ self, user_id: int, beatmapset_type: str, limit: int = 100, offset: int = 0
) -> list[Any] | None:
"""从缓存获取用户谱面集"""
try:
- cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
+ cache_key = self._get_user_beatmapsets_cache_key(
+ user_id, beatmapset_type, limit, offset
+ )
cached_data = await self.redis.get(cache_key)
if cached_data:
- logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}")
+ logger.debug(
+ f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}"
+ )
return json.loads(cached_data)
return None
except Exception as e:
@@ -237,23 +240,27 @@ class UserCacheService:
beatmapsets: list[Any],
limit: int = 100,
offset: int = 0,
- expire_seconds: int | None = None
+ expire_seconds: int | None = None,
):
"""缓存用户谱面集"""
try:
if expire_seconds is None:
expire_seconds = settings.user_beatmapsets_cache_expire_seconds
- cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
+ cache_key = self._get_user_beatmapsets_cache_key(
+ user_id, beatmapset_type, limit, offset
+ )
# 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps
serialized_beatmapsets = []
for bms in beatmapsets:
- if hasattr(bms, 'model_dump_json'):
+ if hasattr(bms, "model_dump_json"):
serialized_beatmapsets.append(bms.model_dump_json())
else:
serialized_beatmapsets.append(safe_json_dumps(bms))
cached_data = f"[{','.join(serialized_beatmapsets)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
- logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s")
+ logger.debug(
+ f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s"
+ )
except Exception as e:
logger.error(f"Error caching user beatmapsets: {e}")
@@ -269,7 +276,9 @@ class UserCacheService:
except Exception as e:
logger.error(f"Error invalidating user cache: {e}")
- async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None):
+ async def invalidate_user_scores_cache(
+ self, user_id: int, mode: GameMode | None = None
+ ):
"""使用户成绩缓存失效"""
try:
# 删除用户成绩相关缓存
@@ -278,7 +287,9 @@ class UserCacheService:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
- logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}")
+ logger.info(
+ f"Invalidated {len(keys)} score cache entries for user {user_id}"
+ )
except Exception as e:
logger.error(f"Error invalidating user scores cache: {e}")
@@ -290,12 +301,10 @@ class UserCacheService:
self._refreshing = True
try:
logger.info(f"Preloading cache for {len(user_ids)} users")
-
+
# 批量获取用户
users = (
- await session.exec(
- select(User).where(col(User.id).in_(user_ids))
- )
+ await session.exec(select(User).where(col(User.id).in_(user_ids)))
).all()
# 串行缓存用户信息,避免并发数据库访问问题
@@ -324,10 +333,7 @@ class UserCacheService:
logger.error(f"Error caching single user {user.id}: {e}")
async def refresh_user_cache_on_score_submit(
- self,
- session: AsyncSession,
- user_id: int,
- mode: GameMode
+ self, session: AsyncSession, user_id: int, mode: GameMode
):
"""成绩提交后刷新用户缓存"""
try:
@@ -335,7 +341,7 @@ class UserCacheService:
await self.invalidate_user_cache(user_id)
await self.invalidate_v1_user_cache(user_id)
await self.invalidate_user_scores_cache(user_id, mode)
-
+
# 立即重新加载用户信息
user = await session.get(User, user_id)
if user and user.id != BANCHOBOT_ID:
@@ -351,7 +357,7 @@ class UserCacheService:
v1_user_keys = await self.redis.keys("v1_user:*")
all_keys = user_keys + v1_user_keys
total_size = 0
-
+
for key in all_keys[:100]: # 限制检查数量
try:
size = await self.redis.memory_usage(key)
@@ -359,12 +365,22 @@ class UserCacheService:
total_size += size
except Exception:
continue
-
+
return {
- "cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]),
- "cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]),
+ "cached_users": len(
+ [
+ k
+ for k in user_keys
+ if ":scores:" not in k and ":beatmapsets:" not in k
+ ]
+ ),
+ "cached_v1_users": len(
+ [k for k in v1_user_keys if ":scores:" not in k]
+ ),
"cached_user_scores": len([k for k in user_keys if ":scores:" in k]),
- "cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]),
+ "cached_user_beatmapsets": len(
+ [k for k in user_keys if ":beatmapsets:" in k]
+ ),
"total_cached_entries": len(all_keys),
"estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
diff --git a/test_ranking_serialization.py b/test_ranking_serialization.py
index 91e767a..e9b81fd 100644
--- a/test_ranking_serialization.py
+++ b/test_ranking_serialization.py
@@ -1,92 +1,98 @@
#!/usr/bin/env python3
"""测试排行榜缓存序列化修复"""
-import asyncio
+from __future__ import annotations
+
+from datetime import UTC, datetime
import warnings
-from datetime import datetime, UTC
-from app.service.ranking_cache_service import DateTimeEncoder, safe_json_dumps
+
+from app.service.ranking_cache_service import safe_json_dumps
def test_datetime_serialization():
"""测试 datetime 序列化"""
print("🧪 测试 datetime 序列化...")
-
+
test_data = {
"id": 1,
"username": "test_user",
"last_updated": datetime.now(UTC),
"join_date": datetime(2020, 1, 1, tzinfo=UTC),
- "stats": {
- "pp": 1000.0,
- "accuracy": 95.5,
- "last_played": datetime.now(UTC)
- }
+ "stats": {"pp": 1000.0, "accuracy": 95.5, "last_played": datetime.now(UTC)},
}
-
+
try:
# 测试自定义编码器
json_result = safe_json_dumps(test_data)
print("✅ datetime 序列化成功")
print(f" 序列化结果长度: {len(json_result)}")
-
+
# 验证可以重新解析
import json
+
parsed = json.loads(json_result)
assert "last_updated" in parsed
assert isinstance(parsed["last_updated"], str)
print("✅ 序列化的 JSON 可以正确解析")
-
+
except Exception as e:
print(f"❌ datetime 序列化测试失败: {e}")
import traceback
+
traceback.print_exc()
def test_boolean_serialization():
"""测试布尔值序列化"""
print("\n🧪 测试布尔值序列化...")
-
+
test_data = {
"user": {
- "is_active": 1, # 数据库中的整数布尔值
- "is_supporter": 0, # 数据库中的整数布尔值
- "has_profile": True, # 正常布尔值
+ "is_active": 1, # 数据库中的整数布尔值
+ "is_supporter": 0, # 数据库中的整数布尔值
+ "has_profile": True, # 正常布尔值
},
"stats": {
- "is_ranked": 1, # 数据库中的整数布尔值
- "verified": False, # 正常布尔值
- }
+ "is_ranked": 1, # 数据库中的整数布尔值
+ "verified": False, # 正常布尔值
+ },
}
-
+
try:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
json_result = safe_json_dumps(test_data)
-
+
# 检查是否有 Pydantic 序列化警告
- pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)]
+ pydantic_warnings = [
+ warning
+ for warning in w
+ if "PydanticSerializationUnexpectedValue" in str(warning.message)
+ ]
if pydantic_warnings:
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个布尔值序列化警告")
for warning in pydantic_warnings:
print(f" {warning.message}")
else:
print("✅ 布尔值序列化无警告")
-
+
# 验证序列化结果
import json
+
parsed = json.loads(json_result)
print(f"✅ 布尔值序列化成功,结果: {parsed}")
-
+
except Exception as e:
print(f"❌ 布尔值序列化测试失败: {e}")
import traceback
+
traceback.print_exc()
def test_complex_ranking_data():
"""测试复杂的排行榜数据序列化"""
print("\n🧪 测试复杂排行榜数据序列化...")
-
+
# 模拟排行榜数据结构
ranking_data = [
{
@@ -95,8 +101,8 @@ def test_complex_ranking_data():
"id": 1,
"username": "player1",
"country_code": "US",
- "is_active": 1, # 整数布尔值
- "is_supporter": 0, # 整数布尔值
+ "is_active": 1, # 整数布尔值
+ "is_supporter": 0, # 整数布尔值
"join_date": datetime(2020, 1, 1, tzinfo=UTC),
"last_visit": datetime.now(UTC),
},
@@ -104,9 +110,9 @@ def test_complex_ranking_data():
"pp": 8000.0,
"accuracy": 98.5,
"play_count": 5000,
- "is_ranked": 1, # 整数布尔值
+ "is_ranked": 1, # 整数布尔值
"last_updated": datetime.now(UTC),
- }
+ },
},
{
"id": 2,
@@ -125,41 +131,47 @@ def test_complex_ranking_data():
"play_count": 4500,
"is_ranked": 1,
"last_updated": datetime.now(UTC),
- }
- }
+ },
+ },
]
-
+
try:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
json_result = safe_json_dumps(ranking_data)
-
- pydantic_warnings = [warning for warning in w if 'PydanticSerializationUnexpectedValue' in str(warning.message)]
+
+ pydantic_warnings = [
+ warning
+ for warning in w
+ if "PydanticSerializationUnexpectedValue" in str(warning.message)
+ ]
if pydantic_warnings:
print(f"⚠️ 仍有 {len(pydantic_warnings)} 个序列化警告")
for warning in pydantic_warnings:
print(f" {warning.message}")
else:
print("✅ 复杂排行榜数据序列化无警告")
-
+
# 验证序列化结果
import json
+
parsed = json.loads(json_result)
assert len(parsed) == 2
assert parsed[0]["user"]["username"] == "player1"
print(f"✅ 复杂排行榜数据序列化成功,包含 {len(parsed)} 个条目")
-
+
except Exception as e:
print(f"❌ 复杂排行榜数据序列化测试失败: {e}")
import traceback
+
traceback.print_exc()
if __name__ == "__main__":
print("🚀 开始排行榜缓存序列化测试\n")
-
+
test_datetime_serialization()
- test_boolean_serialization()
+ test_boolean_serialization()
test_complex_ranking_data()
-
+
print("\n🎉 排行榜缓存序列化测试完成!")