refactor(fetcher): implement passive rate limiter for API requests
This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.dependencies.database import get_redis
|
from app.dependencies.database import get_redis
|
||||||
from app.log import fetcher_logger
|
from app.log import fetcher_logger
|
||||||
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient, HTTPStatusError
|
||||||
|
|
||||||
|
|
||||||
class TokenAuthError(Exception):
|
class TokenAuthError(Exception):
|
||||||
@@ -13,10 +14,63 @@ class TokenAuthError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PassiveRateLimiter:
|
||||||
|
"""
|
||||||
|
被动速率限制器
|
||||||
|
当收到 429 响应时,读取 Retry-After 头并暂停所有请求
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._retry_after_time: float | None = None
|
||||||
|
self._waiting_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
async def wait_if_limited(self) -> None:
|
||||||
|
"""如果正在限流中,等待限流解除"""
|
||||||
|
async with self._lock:
|
||||||
|
if self._retry_after_time is not None:
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time < self._retry_after_time:
|
||||||
|
wait_seconds = self._retry_after_time - current_time
|
||||||
|
logger.warning(f"Rate limited, waiting {wait_seconds:.2f} seconds")
|
||||||
|
await asyncio.sleep(wait_seconds)
|
||||||
|
self._retry_after_time = None
|
||||||
|
|
||||||
|
async def handle_rate_limit(self, retry_after: str | int | None) -> None:
|
||||||
|
"""
|
||||||
|
处理 429 响应,设置限流时间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retry_after: Retry-After 头的值,可以是秒数或 HTTP 日期
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if retry_after is None:
|
||||||
|
# 如果没有 Retry-After 头,默认等待 60 秒
|
||||||
|
wait_seconds = 60
|
||||||
|
elif isinstance(retry_after, int):
|
||||||
|
wait_seconds = retry_after
|
||||||
|
elif retry_after.isdigit():
|
||||||
|
wait_seconds = int(retry_after)
|
||||||
|
else:
|
||||||
|
# 尝试解析 HTTP 日期格式
|
||||||
|
try:
|
||||||
|
retry_time = datetime.strptime(retry_after, "%a, %d %b %Y %H:%M:%S %Z")
|
||||||
|
wait_seconds = max(0, (retry_time - datetime.utcnow()).total_seconds())
|
||||||
|
except ValueError:
|
||||||
|
# 解析失败,默认等待 60 秒
|
||||||
|
wait_seconds = 60
|
||||||
|
|
||||||
|
self._retry_after_time = time.time() + wait_seconds
|
||||||
|
logger.warning(f"Rate limit triggered, will retry after {wait_seconds} seconds")
|
||||||
|
|
||||||
|
|
||||||
logger = fetcher_logger("Fetcher")
|
logger = fetcher_logger("Fetcher")
|
||||||
|
|
||||||
|
|
||||||
class BaseFetcher:
|
class BaseFetcher:
|
||||||
|
# 类级别的 rate limiter,所有实例共享
|
||||||
|
_rate_limiter = PassiveRateLimiter()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client_id: str,
|
client_id: str,
|
||||||
@@ -51,7 +105,7 @@ class BaseFetcher:
|
|||||||
|
|
||||||
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
|
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
|
||||||
"""
|
"""
|
||||||
发送 API 请求
|
发送 API 请求,支持被动速率限制
|
||||||
"""
|
"""
|
||||||
await self.ensure_valid_access_token()
|
await self.ensure_valid_access_token()
|
||||||
|
|
||||||
@@ -59,24 +113,41 @@ class BaseFetcher:
|
|||||||
attempt = 0
|
attempt = 0
|
||||||
|
|
||||||
while attempt < 2:
|
while attempt < 2:
|
||||||
|
# 在发送请求前等待速率限制
|
||||||
|
await self._rate_limiter.wait_if_limited()
|
||||||
|
|
||||||
request_headers = {**headers, **self.header}
|
request_headers = {**headers, **self.header}
|
||||||
request_kwargs = kwargs.copy()
|
request_kwargs = kwargs.copy()
|
||||||
|
|
||||||
async with AsyncClient() as client:
|
async with AsyncClient() as client:
|
||||||
response = await client.request(
|
try:
|
||||||
method,
|
response = await client.request(
|
||||||
url,
|
method,
|
||||||
headers=request_headers,
|
url,
|
||||||
**request_kwargs,
|
headers=request_headers,
|
||||||
)
|
**request_kwargs,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
if response.status_code != 401:
|
except HTTPStatusError as e:
|
||||||
response.raise_for_status()
|
# 处理 429 速率限制响应
|
||||||
return response.json()
|
if e.response.status_code == 429:
|
||||||
|
retry_after = e.response.headers.get("Retry-After")
|
||||||
|
logger.warning(f"Rate limited for {url}, Retry-After: {retry_after}")
|
||||||
|
await self._rate_limiter.handle_rate_limit(retry_after)
|
||||||
|
# 速率限制后重试当前请求(不增加 attempt)
|
||||||
|
continue
|
||||||
|
|
||||||
attempt += 1
|
# 处理 401 未授权响应
|
||||||
logger.warning(f"Received 401 error for {url}, attempt {attempt}")
|
if e.response.status_code == 401:
|
||||||
await self._handle_unauthorized()
|
attempt += 1
|
||||||
|
logger.warning(f"Received 401 error for {url}, attempt {attempt}")
|
||||||
|
await self._handle_unauthorized()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 其他 HTTP 错误直接抛出
|
||||||
|
raise
|
||||||
|
|
||||||
await self._clear_access_token()
|
await self._clear_access_token()
|
||||||
logger.warning(f"Failed to authorize after retries for {url}, cleaned up tokens")
|
logger.warning(f"Failed to authorize after retries for {url}, cleaned up tokens")
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from app.database import BeatmapsetDict, BeatmapsetModel, SearchBeatmapsetsResp
|
from app.database import BeatmapsetDict, BeatmapsetModel, SearchBeatmapsetsResp
|
||||||
from app.helpers.rate_limiter import osu_api_rate_limiter
|
|
||||||
from app.log import fetcher_logger
|
from app.log import fetcher_logger
|
||||||
from app.models.beatmap import SearchQueryModel
|
from app.models.beatmap import SearchQueryModel
|
||||||
from app.models.model import Cursor
|
from app.models.model import Cursor
|
||||||
@@ -12,21 +11,12 @@ from app.utils import bg_tasks
|
|||||||
|
|
||||||
from ._base import BaseFetcher
|
from ._base import BaseFetcher
|
||||||
|
|
||||||
from httpx import AsyncClient
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
|
|
||||||
|
|
||||||
class RateLimitError(Exception):
|
|
||||||
"""速率限制异常"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
logger = fetcher_logger("BeatmapsetFetcher")
|
logger = fetcher_logger("BeatmapsetFetcher")
|
||||||
|
|
||||||
|
|
||||||
MAX_RETRY_ATTEMPTS = 3
|
|
||||||
adapter = TypeAdapter(
|
adapter = TypeAdapter(
|
||||||
BeatmapsetModel.generate_typeddict(
|
BeatmapsetModel.generate_typeddict(
|
||||||
(
|
(
|
||||||
@@ -90,41 +80,6 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
|
|
||||||
return homepage_queries
|
return homepage_queries
|
||||||
|
|
||||||
async def request_api(self, url: str, method: str = "GET", *, retry_times: int = 0, **kwargs) -> dict:
|
|
||||||
"""覆盖基类方法,添加速率限制和429错误处理"""
|
|
||||||
# 在请求前获取速率限制许可
|
|
||||||
if retry_times > MAX_RETRY_ATTEMPTS:
|
|
||||||
raise RuntimeError(f"Maximum retry attempts ({MAX_RETRY_ATTEMPTS}) reached for API request to {url}")
|
|
||||||
|
|
||||||
await osu_api_rate_limiter.acquire()
|
|
||||||
|
|
||||||
# 检查 token 是否过期,如果过期则刷新
|
|
||||||
if self.is_token_expired():
|
|
||||||
await self.grant_access_token()
|
|
||||||
|
|
||||||
header = kwargs.pop("headers", {})
|
|
||||||
header.update(self.header)
|
|
||||||
|
|
||||||
async with AsyncClient() as client:
|
|
||||||
response = await client.request(
|
|
||||||
method,
|
|
||||||
url,
|
|
||||||
headers=header,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理 429 错误 - 直接抛出异常,不重试
|
|
||||||
if response.status_code == 429:
|
|
||||||
logger.warning(f"Rate limit exceeded (429) for {url}")
|
|
||||||
raise RateLimitError(f"Rate limit exceeded for {url}. Please try again later.")
|
|
||||||
if response.status_code == 401:
|
|
||||||
logger.warning(f"Received 401 error for {url}")
|
|
||||||
await self._clear_access_token()
|
|
||||||
return await self.request_api(url, method, retry_times=retry_times + 1, **kwargs)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_cache_key(query: SearchQueryModel, cursor: Cursor) -> str:
|
def _generate_cache_key(query: SearchQueryModel, cursor: Cursor) -> str:
|
||||||
"""生成搜索缓存键"""
|
"""生成搜索缓存键"""
|
||||||
@@ -242,10 +197,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
# 不立即创建任务,而是延迟一段时间再预取
|
# 不立即创建任务,而是延迟一段时间再预取
|
||||||
async def delayed_prefetch():
|
async def delayed_prefetch():
|
||||||
await asyncio.sleep(3.0) # 延迟3秒
|
await asyncio.sleep(3.0) # 延迟3秒
|
||||||
try:
|
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
||||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
|
||||||
except RateLimitError:
|
|
||||||
logger.info("Prefetch skipped due to rate limit")
|
|
||||||
|
|
||||||
bg_tasks.add_task(delayed_prefetch)
|
bg_tasks.add_task(delayed_prefetch)
|
||||||
|
|
||||||
@@ -262,71 +214,65 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
if not current_cursor:
|
if not current_cursor:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
cursor = current_cursor.copy()
|
||||||
cursor = current_cursor.copy()
|
|
||||||
|
|
||||||
for page in range(1, pages + 1):
|
for page in range(1, pages + 1):
|
||||||
# 使用当前 cursor 请求下一页
|
# 使用当前 cursor 请求下一页
|
||||||
next_query = query.model_copy()
|
next_query = query.model_copy()
|
||||||
|
|
||||||
logger.debug(f"Prefetching page {page + 1}")
|
logger.debug(f"Prefetching page {page + 1}")
|
||||||
|
|
||||||
# 生成下一页的缓存键
|
# 生成下一页的缓存键
|
||||||
next_cache_key = self._generate_cache_key(next_query, cursor)
|
next_cache_key = self._generate_cache_key(next_query, cursor)
|
||||||
|
|
||||||
# 检查是否已经缓存
|
# 检查是否已经缓存
|
||||||
if await redis_client.exists(next_cache_key):
|
if await redis_client.exists(next_cache_key):
|
||||||
logger.debug(f"Page {page + 1} already cached")
|
logger.debug(f"Page {page + 1} already cached")
|
||||||
# 尝试从缓存获取cursor继续预取
|
# 尝试从缓存获取cursor继续预取
|
||||||
cached_data = await redis_client.get(next_cache_key)
|
cached_data = await redis_client.get(next_cache_key)
|
||||||
if cached_data:
|
if cached_data:
|
||||||
try:
|
try:
|
||||||
data = json.loads(cached_data)
|
data = json.loads(cached_data)
|
||||||
if data.get("cursor"):
|
if data.get("cursor"):
|
||||||
cursor = data["cursor"]
|
cursor = data["cursor"]
|
||||||
continue
|
continue
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to parse cached data for cursor")
|
logger.warning("Failed to parse cached data for cursor")
|
||||||
break
|
break
|
||||||
|
|
||||||
# 在预取页面之间添加延迟,避免突发请求
|
# 在预取页面之间添加延迟,避免突发请求
|
||||||
if page > 1:
|
if page > 1:
|
||||||
await asyncio.sleep(1.5) # 1.5秒延迟
|
await asyncio.sleep(1.5) # 1.5秒延迟
|
||||||
|
|
||||||
# 请求下一页数据
|
# 请求下一页数据
|
||||||
params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
||||||
|
|
||||||
for k, v in cursor.items():
|
for k, v in cursor.items():
|
||||||
params[f"cursor[{k}]"] = v
|
params[f"cursor[{k}]"] = v
|
||||||
|
|
||||||
api_response = await self.request_api(
|
api_response = await self.request_api(
|
||||||
"https://osu.ppy.sh/api/v2/beatmapsets/search",
|
"https://osu.ppy.sh/api/v2/beatmapsets/search",
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 处理响应中的cursor信息
|
# 处理响应中的cursor信息
|
||||||
if api_response.get("cursor"):
|
if api_response.get("cursor"):
|
||||||
cursor_dict = api_response["cursor"]
|
cursor_dict = api_response["cursor"]
|
||||||
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
|
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
|
||||||
cursor = cursor_dict # 更新cursor用于下一页
|
cursor = cursor_dict # 更新cursor用于下一页
|
||||||
else:
|
else:
|
||||||
# 没有更多页面了
|
# 没有更多页面了
|
||||||
break
|
break
|
||||||
|
|
||||||
# 缓存结果(较短的TTL用于预取)
|
# 缓存结果(较短的TTL用于预取)
|
||||||
prefetch_ttl = 10 * 60 # 10 分钟
|
prefetch_ttl = 10 * 60 # 10 分钟
|
||||||
await redis_client.set(
|
await redis_client.set(
|
||||||
next_cache_key,
|
next_cache_key,
|
||||||
json.dumps(api_response, separators=(",", ":")),
|
json.dumps(api_response, separators=(",", ":")),
|
||||||
ex=prefetch_ttl,
|
ex=prefetch_ttl,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
|
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
|
||||||
|
|
||||||
except RateLimitError:
|
|
||||||
logger.info("Prefetch stopped due to rate limit")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Prefetch failed: {e}")
|
|
||||||
|
|
||||||
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
|
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
|
||||||
"""预热主页缓存"""
|
"""预热主页缓存"""
|
||||||
@@ -370,12 +316,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
|
logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
|
||||||
|
|
||||||
if api_response.get("cursor"):
|
if api_response.get("cursor"):
|
||||||
try:
|
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
|
||||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
|
|
||||||
except RateLimitError:
|
|
||||||
logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
|
|
||||||
|
|
||||||
except RateLimitError:
|
|
||||||
logger.warning(f"Warmup skipped for {query.sort} due to rate limit")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to warmup cache for {query.sort}: {e}")
|
logger.error(f"Failed to warmup cache for {query.sort}: {e}")
|
||||||
|
|||||||
@@ -1,116 +0,0 @@
|
|||||||
"""
|
|
||||||
Rate limiter for osu! API requests to avoid abuse detection.
|
|
||||||
根据 osu! API v2 的速率限制设计:
|
|
||||||
- 默认:每分钟最多 1200 次请求
|
|
||||||
- 突发:短时间内最多 200 次额外请求
|
|
||||||
- 建议:每分钟不超过 60 次请求以避免滥用检测
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections import deque
|
|
||||||
import time
|
|
||||||
|
|
||||||
from app.log import logger
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
|
||||||
"""osu! API 速率限制器"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_requests_per_minute: int = 60, # 保守的限制
|
|
||||||
burst_limit: int = 10, # 短时间内的突发限制
|
|
||||||
burst_window: float = 10.0, # 突发窗口(秒)
|
|
||||||
):
|
|
||||||
self.max_requests_per_minute = max_requests_per_minute
|
|
||||||
self.burst_limit = burst_limit
|
|
||||||
self.burst_window = burst_window
|
|
||||||
|
|
||||||
# 跟踪请求时间戳
|
|
||||||
self.request_times: deque[float] = deque()
|
|
||||||
self.burst_times: deque[float] = deque()
|
|
||||||
|
|
||||||
# 锁确保线程安全
|
|
||||||
self._lock = asyncio.Lock()
|
|
||||||
|
|
||||||
async def acquire(self) -> None:
|
|
||||||
"""获取请求许可,如果超过限制则等待"""
|
|
||||||
async with self._lock:
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# 清理过期的请求记录
|
|
||||||
self._cleanup_old_requests(current_time)
|
|
||||||
|
|
||||||
# 检查是否需要等待
|
|
||||||
wait_time = self._calculate_wait_time(current_time)
|
|
||||||
|
|
||||||
if wait_time > 0:
|
|
||||||
logger.opt(colors=True).info(
|
|
||||||
f"<yellow>[RateLimiter]</yellow> Rate limit reached, waiting {wait_time:.2f}s"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
current_time = time.time()
|
|
||||||
self._cleanup_old_requests(current_time)
|
|
||||||
|
|
||||||
# 记录当前请求
|
|
||||||
self.request_times.append(current_time)
|
|
||||||
self.burst_times.append(current_time)
|
|
||||||
|
|
||||||
logger.opt(colors=True).debug(
|
|
||||||
f"<green>[RateLimiter]</green> Request granted. "
|
|
||||||
f"Recent requests: {len(self.request_times)}/min, "
|
|
||||||
f"{len(self.burst_times)}/{self.burst_window}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _cleanup_old_requests(self, current_time: float) -> None:
|
|
||||||
"""清理过期的请求记录"""
|
|
||||||
# 清理1分钟前的请求
|
|
||||||
minute_ago = current_time - 60.0
|
|
||||||
while self.request_times and self.request_times[0] < minute_ago:
|
|
||||||
self.request_times.popleft()
|
|
||||||
|
|
||||||
# 清理突发窗口外的请求
|
|
||||||
burst_window_ago = current_time - self.burst_window
|
|
||||||
while self.burst_times and self.burst_times[0] < burst_window_ago:
|
|
||||||
self.burst_times.popleft()
|
|
||||||
|
|
||||||
def _calculate_wait_time(self, current_time: float) -> float:
|
|
||||||
"""计算需要等待的时间"""
|
|
||||||
# 检查每分钟限制
|
|
||||||
if len(self.request_times) >= self.max_requests_per_minute:
|
|
||||||
# 需要等到最老的请求超过1分钟
|
|
||||||
oldest_request = self.request_times[0]
|
|
||||||
wait_for_minute_limit = oldest_request + 60.0 - current_time
|
|
||||||
else:
|
|
||||||
wait_for_minute_limit = 0.0
|
|
||||||
|
|
||||||
# 检查突发限制
|
|
||||||
if len(self.burst_times) >= self.burst_limit:
|
|
||||||
# 需要等到最老的突发请求超过突发窗口
|
|
||||||
oldest_burst = self.burst_times[0]
|
|
||||||
wait_for_burst_limit = oldest_burst + self.burst_window - current_time
|
|
||||||
else:
|
|
||||||
wait_for_burst_limit = 0.0
|
|
||||||
|
|
||||||
return max(wait_for_minute_limit, wait_for_burst_limit, 0.0)
|
|
||||||
|
|
||||||
def get_status(self) -> dict[str, int | float]:
|
|
||||||
"""获取当前速率限制状态"""
|
|
||||||
current_time = time.time()
|
|
||||||
self._cleanup_old_requests(current_time)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"requests_this_minute": len(self.request_times),
|
|
||||||
"max_requests_per_minute": self.max_requests_per_minute,
|
|
||||||
"burst_requests": len(self.burst_times),
|
|
||||||
"burst_limit": self.burst_limit,
|
|
||||||
"next_reset_in_seconds": (60.0 - (current_time - self.request_times[0]) if self.request_times else 0.0),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局速率限制器实例
|
|
||||||
osu_api_rate_limiter = RateLimiter(
|
|
||||||
max_requests_per_minute=50, # 保守设置,低于建议的60
|
|
||||||
burst_limit=8, # 短时间内最多8个请求
|
|
||||||
burst_window=10.0, # 10秒窗口
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user