From 819f52450c995575227b850ed1316809d9a5b10a Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 23 Nov 2025 13:42:21 +0000 Subject: [PATCH] refactor(fetcher): implement passive rate limiter for API requests --- app/fetcher/_base.py | 99 ++++++++++++++++++---- app/fetcher/beatmapset.py | 159 ++++++++++++------------------------ app/helpers/rate_limiter.py | 116 -------------------------- 3 files changed, 135 insertions(+), 239 deletions(-) delete mode 100644 app/helpers/rate_limiter.py diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py index 1757cb9..b029755 100644 --- a/app/fetcher/_base.py +++ b/app/fetcher/_base.py @@ -1,10 +1,11 @@ import asyncio +from datetime import datetime import time from app.dependencies.database import get_redis from app.log import fetcher_logger -from httpx import AsyncClient +from httpx import AsyncClient, HTTPStatusError class TokenAuthError(Exception): @@ -13,10 +14,63 @@ class TokenAuthError(Exception): pass +class PassiveRateLimiter: + """ + 被动速率限制器 + 当收到 429 响应时,读取 Retry-After 头并暂停所有请求 + """ + + def __init__(self): + self._lock = asyncio.Lock() + self._retry_after_time: float | None = None + self._waiting_tasks: set[asyncio.Task] = set() + + async def wait_if_limited(self) -> None: + """如果正在限流中,等待限流解除""" + async with self._lock: + if self._retry_after_time is not None: + current_time = time.time() + if current_time < self._retry_after_time: + wait_seconds = self._retry_after_time - current_time + logger.warning(f"Rate limited, waiting {wait_seconds:.2f} seconds") + await asyncio.sleep(wait_seconds) + self._retry_after_time = None + + async def handle_rate_limit(self, retry_after: str | int | None) -> None: + """ + 处理 429 响应,设置限流时间 + + Args: + retry_after: Retry-After 头的值,可以是秒数或 HTTP 日期 + """ + async with self._lock: + if retry_after is None: + # 如果没有 Retry-After 头,默认等待 60 秒 + wait_seconds = 60 + elif isinstance(retry_after, int): + wait_seconds = retry_after + elif retry_after.isdigit(): + wait_seconds = int(retry_after) + else: + # 尝试解析 HTTP 日期格式 + try: + retry_time = datetime.strptime(retry_after, "%a, %d %b %Y %H:%M:%S %Z") + wait_seconds = max(0, (retry_time - datetime.utcnow()).total_seconds()) + except ValueError: + # 解析失败,默认等待 60 秒 + wait_seconds = 60 + + self._retry_after_time = time.time() + wait_seconds + logger.warning(f"Rate limit triggered, will retry after {wait_seconds} seconds") + + logger = fetcher_logger("Fetcher") class BaseFetcher: + # 类级别的 rate limiter,所有实例共享 + _rate_limiter = PassiveRateLimiter() + def __init__( self, client_id: str, @@ -51,7 +105,7 @@ class BaseFetcher: async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict: """ - 发送 API 请求 + 发送 API 请求,支持被动速率限制 """ await self.ensure_valid_access_token() @@ -59,24 +113,41 @@ class BaseFetcher: attempt = 0 while attempt < 2: + # 在发送请求前等待速率限制 + await self._rate_limiter.wait_if_limited() + request_headers = {**headers, **self.header} request_kwargs = kwargs.copy() async with AsyncClient() as client: - response = await client.request( - method, - url, - headers=request_headers, - **request_kwargs, - ) + try: + response = await client.request( + method, + url, + headers=request_headers, + **request_kwargs, + ) + response.raise_for_status() + return response.json() - if response.status_code != 401: - response.raise_for_status() - return response.json() + except HTTPStatusError as e: + # 处理 429 速率限制响应 + if e.response.status_code == 429: + retry_after = e.response.headers.get("Retry-After") + logger.warning(f"Rate limited for {url}, Retry-After: {retry_after}") + await self._rate_limiter.handle_rate_limit(retry_after) + # 速率限制后重试当前请求(不增加 attempt) + continue - attempt += 1 - logger.warning(f"Received 401 error for {url}, attempt {attempt}") - await self._handle_unauthorized() + # 处理 401 未授权响应 + if e.response.status_code == 401: + attempt += 1 + logger.warning(f"Received 401 error for {url}, attempt {attempt}") + await self._handle_unauthorized() + continue + + # 其他 HTTP 错误直接抛出 + raise await self._clear_access_token() logger.warning(f"Failed to authorize after retries for {url}, cleaned up tokens") diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py index 910163f..ce5167d 100644 --- a/app/fetcher/beatmapset.py +++ b/app/fetcher/beatmapset.py @@ -4,7 +4,6 @@ import hashlib import json from app.database import BeatmapsetDict, BeatmapsetModel, SearchBeatmapsetsResp -from app.helpers.rate_limiter import osu_api_rate_limiter from app.log import fetcher_logger from app.models.beatmap import SearchQueryModel from app.models.model import Cursor @@ -12,21 +11,12 @@ from app.utils import bg_tasks from ._base import BaseFetcher -from httpx import AsyncClient from pydantic import TypeAdapter import redis.asyncio as redis - -class RateLimitError(Exception): - """速率限制异常""" - - pass - - logger = fetcher_logger("BeatmapsetFetcher") -MAX_RETRY_ATTEMPTS = 3 adapter = TypeAdapter( BeatmapsetModel.generate_typeddict( ( @@ -90,41 +80,6 @@ class BeatmapsetFetcher(BaseFetcher): return homepage_queries - async def request_api(self, url: str, method: str = "GET", *, retry_times: int = 0, **kwargs) -> dict: - """覆盖基类方法,添加速率限制和429错误处理""" - # 在请求前获取速率限制许可 - if retry_times > MAX_RETRY_ATTEMPTS: - raise RuntimeError(f"Maximum retry attempts ({MAX_RETRY_ATTEMPTS}) reached for API request to {url}") - - await osu_api_rate_limiter.acquire() - - # 检查 token 是否过期,如果过期则刷新 - if self.is_token_expired(): - await self.grant_access_token() - - header = kwargs.pop("headers", {}) - header.update(self.header) - - async with AsyncClient() as client: - response = await client.request( - method, - url, - headers=header, - **kwargs, - ) - - # 处理 429 错误 - 直接抛出异常,不重试 - if response.status_code == 429: - logger.warning(f"Rate limit exceeded (429) for {url}") - raise RateLimitError(f"Rate limit exceeded for {url}. Please try again later.") - if response.status_code == 401: - logger.warning(f"Received 401 error for {url}") - await self._clear_access_token() - return await self.request_api(url, method, retry_times=retry_times + 1, **kwargs) - - response.raise_for_status() - return response.json() - @staticmethod def _generate_cache_key(query: SearchQueryModel, cursor: Cursor) -> str: """生成搜索缓存键""" @@ -242,10 +197,7 @@ class BeatmapsetFetcher(BaseFetcher): # 不立即创建任务,而是延迟一段时间再预取 async def delayed_prefetch(): await asyncio.sleep(3.0) # 延迟3秒 - try: - await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1) - except RateLimitError: - logger.info("Prefetch skipped due to rate limit") + await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1) bg_tasks.add_task(delayed_prefetch) @@ -262,71 +214,65 @@ class BeatmapsetFetcher(BaseFetcher): if not current_cursor: return - try: - cursor = current_cursor.copy() + cursor = current_cursor.copy() - for page in range(1, pages + 1): - # 使用当前 cursor 请求下一页 - next_query = query.model_copy() + for page in range(1, pages + 1): + # 使用当前 cursor 请求下一页 + next_query = query.model_copy() - logger.debug(f"Prefetching page {page + 1}") + logger.debug(f"Prefetching page {page + 1}") - # 生成下一页的缓存键 - next_cache_key = self._generate_cache_key(next_query, cursor) + # 生成下一页的缓存键 + next_cache_key = self._generate_cache_key(next_query, cursor) - # 检查是否已经缓存 - if await redis_client.exists(next_cache_key): - logger.debug(f"Page {page + 1} already cached") - # 尝试从缓存获取cursor继续预取 - cached_data = await redis_client.get(next_cache_key) - if cached_data: - try: - data = json.loads(cached_data) - if data.get("cursor"): - cursor = data["cursor"] - continue - except Exception: - logger.warning("Failed to parse cached data for cursor") - break + # 检查是否已经缓存 + if await redis_client.exists(next_cache_key): + logger.debug(f"Page {page + 1} already cached") + # 尝试从缓存获取cursor继续预取 + cached_data = await redis_client.get(next_cache_key) + if cached_data: + try: + data = json.loads(cached_data) + if data.get("cursor"): + cursor = data["cursor"] + continue + except Exception: + logger.warning("Failed to parse cached data for cursor") + break - # 在预取页面之间添加延迟,避免突发请求 - if page > 1: - await asyncio.sleep(1.5) # 1.5秒延迟 + # 在预取页面之间添加延迟,避免突发请求 + if page > 1: + await asyncio.sleep(1.5) # 1.5秒延迟 - # 请求下一页数据 - params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True) + # 请求下一页数据 + params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True) - for k, v in cursor.items(): - params[f"cursor[{k}]"] = v + for k, v in cursor.items(): + params[f"cursor[{k}]"] = v - api_response = await self.request_api( - "https://osu.ppy.sh/api/v2/beatmapsets/search", - params=params, - ) + api_response = await self.request_api( + "https://osu.ppy.sh/api/v2/beatmapsets/search", + params=params, + ) - # 处理响应中的cursor信息 - if api_response.get("cursor"): - cursor_dict = api_response["cursor"] - api_response["cursor_string"] = self._encode_cursor(cursor_dict) - cursor = cursor_dict # 更新cursor用于下一页 - else: - # 没有更多页面了 - break + # 处理响应中的cursor信息 + if api_response.get("cursor"): + cursor_dict = api_response["cursor"] + api_response["cursor_string"] = self._encode_cursor(cursor_dict) + cursor = cursor_dict # 更新cursor用于下一页 + else: + # 没有更多页面了 + break - # 缓存结果(较短的TTL用于预取) - prefetch_ttl = 10 * 60 # 10 分钟 - await redis_client.set( - next_cache_key, - json.dumps(api_response, separators=(",", ":")), - ex=prefetch_ttl, - ) + # 缓存结果(较短的TTL用于预取) + prefetch_ttl = 10 * 60 # 10 分钟 + await redis_client.set( + next_cache_key, + json.dumps(api_response, separators=(",", ":")), + ex=prefetch_ttl, + ) - logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)") - - except RateLimitError: - logger.info("Prefetch stopped due to rate limit") - except Exception as e: - logger.warning(f"Prefetch failed: {e}") + logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)") async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None: """预热主页缓存""" @@ -370,12 +316,7 @@ class BeatmapsetFetcher(BaseFetcher): logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)") if api_response.get("cursor"): - try: - await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2) - except RateLimitError: - logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit") + await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2) - except RateLimitError: - logger.warning(f"Warmup skipped for {query.sort} due to rate limit") except Exception as e: logger.error(f"Failed to warmup cache for {query.sort}: {e}") diff --git a/app/helpers/rate_limiter.py b/app/helpers/rate_limiter.py deleted file mode 100644 index c80088a..0000000 --- a/app/helpers/rate_limiter.py +++ /dev/null @@ -1,116 +0,0 @@ -""" -Rate limiter for osu! API requests to avoid abuse detection. -根据 osu! API v2 的速率限制设计: -- 默认:每分钟最多 1200 次请求 -- 突发:短时间内最多 200 次额外请求 -- 建议:每分钟不超过 60 次请求以避免滥用检测 -""" - -import asyncio -from collections import deque -import time - -from app.log import logger - - -class RateLimiter: - """osu! API 速率限制器""" - - def __init__( - self, - max_requests_per_minute: int = 60, # 保守的限制 - burst_limit: int = 10, # 短时间内的突发限制 - burst_window: float = 10.0, # 突发窗口(秒) - ): - self.max_requests_per_minute = max_requests_per_minute - self.burst_limit = burst_limit - self.burst_window = burst_window - - # 跟踪请求时间戳 - self.request_times: deque[float] = deque() - self.burst_times: deque[float] = deque() - - # 锁确保线程安全 - self._lock = asyncio.Lock() - - async def acquire(self) -> None: - """获取请求许可,如果超过限制则等待""" - async with self._lock: - current_time = time.time() - - # 清理过期的请求记录 - self._cleanup_old_requests(current_time) - - # 检查是否需要等待 - wait_time = self._calculate_wait_time(current_time) - - if wait_time > 0: - logger.opt(colors=True).info( - f"[RateLimiter] Rate limit reached, waiting {wait_time:.2f}s" - ) - await asyncio.sleep(wait_time) - current_time = time.time() - self._cleanup_old_requests(current_time) - - # 记录当前请求 - self.request_times.append(current_time) - self.burst_times.append(current_time) - - logger.opt(colors=True).debug( - f"[RateLimiter] Request granted. " - f"Recent requests: {len(self.request_times)}/min, " - f"{len(self.burst_times)}/{self.burst_window}s" - ) - - def _cleanup_old_requests(self, current_time: float) -> None: - """清理过期的请求记录""" - # 清理1分钟前的请求 - minute_ago = current_time - 60.0 - while self.request_times and self.request_times[0] < minute_ago: - self.request_times.popleft() - - # 清理突发窗口外的请求 - burst_window_ago = current_time - self.burst_window - while self.burst_times and self.burst_times[0] < burst_window_ago: - self.burst_times.popleft() - - def _calculate_wait_time(self, current_time: float) -> float: - """计算需要等待的时间""" - # 检查每分钟限制 - if len(self.request_times) >= self.max_requests_per_minute: - # 需要等到最老的请求超过1分钟 - oldest_request = self.request_times[0] - wait_for_minute_limit = oldest_request + 60.0 - current_time - else: - wait_for_minute_limit = 0.0 - - # 检查突发限制 - if len(self.burst_times) >= self.burst_limit: - # 需要等到最老的突发请求超过突发窗口 - oldest_burst = self.burst_times[0] - wait_for_burst_limit = oldest_burst + self.burst_window - current_time - else: - wait_for_burst_limit = 0.0 - - return max(wait_for_minute_limit, wait_for_burst_limit, 0.0) - - def get_status(self) -> dict[str, int | float]: - """获取当前速率限制状态""" - current_time = time.time() - self._cleanup_old_requests(current_time) - - return { - "requests_this_minute": len(self.request_times), - "max_requests_per_minute": self.max_requests_per_minute, - "burst_requests": len(self.burst_times), - "burst_limit": self.burst_limit, - "next_reset_in_seconds": (60.0 - (current_time - self.request_times[0]) if self.request_times else 0.0), - } - - -# 全局速率限制器实例 -osu_api_rate_limiter = RateLimiter( - max_requests_per_minute=50, # 保守设置,低于建议的60 - burst_limit=8, # 短时间内最多8个请求 - burst_window=10.0, # 10秒窗口 -)