refactor(fetcher): implement passive rate limiter for API requests

This commit is contained in:
MingxuanGame
2025-11-23 13:42:21 +00:00
parent 40da994ae8
commit 819f52450c
3 changed files with 135 additions and 239 deletions

View File

@@ -1,10 +1,11 @@
import asyncio
from datetime import datetime
import time
from app.dependencies.database import get_redis
from app.log import fetcher_logger
from httpx import AsyncClient
from httpx import AsyncClient, HTTPStatusError
class TokenAuthError(Exception):
@@ -13,10 +14,63 @@ class TokenAuthError(Exception):
pass
class PassiveRateLimiter:
"""
被动速率限制器
当收到 429 响应时,读取 Retry-After 头并暂停所有请求
"""
def __init__(self):
self._lock = asyncio.Lock()
self._retry_after_time: float | None = None
self._waiting_tasks: set[asyncio.Task] = set()
async def wait_if_limited(self) -> None:
"""如果正在限流中,等待限流解除"""
async with self._lock:
if self._retry_after_time is not None:
current_time = time.time()
if current_time < self._retry_after_time:
wait_seconds = self._retry_after_time - current_time
logger.warning(f"Rate limited, waiting {wait_seconds:.2f} seconds")
await asyncio.sleep(wait_seconds)
self._retry_after_time = None
async def handle_rate_limit(self, retry_after: str | int | None) -> None:
"""
处理 429 响应,设置限流时间
Args:
retry_after: Retry-After 头的值,可以是秒数或 HTTP 日期
"""
async with self._lock:
if retry_after is None:
# 如果没有 Retry-After 头,默认等待 60 秒
wait_seconds = 60
elif isinstance(retry_after, int):
wait_seconds = retry_after
elif retry_after.isdigit():
wait_seconds = int(retry_after)
else:
# 尝试解析 HTTP 日期格式
try:
retry_time = datetime.strptime(retry_after, "%a, %d %b %Y %H:%M:%S %Z")
wait_seconds = max(0, (retry_time - datetime.utcnow()).total_seconds())
except ValueError:
# 解析失败,默认等待 60 秒
wait_seconds = 60
self._retry_after_time = time.time() + wait_seconds
logger.warning(f"Rate limit triggered, will retry after {wait_seconds} seconds")
logger = fetcher_logger("Fetcher")
class BaseFetcher:
# 类级别的 rate limiter所有实例共享
_rate_limiter = PassiveRateLimiter()
def __init__(
self,
client_id: str,
@@ -51,7 +105,7 @@ class BaseFetcher:
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
"""
发送 API 请求
发送 API 请求,支持被动速率限制
"""
await self.ensure_valid_access_token()
@@ -59,24 +113,41 @@ class BaseFetcher:
attempt = 0
while attempt < 2:
# 在发送请求前等待速率限制
await self._rate_limiter.wait_if_limited()
request_headers = {**headers, **self.header}
request_kwargs = kwargs.copy()
async with AsyncClient() as client:
response = await client.request(
method,
url,
headers=request_headers,
**request_kwargs,
)
try:
response = await client.request(
method,
url,
headers=request_headers,
**request_kwargs,
)
response.raise_for_status()
return response.json()
if response.status_code != 401:
response.raise_for_status()
return response.json()
except HTTPStatusError as e:
# 处理 429 速率限制响应
if e.response.status_code == 429:
retry_after = e.response.headers.get("Retry-After")
logger.warning(f"Rate limited for {url}, Retry-After: {retry_after}")
await self._rate_limiter.handle_rate_limit(retry_after)
# 速率限制后重试当前请求(不增加 attempt
continue
attempt += 1
logger.warning(f"Received 401 error for {url}, attempt {attempt}")
await self._handle_unauthorized()
# 处理 401 未授权响应
if e.response.status_code == 401:
attempt += 1
logger.warning(f"Received 401 error for {url}, attempt {attempt}")
await self._handle_unauthorized()
continue
# 其他 HTTP 错误直接抛出
raise
await self._clear_access_token()
logger.warning(f"Failed to authorize after retries for {url}, cleaned up tokens")

View File

@@ -4,7 +4,6 @@ import hashlib
import json
from app.database import BeatmapsetDict, BeatmapsetModel, SearchBeatmapsetsResp
from app.helpers.rate_limiter import osu_api_rate_limiter
from app.log import fetcher_logger
from app.models.beatmap import SearchQueryModel
from app.models.model import Cursor
@@ -12,21 +11,12 @@ from app.utils import bg_tasks
from ._base import BaseFetcher
from httpx import AsyncClient
from pydantic import TypeAdapter
import redis.asyncio as redis
class RateLimitError(Exception):
"""速率限制异常"""
pass
logger = fetcher_logger("BeatmapsetFetcher")
MAX_RETRY_ATTEMPTS = 3
adapter = TypeAdapter(
BeatmapsetModel.generate_typeddict(
(
@@ -90,41 +80,6 @@ class BeatmapsetFetcher(BaseFetcher):
return homepage_queries
async def request_api(self, url: str, method: str = "GET", *, retry_times: int = 0, **kwargs) -> dict:
"""覆盖基类方法添加速率限制和429错误处理"""
# 在请求前获取速率限制许可
if retry_times > MAX_RETRY_ATTEMPTS:
raise RuntimeError(f"Maximum retry attempts ({MAX_RETRY_ATTEMPTS}) reached for API request to {url}")
await osu_api_rate_limiter.acquire()
# 检查 token 是否过期,如果过期则刷新
if self.is_token_expired():
await self.grant_access_token()
header = kwargs.pop("headers", {})
header.update(self.header)
async with AsyncClient() as client:
response = await client.request(
method,
url,
headers=header,
**kwargs,
)
# 处理 429 错误 - 直接抛出异常,不重试
if response.status_code == 429:
logger.warning(f"Rate limit exceeded (429) for {url}")
raise RateLimitError(f"Rate limit exceeded for {url}. Please try again later.")
if response.status_code == 401:
logger.warning(f"Received 401 error for {url}")
await self._clear_access_token()
return await self.request_api(url, method, retry_times=retry_times + 1, **kwargs)
response.raise_for_status()
return response.json()
@staticmethod
def _generate_cache_key(query: SearchQueryModel, cursor: Cursor) -> str:
"""生成搜索缓存键"""
@@ -242,10 +197,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 不立即创建任务,而是延迟一段时间再预取
async def delayed_prefetch():
await asyncio.sleep(3.0) # 延迟3秒
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
except RateLimitError:
logger.info("Prefetch skipped due to rate limit")
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
bg_tasks.add_task(delayed_prefetch)
@@ -262,71 +214,65 @@ class BeatmapsetFetcher(BaseFetcher):
if not current_cursor:
return
try:
cursor = current_cursor.copy()
cursor = current_cursor.copy()
for page in range(1, pages + 1):
# 使用当前 cursor 请求下一页
next_query = query.model_copy()
for page in range(1, pages + 1):
# 使用当前 cursor 请求下一页
next_query = query.model_copy()
logger.debug(f"Prefetching page {page + 1}")
logger.debug(f"Prefetching page {page + 1}")
# 生成下一页的缓存键
next_cache_key = self._generate_cache_key(next_query, cursor)
# 生成下一页的缓存键
next_cache_key = self._generate_cache_key(next_query, cursor)
# 检查是否已经缓存
if await redis_client.exists(next_cache_key):
logger.debug(f"Page {page + 1} already cached")
# 尝试从缓存获取cursor继续预取
cached_data = await redis_client.get(next_cache_key)
if cached_data:
try:
data = json.loads(cached_data)
if data.get("cursor"):
cursor = data["cursor"]
continue
except Exception:
logger.warning("Failed to parse cached data for cursor")
break
# 检查是否已经缓存
if await redis_client.exists(next_cache_key):
logger.debug(f"Page {page + 1} already cached")
# 尝试从缓存获取cursor继续预取
cached_data = await redis_client.get(next_cache_key)
if cached_data:
try:
data = json.loads(cached_data)
if data.get("cursor"):
cursor = data["cursor"]
continue
except Exception:
logger.warning("Failed to parse cached data for cursor")
break
# 在预取页面之间添加延迟,避免突发请求
if page > 1:
await asyncio.sleep(1.5) # 1.5秒延迟
# 在预取页面之间添加延迟,避免突发请求
if page > 1:
await asyncio.sleep(1.5) # 1.5秒延迟
# 请求下一页数据
params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
# 请求下一页数据
params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
for k, v in cursor.items():
params[f"cursor[{k}]"] = v
for k, v in cursor.items():
params[f"cursor[{k}]"] = v
api_response = await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search",
params=params,
)
api_response = await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search",
params=params,
)
# 处理响应中的cursor信息
if api_response.get("cursor"):
cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
cursor = cursor_dict # 更新cursor用于下一页
else:
# 没有更多页面了
break
# 处理响应中的cursor信息
if api_response.get("cursor"):
cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
cursor = cursor_dict # 更新cursor用于下一页
else:
# 没有更多页面了
break
# 缓存结果较短的TTL用于预取
prefetch_ttl = 10 * 60 # 10 分钟
await redis_client.set(
next_cache_key,
json.dumps(api_response, separators=(",", ":")),
ex=prefetch_ttl,
)
# 缓存结果较短的TTL用于预取
prefetch_ttl = 10 * 60 # 10 分钟
await redis_client.set(
next_cache_key,
json.dumps(api_response, separators=(",", ":")),
ex=prefetch_ttl,
)
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
except RateLimitError:
logger.info("Prefetch stopped due to rate limit")
except Exception as e:
logger.warning(f"Prefetch failed: {e}")
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
"""预热主页缓存"""
@@ -370,12 +316,7 @@ class BeatmapsetFetcher(BaseFetcher):
logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
if api_response.get("cursor"):
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
except RateLimitError:
logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
except RateLimitError:
logger.warning(f"Warmup skipped for {query.sort} due to rate limit")
except Exception as e:
logger.error(f"Failed to warmup cache for {query.sort}: {e}")