diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py
index 6a5ca2b..d1b3451 100644
--- a/app/database/beatmapset.py
+++ b/app/database/beatmapset.py
@@ -315,3 +315,5 @@ class BeatmapsetResp(BeatmapsetBase):
class SearchBeatmapsetsResp(SQLModel):
beatmapsets: list[BeatmapsetResp]
total: int
+ cursor: dict[str, int | float] | None = None
+ cursor_string: str | None = None
diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py
index 181eb05..dd4ebda 100644
--- a/app/fetcher/beatmapset.py
+++ b/app/fetcher/beatmapset.py
@@ -1,14 +1,83 @@
from __future__ import annotations
+import asyncio
+import base64
+import hashlib
+import json
+
from app.database.beatmapset import BeatmapsetResp, SearchBeatmapsetsResp
+from app.helpers.rate_limiter import osu_api_rate_limiter
from app.log import logger
from app.models.beatmap import SearchQueryModel
from app.models.model import Cursor
from ._base import BaseFetcher
+import redis.asyncio as redis
+
class BeatmapsetFetcher(BaseFetcher):
+ @staticmethod
+ def _get_homepage_queries() -> list[tuple[SearchQueryModel, Cursor]]:
+ """获取主页预缓存查询列表"""
+ # 主页常用查询组合
+ homepage_queries = []
+
+ # 主要排序方式
+ sorts = ["ranked_desc", "updated_desc", "favourites_desc", "plays_desc"]
+
+ for sort in sorts:
+ # 第一页
+ query = SearchQueryModel(
+ q="",
+ s="leaderboard",
+ sort=sort, # type: ignore
+ nsfw=False,
+ m=0
+ )
+ homepage_queries.append((query, {}))
+
+ return homepage_queries
+
+ async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
+ """覆盖基类方法,添加速率限制"""
+ # 在请求前获取速率限制许可
+ await osu_api_rate_limiter.acquire()
+
+ # 调用基类的请求方法
+ return await super().request_api(url, method, **kwargs)
+
+ @staticmethod
+ def _generate_cache_key(query: SearchQueryModel, cursor: Cursor) -> str:
+ """生成搜索缓存键"""
+ # 创建包含查询参数和 cursor 的字典
+ cache_data = {
+ **query.model_dump(
+ exclude_none=True, exclude_unset=True, exclude_defaults=True
+ ),
+ "cursor": cursor
+ }
+
+ # 序列化为 JSON 并生成 MD5 哈希
+ cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
+ cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
+
+ return f"beatmapset:search:{cache_hash}"
+
+ @staticmethod
+ def _encode_cursor(cursor_dict: dict[str, int | float]) -> str:
+ """将cursor字典编码为base64字符串"""
+ cursor_json = json.dumps(cursor_dict, separators=(",", ":"))
+ return base64.b64encode(cursor_json.encode()).decode()
+
+ @staticmethod
+ def _decode_cursor(cursor_string: str) -> dict[str, int | float]:
+ """将base64字符串解码为cursor字典"""
+ try:
+ cursor_json = base64.b64decode(cursor_string).decode()
+ return json.loads(cursor_json)
+ except Exception:
+ return {}
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
logger.opt(colors=True).debug(
f"[BeatmapsetFetcher] get_beatmapset: {beatmap_set_id}"
@@ -21,21 +90,236 @@ class BeatmapsetFetcher(BaseFetcher):
)
async def search_beatmapset(
- self, query: SearchQueryModel, cursor: Cursor
+ self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis
) -> SearchBeatmapsetsResp:
logger.opt(colors=True).debug(
f"[BeatmapsetFetcher] search_beatmapset: {query}"
)
+ # 生成缓存键
+ cache_key = self._generate_cache_key(query, cursor)
+
+ # 尝试从缓存获取结果
+ cached_result = await redis_client.get(cache_key)
+ if cached_result:
+ logger.opt(colors=True).debug(
+ f"[BeatmapsetFetcher] Cache hit for key: {cache_key}"
+ )
+ try:
+ cached_data = json.loads(cached_result)
+ return SearchBeatmapsetsResp.model_validate(cached_data)
+ except Exception as e:
+ logger.opt(colors=True).warning(
+ f"[BeatmapsetFetcher] Cache data invalid, fetching from API: {e}"
+ )
+
+ # 缓存未命中,从 API 获取数据
+ logger.opt(colors=True).debug(
+ "[BeatmapsetFetcher] Cache miss, fetching from API"
+ )
+
params = query.model_dump(
exclude_none=True, exclude_unset=True, exclude_defaults=True
)
- for k, v in cursor.items():
- params[f"cursor[{k}]"] = v
- resp = SearchBeatmapsetsResp.model_validate(
- await self.request_api(
- "https://osu.ppy.sh/api/v2/beatmapsets/search",
- params=params,
- )
+
+ if query.cursor_string:
+ params["cursor_string"] = query.cursor_string
+ else:
+ for k, v in cursor.items():
+ params[f"cursor[{k}]"] = v
+
+ api_response = await self.request_api(
+ "https://osu.ppy.sh/api/v2/beatmapsets/search",
+ params=params,
)
+
+ # 处理响应中的cursor信息
+ if api_response.get("cursor"):
+ cursor_dict = api_response["cursor"]
+ api_response["cursor_string"] = self._encode_cursor(cursor_dict)
+
+ # 将结果缓存 15 分钟
+ cache_ttl = 15 * 60 # 15 分钟
+ await redis_client.set(
+ cache_key,
+ json.dumps(api_response, separators=(",", ":")),
+ ex=cache_ttl
+ )
+
+ logger.opt(colors=True).debug(
+ f"[BeatmapsetFetcher] Cached result for key: "
+ f"{cache_key} (TTL: {cache_ttl}s)"
+ )
+
+ resp = SearchBeatmapsetsResp.model_validate(api_response)
+
+ # 智能预取:只在用户明确搜索时才预取,避免过多API请求
+ # 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
+ if (api_response.get("cursor") and
+ (query.q or query.s != "leaderboard" or cursor)):
+ # 在后台预取下1页(减少预取量)
+ import asyncio
+ # 不立即创建任务,而是延迟一段时间再预取
+ async def delayed_prefetch():
+ await asyncio.sleep(3.0) # 延迟3秒
+ await self.prefetch_next_pages(
+ query, api_response["cursor"], redis_client, pages=1
+ )
+
+ # 创建延迟预取任务
+ task = asyncio.create_task(delayed_prefetch())
+ # 添加到后台任务集合避免被垃圾回收
+ if not hasattr(self, "_background_tasks"):
+ self._background_tasks = set()
+ self._background_tasks.add(task)
+ task.add_done_callback(self._background_tasks.discard)
+
return resp
+
+ async def prefetch_next_pages(
+ self, query: SearchQueryModel, current_cursor: Cursor,
+ redis_client: redis.Redis, pages: int = 3
+ ) -> None:
+ """预取下几页内容"""
+ if not current_cursor:
+ return
+
+ try:
+ cursor = current_cursor.copy()
+
+ for page in range(1, pages + 1):
+ # 使用当前 cursor 请求下一页
+ next_query = query.model_copy()
+
+ logger.opt(colors=True).debug(
+ f"[BeatmapsetFetcher] Prefetching page {page + 1}"
+ )
+
+ # 生成下一页的缓存键
+ next_cache_key = self._generate_cache_key(next_query, cursor)
+
+ # 检查是否已经缓存
+ if await redis_client.exists(next_cache_key):
+ logger.opt(colors=True).debug(
+ f"[BeatmapsetFetcher] Page {page + 1} already cached"
+ )
+ # 尝试从缓存获取cursor继续预取
+ cached_data = await redis_client.get(next_cache_key)
+ if cached_data:
+ try:
+ data = json.loads(cached_data)
+ if data.get("cursor"):
+ cursor = data["cursor"]
+ continue
+ except Exception:
+ pass
+ break
+
+ # 在预取页面之间添加延迟,避免突发请求
+ if page > 1:
+ await asyncio.sleep(1.5) # 1.5秒延迟
+
+ # 请求下一页数据
+ params = next_query.model_dump(
+ exclude_none=True, exclude_unset=True, exclude_defaults=True
+ )
+
+ for k, v in cursor.items():
+ params[f"cursor[{k}]"] = v
+
+ api_response = await self.request_api(
+ "https://osu.ppy.sh/api/v2/beatmapsets/search",
+ params=params,
+ )
+
+ # 处理响应中的cursor信息
+ if api_response.get("cursor"):
+ cursor_dict = api_response["cursor"]
+ api_response["cursor_string"] = self._encode_cursor(cursor_dict)
+ cursor = cursor_dict # 更新cursor用于下一页
+ else:
+ # 没有更多页面了
+ break
+
+ # 缓存结果(较短的TTL用于预取)
+ prefetch_ttl = 10 * 60 # 10 分钟
+ await redis_client.set(
+ next_cache_key,
+ json.dumps(api_response, separators=(",", ":")),
+ ex=prefetch_ttl
+ )
+
+ logger.opt(colors=True).debug(
+ f"[BeatmapsetFetcher] Prefetched page {page + 1} "
+ f"(TTL: {prefetch_ttl}s)"
+ )
+
+ except Exception as e:
+ logger.opt(colors=True).warning(
+ f"[BeatmapsetFetcher] Prefetch failed: {e}"
+ )
+
+ async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
+ """预热主页缓存"""
+ homepage_queries = self._get_homepage_queries()
+
+ logger.opt(colors=True).info(
+ f"[BeatmapsetFetcher] Starting homepage cache warmup "
+ f"({len(homepage_queries)} queries)"
+ )
+
+ for i, (query, cursor) in enumerate(homepage_queries):
+ try:
+ # 在请求之间添加延迟,避免突发请求
+ if i > 0:
+ await asyncio.sleep(2.0) # 2秒延迟
+
+ cache_key = self._generate_cache_key(query, cursor)
+
+ # 检查是否已经缓存
+ if await redis_client.exists(cache_key):
+ logger.opt(colors=True).debug(
+ f"[BeatmapsetFetcher] "
+ f"Query {query.sort} already cached"
+ )
+ continue
+
+ # 请求并缓存
+ params = query.model_dump(
+ exclude_none=True, exclude_unset=True, exclude_defaults=True
+ )
+
+ api_response = await self.request_api(
+ "https://osu.ppy.sh/api/v2/beatmapsets/search",
+ params=params,
+ )
+
+ # 处理响应中的cursor信息
+ if api_response.get("cursor"):
+ cursor_dict = api_response["cursor"]
+ api_response["cursor_string"] = self._encode_cursor(cursor_dict)
+
+ # 缓存结果
+ cache_ttl = 20 * 60 # 20 分钟
+ await redis_client.set(
+ cache_key,
+ json.dumps(api_response, separators=(",", ":")),
+ ex=cache_ttl
+ )
+
+ logger.opt(colors=True).info(
+ f"[BeatmapsetFetcher] "
+ f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
+ )
+
+ # 预取前2页(也会遵循速率限制)
+ if api_response.get("cursor"):
+ await self.prefetch_next_pages(
+ query, api_response["cursor"], redis_client, pages=2
+ )
+
+ except Exception as e:
+ logger.opt(colors=True).error(
+ f"[BeatmapsetFetcher] "
+ f"Failed to warmup cache for {query.sort}: {e}"
+ )
diff --git a/app/helpers/rate_limiter.py b/app/helpers/rate_limiter.py
new file mode 100644
index 0000000..ad1f27b
--- /dev/null
+++ b/app/helpers/rate_limiter.py
@@ -0,0 +1,122 @@
+"""
+Rate limiter for osu! API requests to avoid abuse detection.
+根据 osu! API v2 的速率限制设计:
+- 默认:每分钟最多 1200 次请求
+- 突发:短时间内最多 200 次额外请求
+- 建议:每分钟不超过 60 次请求以避免滥用检测
+"""
+from __future__ import annotations
+
+import asyncio
+from collections import deque
+import time
+
+from app.log import logger
+
+
+class RateLimiter:
+ """osu! API 速率限制器"""
+
+ def __init__(
+ self,
+ max_requests_per_minute: int = 60, # 保守的限制
+ burst_limit: int = 10, # 短时间内的突发限制
+ burst_window: float = 10.0, # 突发窗口(秒)
+ ):
+ self.max_requests_per_minute = max_requests_per_minute
+ self.burst_limit = burst_limit
+ self.burst_window = burst_window
+
+ # 跟踪请求时间戳
+ self.request_times: deque[float] = deque()
+ self.burst_times: deque[float] = deque()
+
+ # 锁确保线程安全
+ self._lock = asyncio.Lock()
+
+ async def acquire(self) -> None:
+ """获取请求许可,如果超过限制则等待"""
+ async with self._lock:
+ current_time = time.time()
+
+ # 清理过期的请求记录
+ self._cleanup_old_requests(current_time)
+
+ # 检查是否需要等待
+ wait_time = self._calculate_wait_time(current_time)
+
+ if wait_time > 0:
+ logger.opt(colors=True).info(
+ f"[RateLimiter] Rate limit reached, "
+ f"waiting {wait_time:.2f}s"
+ )
+ await asyncio.sleep(wait_time)
+ current_time = time.time()
+ self._cleanup_old_requests(current_time)
+
+ # 记录当前请求
+ self.request_times.append(current_time)
+ self.burst_times.append(current_time)
+
+ logger.opt(colors=True).debug(
+ f"[RateLimiter] Request granted. "
+ f"Recent requests: {len(self.request_times)}/min, "
+ f"{len(self.burst_times)}/{self.burst_window}s"
+ )
+
+ def _cleanup_old_requests(self, current_time: float) -> None:
+ """清理过期的请求记录"""
+ # 清理1分钟前的请求
+ minute_ago = current_time - 60.0
+ while self.request_times and self.request_times[0] < minute_ago:
+ self.request_times.popleft()
+
+ # 清理突发窗口外的请求
+ burst_window_ago = current_time - self.burst_window
+ while self.burst_times and self.burst_times[0] < burst_window_ago:
+ self.burst_times.popleft()
+
+ def _calculate_wait_time(self, current_time: float) -> float:
+ """计算需要等待的时间"""
+ # 检查每分钟限制
+ if len(self.request_times) >= self.max_requests_per_minute:
+ # 需要等到最老的请求超过1分钟
+ oldest_request = self.request_times[0]
+ wait_for_minute_limit = oldest_request + 60.0 - current_time
+ else:
+ wait_for_minute_limit = 0.0
+
+ # 检查突发限制
+ if len(self.burst_times) >= self.burst_limit:
+ # 需要等到最老的突发请求超过突发窗口
+ oldest_burst = self.burst_times[0]
+ wait_for_burst_limit = oldest_burst + self.burst_window - current_time
+ else:
+ wait_for_burst_limit = 0.0
+
+ return max(wait_for_minute_limit, wait_for_burst_limit, 0.0)
+
+ def get_status(self) -> dict[str, int | float]:
+ """获取当前速率限制状态"""
+ current_time = time.time()
+ self._cleanup_old_requests(current_time)
+
+ return {
+ "requests_this_minute": len(self.request_times),
+ "max_requests_per_minute": self.max_requests_per_minute,
+ "burst_requests": len(self.burst_times),
+ "burst_limit": self.burst_limit,
+ "next_reset_in_seconds": (
+ 60.0 - (current_time - self.request_times[0])
+ if self.request_times
+ else 0.0
+ ),
+ }
+
+
+# 全局速率限制器实例
+osu_api_rate_limiter = RateLimiter(
+ max_requests_per_minute=50, # 保守设置,低于建议的60
+ burst_limit=8, # 短时间内最多8个请求
+ burst_window=10.0, # 10秒窗口
+)
diff --git a/app/models/beatmap.py b/app/models/beatmap.py
index b639467..69304b6 100644
--- a/app/models/beatmap.py
+++ b/app/models/beatmap.py
@@ -202,3 +202,7 @@ class SearchQueryModel(BaseModel):
default=False,
description="不良内容",
)
+ cursor_string: str | None = Field(
+ default=None,
+ description="游标字符串,用于分页",
+ )
diff --git a/app/models/model.py b/app/models/model.py
index 34d4902..5b50da3 100644
--- a/app/models/model.py
+++ b/app/models/model.py
@@ -17,7 +17,7 @@ class UTCBaseModel(BaseModel):
return v
-Cursor = dict[str, int]
+Cursor = dict[str, int | float]
class RespWithCursor(BaseModel):
diff --git a/app/router/v2/beatmapset.py b/app/router/v2/beatmapset.py
index d843208..52a4f1d 100644
--- a/app/router/v2/beatmapset.py
+++ b/app/router/v2/beatmapset.py
@@ -7,7 +7,7 @@ from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp
from app.dependencies.beatmap_download import get_beatmap_download_service
-from app.dependencies.database import engine, get_db
+from app.dependencies.database import engine, get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user
@@ -55,13 +55,37 @@ async def search_beatmapset(
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
+ redis = Depends(get_redis),
):
params = parse_qs(qs=request.url.query, keep_blank_values=True)
cursor = {}
+
+ # 解析 cursor[field] 格式的参数
for k, v in params.items():
match = re.match(r"cursor\[(\w+)\]", k)
if match:
- cursor[match.group(1)] = v[0] if v else None
+ field_name = match.group(1)
+ field_value = v[0] if v else None
+ if field_value is not None:
+ # 转换为适当的类型
+ try:
+ if field_name in ["approved_date", "id"]:
+ cursor[field_name] = int(field_value)
+ else:
+ # 尝试转换为数字类型
+ try:
+ # 首先尝试转换为整数
+ cursor[field_name] = int(field_value)
+ except ValueError:
+ try:
+ # 然后尝试转换为浮点数
+ cursor[field_name] = float(field_value)
+ except ValueError:
+ # 最后保持字符串
+ cursor[field_name] = field_value
+ except ValueError:
+ cursor[field_name] = field_value
+
if (
"recommended" in query.c
or len(query.r) > 0
@@ -73,7 +97,7 @@ async def search_beatmapset(
# TODO: search locally
return SearchBeatmapsetsResp(total=0, beatmapsets=[])
try:
- sets = await fetcher.search_beatmapset(query, cursor)
+ sets = await fetcher.search_beatmapset(query, cursor, redis)
background_tasks.add_task(_save_to_db, sets)
except HTTPError as e:
raise HTTPException(status_code=500, detail=str(e))
diff --git a/app/scheduler/__init__.py b/app/scheduler/__init__.py
new file mode 100644
index 0000000..c9d77d0
--- /dev/null
+++ b/app/scheduler/__init__.py
@@ -0,0 +1,6 @@
+"""缓存调度器模块"""
+from __future__ import annotations
+
+from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler
+
+__all__ = ["start_cache_scheduler", "stop_cache_scheduler"]
diff --git a/app/scheduler/cache_scheduler.py b/app/scheduler/cache_scheduler.py
new file mode 100644
index 0000000..254d315
--- /dev/null
+++ b/app/scheduler/cache_scheduler.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+import asyncio
+
+from app.dependencies.database import get_redis
+from app.dependencies.fetcher import get_fetcher
+from app.log import logger
+
+
+class BeatmapsetCacheScheduler:
+ """谱面集缓存调度器"""
+
+ def __init__(self):
+ self.running = False
+ self.task = None
+
+ async def start(self):
+ """启动调度器"""
+ if self.running:
+ return
+
+ self.running = True
+ self.task = asyncio.create_task(self._run_scheduler())
+ logger.info("BeatmapsetCacheScheduler started")
+
+ async def stop(self):
+ """停止调度器"""
+ self.running = False
+ if self.task:
+ self.task.cancel()
+ try:
+ await self.task
+ except asyncio.CancelledError:
+ pass
+ logger.info("BeatmapsetCacheScheduler stopped")
+
+ async def _run_scheduler(self):
+ """运行调度器主循环"""
+ # 启动时立即执行一次预热
+ await self._warmup_cache()
+
+ while self.running:
+ try:
+ # 每30分钟执行一次缓存预热
+ await asyncio.sleep(30 * 60) # 30分钟
+
+ if self.running:
+ await self._warmup_cache()
+
+ except asyncio.CancelledError:
+ break
+ except Exception as e:
+ logger.error(f"Cache scheduler error: {e}")
+ await asyncio.sleep(60) # 出错后等待1分钟再继续
+
+ async def _warmup_cache(self):
+ """执行缓存预热"""
+ try:
+ logger.info("Starting cache warmup...")
+
+ fetcher = await get_fetcher()
+ redis = get_redis()
+
+ # 预热主页缓存
+ await fetcher.warmup_homepage_cache(redis)
+
+ logger.info("Cache warmup completed successfully")
+
+ except Exception as e:
+ logger.error(f"Cache warmup failed: {e}")
+
+
+# 全局调度器实例
+cache_scheduler = BeatmapsetCacheScheduler()
+
+
+async def start_cache_scheduler():
+ """启动缓存调度器"""
+ await cache_scheduler.start()
+
+
+async def stop_cache_scheduler():
+ """停止缓存调度器"""
+ await cache_scheduler.stop()
diff --git a/main.py b/main.py
index f8cc17f..90568de 100644
--- a/main.py
+++ b/main.py
@@ -21,6 +21,7 @@ from app.router import (
signalr_router,
)
from app.router.redirect import redirect_router
+from app.scheduler.cache_scheduler import start_cache_scheduler, stop_cache_scheduler
from app.service.beatmap_download_service import download_service
from app.service.calculate_all_user_rank import calculate_user_rank
from app.service.create_banchobot import create_banchobot
@@ -74,9 +75,11 @@ async def lifespan(app: FastAPI):
await daily_challenge_job()
await create_banchobot()
await download_service.start_health_check() # 启动下载服务健康检查
+ await start_cache_scheduler() # 启动缓存调度器
# on shutdown
yield
stop_scheduler()
+ await stop_cache_scheduler() # 停止缓存调度器
await download_service.stop_health_check() # 停止下载服务健康检查
await engine.dispose()
await redis_client.aclose()
@@ -111,6 +114,7 @@ app.include_router(fetcher_router)
app.include_router(file_router)
app.include_router(auth_router)
app.include_router(private_router)
+
# CORS 配置
origins = []
for url in [*settings.cors_urls, settings.server_url]: