diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 6a5ca2b..d1b3451 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -315,3 +315,5 @@ class BeatmapsetResp(BeatmapsetBase): class SearchBeatmapsetsResp(SQLModel): beatmapsets: list[BeatmapsetResp] total: int + cursor: dict[str, int | float] | None = None + cursor_string: str | None = None diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py index 181eb05..dd4ebda 100644 --- a/app/fetcher/beatmapset.py +++ b/app/fetcher/beatmapset.py @@ -1,14 +1,83 @@ from __future__ import annotations +import asyncio +import base64 +import hashlib +import json + from app.database.beatmapset import BeatmapsetResp, SearchBeatmapsetsResp +from app.helpers.rate_limiter import osu_api_rate_limiter from app.log import logger from app.models.beatmap import SearchQueryModel from app.models.model import Cursor from ._base import BaseFetcher +import redis.asyncio as redis + class BeatmapsetFetcher(BaseFetcher): + @staticmethod + def _get_homepage_queries() -> list[tuple[SearchQueryModel, Cursor]]: + """获取主页预缓存查询列表""" + # 主页常用查询组合 + homepage_queries = [] + + # 主要排序方式 + sorts = ["ranked_desc", "updated_desc", "favourites_desc", "plays_desc"] + + for sort in sorts: + # 第一页 + query = SearchQueryModel( + q="", + s="leaderboard", + sort=sort, # type: ignore + nsfw=False, + m=0 + ) + homepage_queries.append((query, {})) + + return homepage_queries + + async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict: + """覆盖基类方法,添加速率限制""" + # 在请求前获取速率限制许可 + await osu_api_rate_limiter.acquire() + + # 调用基类的请求方法 + return await super().request_api(url, method, **kwargs) + + @staticmethod + def _generate_cache_key(query: SearchQueryModel, cursor: Cursor) -> str: + """生成搜索缓存键""" + # 创建包含查询参数和 cursor 的字典 + cache_data = { + **query.model_dump( + exclude_none=True, exclude_unset=True, exclude_defaults=True + ), + "cursor": cursor + } + + # 序列化为 JSON 并生成 MD5 哈希 + cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":")) + cache_hash = hashlib.md5(cache_json.encode()).hexdigest() + + return f"beatmapset:search:{cache_hash}" + + @staticmethod + def _encode_cursor(cursor_dict: dict[str, int | float]) -> str: + """将cursor字典编码为base64字符串""" + cursor_json = json.dumps(cursor_dict, separators=(",", ":")) + return base64.b64encode(cursor_json.encode()).decode() + + @staticmethod + def _decode_cursor(cursor_string: str) -> dict[str, int | float]: + """将base64字符串解码为cursor字典""" + try: + cursor_json = base64.b64decode(cursor_string).decode() + return json.loads(cursor_json) + except Exception: + return {} async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp: logger.opt(colors=True).debug( f"[BeatmapsetFetcher] get_beatmapset: {beatmap_set_id}" @@ -21,21 +90,236 @@ class BeatmapsetFetcher(BaseFetcher): ) async def search_beatmapset( - self, query: SearchQueryModel, cursor: Cursor + self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis ) -> SearchBeatmapsetsResp: logger.opt(colors=True).debug( f"[BeatmapsetFetcher] search_beatmapset: {query}" ) + # 生成缓存键 + cache_key = self._generate_cache_key(query, cursor) + + # 尝试从缓存获取结果 + cached_result = await redis_client.get(cache_key) + if cached_result: + logger.opt(colors=True).debug( + f"[BeatmapsetFetcher] Cache hit for key: {cache_key}" + ) + try: + cached_data = json.loads(cached_result) + return SearchBeatmapsetsResp.model_validate(cached_data) + except Exception as e: + logger.opt(colors=True).warning( + f"[BeatmapsetFetcher] Cache data invalid, fetching from API: {e}" + ) + + # 缓存未命中,从 API 获取数据 + logger.opt(colors=True).debug( + "[BeatmapsetFetcher] Cache miss, fetching from API" + ) + params = query.model_dump( exclude_none=True, exclude_unset=True, exclude_defaults=True ) - for k, v in cursor.items(): - params[f"cursor[{k}]"] = v - resp = SearchBeatmapsetsResp.model_validate( - await self.request_api( - "https://osu.ppy.sh/api/v2/beatmapsets/search", - params=params, - ) + + if query.cursor_string: + params["cursor_string"] = query.cursor_string + else: + for k, v in cursor.items(): + params[f"cursor[{k}]"] = v + + api_response = await self.request_api( + "https://osu.ppy.sh/api/v2/beatmapsets/search", + params=params, ) + + # 处理响应中的cursor信息 + if api_response.get("cursor"): + cursor_dict = api_response["cursor"] + api_response["cursor_string"] = self._encode_cursor(cursor_dict) + + # 将结果缓存 15 分钟 + cache_ttl = 15 * 60 # 15 分钟 + await redis_client.set( + cache_key, + json.dumps(api_response, separators=(",", ":")), + ex=cache_ttl + ) + + logger.opt(colors=True).debug( + f"[BeatmapsetFetcher] Cached result for key: " + f"{cache_key} (TTL: {cache_ttl}s)" + ) + + resp = SearchBeatmapsetsResp.model_validate(api_response) + + # 智能预取:只在用户明确搜索时才预取,避免过多API请求 + # 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取 + if (api_response.get("cursor") and + (query.q or query.s != "leaderboard" or cursor)): + # 在后台预取下1页(减少预取量) + import asyncio + # 不立即创建任务,而是延迟一段时间再预取 + async def delayed_prefetch(): + await asyncio.sleep(3.0) # 延迟3秒 + await self.prefetch_next_pages( + query, api_response["cursor"], redis_client, pages=1 + ) + + # 创建延迟预取任务 + task = asyncio.create_task(delayed_prefetch()) + # 添加到后台任务集合避免被垃圾回收 + if not hasattr(self, "_background_tasks"): + self._background_tasks = set() + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return resp + + async def prefetch_next_pages( + self, query: SearchQueryModel, current_cursor: Cursor, + redis_client: redis.Redis, pages: int = 3 + ) -> None: + """预取下几页内容""" + if not current_cursor: + return + + try: + cursor = current_cursor.copy() + + for page in range(1, pages + 1): + # 使用当前 cursor 请求下一页 + next_query = query.model_copy() + + logger.opt(colors=True).debug( + f"[BeatmapsetFetcher] Prefetching page {page + 1}" + ) + + # 生成下一页的缓存键 + next_cache_key = self._generate_cache_key(next_query, cursor) + + # 检查是否已经缓存 + if await redis_client.exists(next_cache_key): + logger.opt(colors=True).debug( + f"[BeatmapsetFetcher] Page {page + 1} already cached" + ) + # 尝试从缓存获取cursor继续预取 + cached_data = await redis_client.get(next_cache_key) + if cached_data: + try: + data = json.loads(cached_data) + if data.get("cursor"): + cursor = data["cursor"] + continue + except Exception: + pass + break + + # 在预取页面之间添加延迟,避免突发请求 + if page > 1: + await asyncio.sleep(1.5) # 1.5秒延迟 + + # 请求下一页数据 + params = next_query.model_dump( + exclude_none=True, exclude_unset=True, exclude_defaults=True + ) + + for k, v in cursor.items(): + params[f"cursor[{k}]"] = v + + api_response = await self.request_api( + "https://osu.ppy.sh/api/v2/beatmapsets/search", + params=params, + ) + + # 处理响应中的cursor信息 + if api_response.get("cursor"): + cursor_dict = api_response["cursor"] + api_response["cursor_string"] = self._encode_cursor(cursor_dict) + cursor = cursor_dict # 更新cursor用于下一页 + else: + # 没有更多页面了 + break + + # 缓存结果(较短的TTL用于预取) + prefetch_ttl = 10 * 60 # 10 分钟 + await redis_client.set( + next_cache_key, + json.dumps(api_response, separators=(",", ":")), + ex=prefetch_ttl + ) + + logger.opt(colors=True).debug( + f"[BeatmapsetFetcher] Prefetched page {page + 1} " + f"(TTL: {prefetch_ttl}s)" + ) + + except Exception as e: + logger.opt(colors=True).warning( + f"[BeatmapsetFetcher] Prefetch failed: {e}" + ) + + async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None: + """预热主页缓存""" + homepage_queries = self._get_homepage_queries() + + logger.opt(colors=True).info( + f"[BeatmapsetFetcher] Starting homepage cache warmup " + f"({len(homepage_queries)} queries)" + ) + + for i, (query, cursor) in enumerate(homepage_queries): + try: + # 在请求之间添加延迟,避免突发请求 + if i > 0: + await asyncio.sleep(2.0) # 2秒延迟 + + cache_key = self._generate_cache_key(query, cursor) + + # 检查是否已经缓存 + if await redis_client.exists(cache_key): + logger.opt(colors=True).debug( + f"[BeatmapsetFetcher] " + f"Query {query.sort} already cached" + ) + continue + + # 请求并缓存 + params = query.model_dump( + exclude_none=True, exclude_unset=True, exclude_defaults=True + ) + + api_response = await self.request_api( + "https://osu.ppy.sh/api/v2/beatmapsets/search", + params=params, + ) + + # 处理响应中的cursor信息 + if api_response.get("cursor"): + cursor_dict = api_response["cursor"] + api_response["cursor_string"] = self._encode_cursor(cursor_dict) + + # 缓存结果 + cache_ttl = 20 * 60 # 20 分钟 + await redis_client.set( + cache_key, + json.dumps(api_response, separators=(",", ":")), + ex=cache_ttl + ) + + logger.opt(colors=True).info( + f"[BeatmapsetFetcher] " + f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)" + ) + + # 预取前2页(也会遵循速率限制) + if api_response.get("cursor"): + await self.prefetch_next_pages( + query, api_response["cursor"], redis_client, pages=2 + ) + + except Exception as e: + logger.opt(colors=True).error( + f"[BeatmapsetFetcher] " + f"Failed to warmup cache for {query.sort}: {e}" + ) diff --git a/app/helpers/rate_limiter.py b/app/helpers/rate_limiter.py new file mode 100644 index 0000000..ad1f27b --- /dev/null +++ b/app/helpers/rate_limiter.py @@ -0,0 +1,122 @@ +""" +Rate limiter for osu! API requests to avoid abuse detection. +根据 osu! API v2 的速率限制设计: +- 默认:每分钟最多 1200 次请求 +- 突发:短时间内最多 200 次额外请求 +- 建议:每分钟不超过 60 次请求以避免滥用检测 +""" +from __future__ import annotations + +import asyncio +from collections import deque +import time + +from app.log import logger + + +class RateLimiter: + """osu! API 速率限制器""" + + def __init__( + self, + max_requests_per_minute: int = 60, # 保守的限制 + burst_limit: int = 10, # 短时间内的突发限制 + burst_window: float = 10.0, # 突发窗口(秒) + ): + self.max_requests_per_minute = max_requests_per_minute + self.burst_limit = burst_limit + self.burst_window = burst_window + + # 跟踪请求时间戳 + self.request_times: deque[float] = deque() + self.burst_times: deque[float] = deque() + + # 锁确保线程安全 + self._lock = asyncio.Lock() + + async def acquire(self) -> None: + """获取请求许可,如果超过限制则等待""" + async with self._lock: + current_time = time.time() + + # 清理过期的请求记录 + self._cleanup_old_requests(current_time) + + # 检查是否需要等待 + wait_time = self._calculate_wait_time(current_time) + + if wait_time > 0: + logger.opt(colors=True).info( + f"[RateLimiter] Rate limit reached, " + f"waiting {wait_time:.2f}s" + ) + await asyncio.sleep(wait_time) + current_time = time.time() + self._cleanup_old_requests(current_time) + + # 记录当前请求 + self.request_times.append(current_time) + self.burst_times.append(current_time) + + logger.opt(colors=True).debug( + f"[RateLimiter] Request granted. " + f"Recent requests: {len(self.request_times)}/min, " + f"{len(self.burst_times)}/{self.burst_window}s" + ) + + def _cleanup_old_requests(self, current_time: float) -> None: + """清理过期的请求记录""" + # 清理1分钟前的请求 + minute_ago = current_time - 60.0 + while self.request_times and self.request_times[0] < minute_ago: + self.request_times.popleft() + + # 清理突发窗口外的请求 + burst_window_ago = current_time - self.burst_window + while self.burst_times and self.burst_times[0] < burst_window_ago: + self.burst_times.popleft() + + def _calculate_wait_time(self, current_time: float) -> float: + """计算需要等待的时间""" + # 检查每分钟限制 + if len(self.request_times) >= self.max_requests_per_minute: + # 需要等到最老的请求超过1分钟 + oldest_request = self.request_times[0] + wait_for_minute_limit = oldest_request + 60.0 - current_time + else: + wait_for_minute_limit = 0.0 + + # 检查突发限制 + if len(self.burst_times) >= self.burst_limit: + # 需要等到最老的突发请求超过突发窗口 + oldest_burst = self.burst_times[0] + wait_for_burst_limit = oldest_burst + self.burst_window - current_time + else: + wait_for_burst_limit = 0.0 + + return max(wait_for_minute_limit, wait_for_burst_limit, 0.0) + + def get_status(self) -> dict[str, int | float]: + """获取当前速率限制状态""" + current_time = time.time() + self._cleanup_old_requests(current_time) + + return { + "requests_this_minute": len(self.request_times), + "max_requests_per_minute": self.max_requests_per_minute, + "burst_requests": len(self.burst_times), + "burst_limit": self.burst_limit, + "next_reset_in_seconds": ( + 60.0 - (current_time - self.request_times[0]) + if self.request_times + else 0.0 + ), + } + + +# 全局速率限制器实例 +osu_api_rate_limiter = RateLimiter( + max_requests_per_minute=50, # 保守设置,低于建议的60 + burst_limit=8, # 短时间内最多8个请求 + burst_window=10.0, # 10秒窗口 +) diff --git a/app/models/beatmap.py b/app/models/beatmap.py index b639467..69304b6 100644 --- a/app/models/beatmap.py +++ b/app/models/beatmap.py @@ -202,3 +202,7 @@ class SearchQueryModel(BaseModel): default=False, description="不良内容", ) + cursor_string: str | None = Field( + default=None, + description="游标字符串,用于分页", + ) diff --git a/app/models/model.py b/app/models/model.py index 34d4902..5b50da3 100644 --- a/app/models/model.py +++ b/app/models/model.py @@ -17,7 +17,7 @@ class UTCBaseModel(BaseModel): return v -Cursor = dict[str, int] +Cursor = dict[str, int | float] class RespWithCursor(BaseModel): diff --git a/app/router/v2/beatmapset.py b/app/router/v2/beatmapset.py index d843208..52a4f1d 100644 --- a/app/router/v2/beatmapset.py +++ b/app/router/v2/beatmapset.py @@ -7,7 +7,7 @@ from urllib.parse import parse_qs from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.database.beatmapset import SearchBeatmapsetsResp from app.dependencies.beatmap_download import get_beatmap_download_service -from app.dependencies.database import engine, get_db +from app.dependencies.database import engine, get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.dependencies.user import get_client_user, get_current_user @@ -55,13 +55,37 @@ async def search_beatmapset( current_user: User = Security(get_current_user, scopes=["public"]), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), + redis = Depends(get_redis), ): params = parse_qs(qs=request.url.query, keep_blank_values=True) cursor = {} + + # 解析 cursor[field] 格式的参数 for k, v in params.items(): match = re.match(r"cursor\[(\w+)\]", k) if match: - cursor[match.group(1)] = v[0] if v else None + field_name = match.group(1) + field_value = v[0] if v else None + if field_value is not None: + # 转换为适当的类型 + try: + if field_name in ["approved_date", "id"]: + cursor[field_name] = int(field_value) + else: + # 尝试转换为数字类型 + try: + # 首先尝试转换为整数 + cursor[field_name] = int(field_value) + except ValueError: + try: + # 然后尝试转换为浮点数 + cursor[field_name] = float(field_value) + except ValueError: + # 最后保持字符串 + cursor[field_name] = field_value + except ValueError: + cursor[field_name] = field_value + if ( "recommended" in query.c or len(query.r) > 0 @@ -73,7 +97,7 @@ async def search_beatmapset( # TODO: search locally return SearchBeatmapsetsResp(total=0, beatmapsets=[]) try: - sets = await fetcher.search_beatmapset(query, cursor) + sets = await fetcher.search_beatmapset(query, cursor, redis) background_tasks.add_task(_save_to_db, sets) except HTTPError as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/app/scheduler/__init__.py b/app/scheduler/__init__.py new file mode 100644 index 0000000..c9d77d0 --- /dev/null +++ b/app/scheduler/__init__.py @@ -0,0 +1,6 @@ +"""缓存调度器模块""" +from __future__ import annotations + +from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler + +__all__ = ["start_cache_scheduler", "stop_cache_scheduler"] diff --git a/app/scheduler/cache_scheduler.py b/app/scheduler/cache_scheduler.py new file mode 100644 index 0000000..254d315 --- /dev/null +++ b/app/scheduler/cache_scheduler.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import asyncio + +from app.dependencies.database import get_redis +from app.dependencies.fetcher import get_fetcher +from app.log import logger + + +class BeatmapsetCacheScheduler: + """谱面集缓存调度器""" + + def __init__(self): + self.running = False + self.task = None + + async def start(self): + """启动调度器""" + if self.running: + return + + self.running = True + self.task = asyncio.create_task(self._run_scheduler()) + logger.info("BeatmapsetCacheScheduler started") + + async def stop(self): + """停止调度器""" + self.running = False + if self.task: + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + logger.info("BeatmapsetCacheScheduler stopped") + + async def _run_scheduler(self): + """运行调度器主循环""" + # 启动时立即执行一次预热 + await self._warmup_cache() + + while self.running: + try: + # 每30分钟执行一次缓存预热 + await asyncio.sleep(30 * 60) # 30分钟 + + if self.running: + await self._warmup_cache() + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Cache scheduler error: {e}") + await asyncio.sleep(60) # 出错后等待1分钟再继续 + + async def _warmup_cache(self): + """执行缓存预热""" + try: + logger.info("Starting cache warmup...") + + fetcher = await get_fetcher() + redis = get_redis() + + # 预热主页缓存 + await fetcher.warmup_homepage_cache(redis) + + logger.info("Cache warmup completed successfully") + + except Exception as e: + logger.error(f"Cache warmup failed: {e}") + + +# 全局调度器实例 +cache_scheduler = BeatmapsetCacheScheduler() + + +async def start_cache_scheduler(): + """启动缓存调度器""" + await cache_scheduler.start() + + +async def stop_cache_scheduler(): + """停止缓存调度器""" + await cache_scheduler.stop() diff --git a/main.py b/main.py index f8cc17f..90568de 100644 --- a/main.py +++ b/main.py @@ -21,6 +21,7 @@ from app.router import ( signalr_router, ) from app.router.redirect import redirect_router +from app.scheduler.cache_scheduler import start_cache_scheduler, stop_cache_scheduler from app.service.beatmap_download_service import download_service from app.service.calculate_all_user_rank import calculate_user_rank from app.service.create_banchobot import create_banchobot @@ -74,9 +75,11 @@ async def lifespan(app: FastAPI): await daily_challenge_job() await create_banchobot() await download_service.start_health_check() # 启动下载服务健康检查 + await start_cache_scheduler() # 启动缓存调度器 # on shutdown yield stop_scheduler() + await stop_cache_scheduler() # 停止缓存调度器 await download_service.stop_health_check() # 停止下载服务健康检查 await engine.dispose() await redis_client.aclose() @@ -111,6 +114,7 @@ app.include_router(fetcher_router) app.include_router(file_router) app.include_router(auth_router) app.include_router(private_router) + # CORS 配置 origins = [] for url in [*settings.cors_urls, settings.server_url]: