refactor(fetcher): implement passive rate limiter for API requests
This commit is contained in:
@@ -4,7 +4,6 @@ import hashlib
|
||||
import json
|
||||
|
||||
from app.database import BeatmapsetDict, BeatmapsetModel, SearchBeatmapsetsResp
|
||||
from app.helpers.rate_limiter import osu_api_rate_limiter
|
||||
from app.log import fetcher_logger
|
||||
from app.models.beatmap import SearchQueryModel
|
||||
from app.models.model import Cursor
|
||||
@@ -12,21 +11,12 @@ from app.utils import bg_tasks
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
from pydantic import TypeAdapter
|
||||
import redis.asyncio as redis
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""速率限制异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
logger = fetcher_logger("BeatmapsetFetcher")
|
||||
|
||||
|
||||
MAX_RETRY_ATTEMPTS = 3
|
||||
adapter = TypeAdapter(
|
||||
BeatmapsetModel.generate_typeddict(
|
||||
(
|
||||
@@ -90,41 +80,6 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
|
||||
return homepage_queries
|
||||
|
||||
async def request_api(self, url: str, method: str = "GET", *, retry_times: int = 0, **kwargs) -> dict:
|
||||
"""覆盖基类方法,添加速率限制和429错误处理"""
|
||||
# 在请求前获取速率限制许可
|
||||
if retry_times > MAX_RETRY_ATTEMPTS:
|
||||
raise RuntimeError(f"Maximum retry attempts ({MAX_RETRY_ATTEMPTS}) reached for API request to {url}")
|
||||
|
||||
await osu_api_rate_limiter.acquire()
|
||||
|
||||
# 检查 token 是否过期,如果过期则刷新
|
||||
if self.is_token_expired():
|
||||
await self.grant_access_token()
|
||||
|
||||
header = kwargs.pop("headers", {})
|
||||
header.update(self.header)
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.request(
|
||||
method,
|
||||
url,
|
||||
headers=header,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# 处理 429 错误 - 直接抛出异常,不重试
|
||||
if response.status_code == 429:
|
||||
logger.warning(f"Rate limit exceeded (429) for {url}")
|
||||
raise RateLimitError(f"Rate limit exceeded for {url}. Please try again later.")
|
||||
if response.status_code == 401:
|
||||
logger.warning(f"Received 401 error for {url}")
|
||||
await self._clear_access_token()
|
||||
return await self.request_api(url, method, retry_times=retry_times + 1, **kwargs)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def _generate_cache_key(query: SearchQueryModel, cursor: Cursor) -> str:
|
||||
"""生成搜索缓存键"""
|
||||
@@ -242,10 +197,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
# 不立即创建任务,而是延迟一段时间再预取
|
||||
async def delayed_prefetch():
|
||||
await asyncio.sleep(3.0) # 延迟3秒
|
||||
try:
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
||||
except RateLimitError:
|
||||
logger.info("Prefetch skipped due to rate limit")
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
||||
|
||||
bg_tasks.add_task(delayed_prefetch)
|
||||
|
||||
@@ -262,71 +214,65 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
if not current_cursor:
|
||||
return
|
||||
|
||||
try:
|
||||
cursor = current_cursor.copy()
|
||||
cursor = current_cursor.copy()
|
||||
|
||||
for page in range(1, pages + 1):
|
||||
# 使用当前 cursor 请求下一页
|
||||
next_query = query.model_copy()
|
||||
for page in range(1, pages + 1):
|
||||
# 使用当前 cursor 请求下一页
|
||||
next_query = query.model_copy()
|
||||
|
||||
logger.debug(f"Prefetching page {page + 1}")
|
||||
logger.debug(f"Prefetching page {page + 1}")
|
||||
|
||||
# 生成下一页的缓存键
|
||||
next_cache_key = self._generate_cache_key(next_query, cursor)
|
||||
# 生成下一页的缓存键
|
||||
next_cache_key = self._generate_cache_key(next_query, cursor)
|
||||
|
||||
# 检查是否已经缓存
|
||||
if await redis_client.exists(next_cache_key):
|
||||
logger.debug(f"Page {page + 1} already cached")
|
||||
# 尝试从缓存获取cursor继续预取
|
||||
cached_data = await redis_client.get(next_cache_key)
|
||||
if cached_data:
|
||||
try:
|
||||
data = json.loads(cached_data)
|
||||
if data.get("cursor"):
|
||||
cursor = data["cursor"]
|
||||
continue
|
||||
except Exception:
|
||||
logger.warning("Failed to parse cached data for cursor")
|
||||
break
|
||||
# 检查是否已经缓存
|
||||
if await redis_client.exists(next_cache_key):
|
||||
logger.debug(f"Page {page + 1} already cached")
|
||||
# 尝试从缓存获取cursor继续预取
|
||||
cached_data = await redis_client.get(next_cache_key)
|
||||
if cached_data:
|
||||
try:
|
||||
data = json.loads(cached_data)
|
||||
if data.get("cursor"):
|
||||
cursor = data["cursor"]
|
||||
continue
|
||||
except Exception:
|
||||
logger.warning("Failed to parse cached data for cursor")
|
||||
break
|
||||
|
||||
# 在预取页面之间添加延迟,避免突发请求
|
||||
if page > 1:
|
||||
await asyncio.sleep(1.5) # 1.5秒延迟
|
||||
# 在预取页面之间添加延迟,避免突发请求
|
||||
if page > 1:
|
||||
await asyncio.sleep(1.5) # 1.5秒延迟
|
||||
|
||||
# 请求下一页数据
|
||||
params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
||||
# 请求下一页数据
|
||||
params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
||||
|
||||
for k, v in cursor.items():
|
||||
params[f"cursor[{k}]"] = v
|
||||
for k, v in cursor.items():
|
||||
params[f"cursor[{k}]"] = v
|
||||
|
||||
api_response = await self.request_api(
|
||||
"https://osu.ppy.sh/api/v2/beatmapsets/search",
|
||||
params=params,
|
||||
)
|
||||
api_response = await self.request_api(
|
||||
"https://osu.ppy.sh/api/v2/beatmapsets/search",
|
||||
params=params,
|
||||
)
|
||||
|
||||
# 处理响应中的cursor信息
|
||||
if api_response.get("cursor"):
|
||||
cursor_dict = api_response["cursor"]
|
||||
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
|
||||
cursor = cursor_dict # 更新cursor用于下一页
|
||||
else:
|
||||
# 没有更多页面了
|
||||
break
|
||||
# 处理响应中的cursor信息
|
||||
if api_response.get("cursor"):
|
||||
cursor_dict = api_response["cursor"]
|
||||
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
|
||||
cursor = cursor_dict # 更新cursor用于下一页
|
||||
else:
|
||||
# 没有更多页面了
|
||||
break
|
||||
|
||||
# 缓存结果(较短的TTL用于预取)
|
||||
prefetch_ttl = 10 * 60 # 10 分钟
|
||||
await redis_client.set(
|
||||
next_cache_key,
|
||||
json.dumps(api_response, separators=(",", ":")),
|
||||
ex=prefetch_ttl,
|
||||
)
|
||||
# 缓存结果(较短的TTL用于预取)
|
||||
prefetch_ttl = 10 * 60 # 10 分钟
|
||||
await redis_client.set(
|
||||
next_cache_key,
|
||||
json.dumps(api_response, separators=(",", ":")),
|
||||
ex=prefetch_ttl,
|
||||
)
|
||||
|
||||
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
|
||||
|
||||
except RateLimitError:
|
||||
logger.info("Prefetch stopped due to rate limit")
|
||||
except Exception as e:
|
||||
logger.warning(f"Prefetch failed: {e}")
|
||||
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
|
||||
|
||||
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
|
||||
"""预热主页缓存"""
|
||||
@@ -370,12 +316,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
|
||||
|
||||
if api_response.get("cursor"):
|
||||
try:
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
|
||||
except RateLimitError:
|
||||
logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
|
||||
|
||||
except RateLimitError:
|
||||
logger.warning(f"Warmup skipped for {query.sort} due to rate limit")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to warmup cache for {query.sort}: {e}")
|
||||
|
||||
Reference in New Issue
Block a user