Files
g0v0-server/app/fetcher/beatmapset.py
咕谷酱 b4fd4e0256 Handle rate limit errors in BeatmapsetFetcher
Introduces RateLimitError to manage 429 responses from the API, updating request_api to raise this error and adding handling in prefetch and warmup logic to skip or log when rate limits are hit. Also improves error handling for authentication failures and token expiration.
2025-08-26 13:18:11 +08:00

355 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import asyncio
import base64
import hashlib
import json
from app.database.beatmapset import BeatmapsetResp, SearchBeatmapsetsResp
from app.helpers.rate_limiter import osu_api_rate_limiter
from app.log import logger
from app.models.beatmap import SearchQueryModel
from app.models.model import Cursor
from app.utils import bg_tasks
from ._base import BaseFetcher
import redis.asyncio as redis
from httpx import AsyncClient
class RateLimitError(Exception):
"""速率限制异常"""
pass
class BeatmapsetFetcher(BaseFetcher):
@staticmethod
def _get_homepage_queries() -> list[tuple[SearchQueryModel, Cursor]]:
"""获取主页预缓存查询列表"""
# 主页常用查询组合
homepage_queries = []
# 主要排序方式
sorts = ["ranked_desc", "updated_desc", "favourites_desc", "plays_desc"]
for sort in sorts:
# 第一页 - 使用最小参数集合以匹配用户请求
query = SearchQueryModel(
q="",
s="leaderboard",
sort=sort, # type: ignore
)
homepage_queries.append((query, {}))
return homepage_queries
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
"""覆盖基类方法添加速率限制和429错误处理"""
# 在请求前获取速率限制许可
await osu_api_rate_limiter.acquire()
# 检查 token 是否过期,如果过期则刷新
if self.is_token_expired():
await self.refresh_access_token()
header = kwargs.pop("headers", {})
header.update(self.header)
async with AsyncClient() as client:
response = await client.request(
method,
url,
headers=header,
**kwargs,
)
# 处理 429 错误 - 直接抛出异常,不重试
if response.status_code == 429:
logger.warning(f"Rate limit exceeded (429) for {url}")
raise RateLimitError(f"Rate limit exceeded for {url}. Please try again later.")
# 处理 401 错误
if response.status_code == 401:
logger.warning(f"Received 401 error for {url}")
await self._clear_tokens()
raise TokenAuthError(
f"Authentication failed. Please re-authorize using: {self.authorize_url}"
)
response.raise_for_status()
return response.json()
@staticmethod
def _generate_cache_key(query: SearchQueryModel, cursor: Cursor) -> str:
"""生成搜索缓存键"""
# 只包含核心查询参数,忽略默认值
cache_data = {}
# 添加非默认/非空的查询参数
if query.q:
cache_data["q"] = query.q
if query.s != "leaderboard": # 只有非默认值才加入
cache_data["s"] = query.s
if hasattr(query, "sort") and query.sort:
cache_data["sort"] = query.sort
if query.nsfw is not False: # 只有非默认值才加入
cache_data["nsfw"] = query.nsfw
if query.m is not None:
cache_data["m"] = query.m
if query.c:
cache_data["c"] = query.c
if query.l != "any": # 检查语言默认值
cache_data["l"] = query.l
if query.e:
cache_data["e"] = query.e
if query.r:
cache_data["r"] = query.r
if query.played is not False:
cache_data["played"] = query.played
# 添加 cursor
if cursor:
cache_data["cursor"] = cursor
# 序列化为 JSON 并生成 MD5 哈希
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
return f"beatmapset:search:{cache_hash}"
@staticmethod
def _encode_cursor(cursor_dict: dict[str, int | float]) -> str:
"""将cursor字典编码为base64字符串"""
cursor_json = json.dumps(cursor_dict, separators=(",", ":"))
return base64.b64encode(cursor_json.encode()).decode()
@staticmethod
def _decode_cursor(cursor_string: str) -> dict[str, int | float]:
"""将base64字符串解码为cursor字典"""
try:
cursor_json = base64.b64decode(cursor_string).decode()
return json.loads(cursor_json)
except Exception:
return {}
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>")
return BeatmapsetResp.model_validate(
await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}")
)
async def search_beatmapset(
self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis
) -> SearchBeatmapsetsResp:
logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> search_beatmapset: <y>{query}</y>")
# 生成缓存键
cache_key = self._generate_cache_key(query, cursor)
# 尝试从缓存获取结果
cached_result = await redis_client.get(cache_key)
if cached_result:
logger.opt(colors=True).debug(f"<green>[BeatmapsetFetcher]</green> Cache hit for key: <y>{cache_key}</y>")
try:
cached_data = json.loads(cached_result)
return SearchBeatmapsetsResp.model_validate(cached_data)
except Exception as e:
logger.opt(colors=True).warning(
f"<yellow>[BeatmapsetFetcher]</yellow> Cache data invalid, fetching from API: {e}"
)
# 缓存未命中,从 API 获取数据
logger.opt(colors=True).debug("<blue>[BeatmapsetFetcher]</blue> Cache miss, fetching from API")
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
if query.cursor_string:
params["cursor_string"] = query.cursor_string
else:
for k, v in cursor.items():
params[f"cursor[{k}]"] = v
api_response = await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search",
params=params,
)
# 处理响应中的cursor信息
if api_response.get("cursor"):
cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
# 将结果缓存 15 分钟
cache_ttl = 15 * 60 # 15 分钟
await redis_client.set(cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl)
logger.opt(colors=True).debug(
f"<green>[BeatmapsetFetcher]</green> Cached result for key: <y>{cache_key}</y> (TTL: {cache_ttl}s)"
)
resp = SearchBeatmapsetsResp.model_validate(api_response)
# 智能预取只在用户明确搜索时才预取避免过多API请求
# 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
if api_response.get("cursor") and (query.q or query.s != "leaderboard" or cursor):
# 在后台预取下1页减少预取量
import asyncio
# 不立即创建任务,而是延迟一段时间再预取
async def delayed_prefetch():
await asyncio.sleep(3.0) # 延迟3秒
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
except RateLimitError:
logger.opt(colors=True).info("<yellow>[BeatmapsetFetcher]</yellow> Prefetch skipped due to rate limit")
bg_tasks.add_task(delayed_prefetch)
return resp
async def prefetch_next_pages(
self,
query: SearchQueryModel,
current_cursor: Cursor,
redis_client: redis.Redis,
pages: int = 3,
) -> None:
"""预取下几页内容"""
if not current_cursor:
return
try:
cursor = current_cursor.copy()
for page in range(1, pages + 1):
# 使用当前 cursor 请求下一页
next_query = query.model_copy()
logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Prefetching page {page + 1}")
# 生成下一页的缓存键
next_cache_key = self._generate_cache_key(next_query, cursor)
# 检查是否已经缓存
if await redis_client.exists(next_cache_key):
logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Page {page + 1} already cached")
# 尝试从缓存获取cursor继续预取
cached_data = await redis_client.get(next_cache_key)
if cached_data:
try:
data = json.loads(cached_data)
if data.get("cursor"):
cursor = data["cursor"]
continue
except Exception:
pass
break
# 在预取页面之间添加延迟,避免突发请求
if page > 1:
await asyncio.sleep(1.5) # 1.5秒延迟
# 请求下一页数据
params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
for k, v in cursor.items():
params[f"cursor[{k}]"] = v
api_response = await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search",
params=params,
)
# 处理响应中的cursor信息
if api_response.get("cursor"):
cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
cursor = cursor_dict # 更新cursor用于下一页
else:
# 没有更多页面了
break
# 缓存结果较短的TTL用于预取
prefetch_ttl = 10 * 60 # 10 分钟
await redis_client.set(
next_cache_key,
json.dumps(api_response, separators=(",", ":")),
ex=prefetch_ttl,
)
logger.opt(colors=True).debug(
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetched page {page + 1} (TTL: {prefetch_ttl}s)"
)
except RateLimitError:
logger.opt(colors=True).info("<yellow>[BeatmapsetFetcher]</yellow> Prefetch stopped due to rate limit")
except Exception as e:
logger.opt(colors=True).warning(f"<yellow>[BeatmapsetFetcher]</yellow> Prefetch failed: {e}")
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
"""预热主页缓存"""
homepage_queries = self._get_homepage_queries()
logger.opt(colors=True).info(
f"<magenta>[BeatmapsetFetcher]</magenta> Starting homepage cache warmup ({len(homepage_queries)} queries)"
)
for i, (query, cursor) in enumerate(homepage_queries):
try:
# 在请求之间添加延迟,避免突发请求
if i > 0:
await asyncio.sleep(2.0) # 2秒延迟
cache_key = self._generate_cache_key(query, cursor)
# 检查是否已经缓存
if await redis_client.exists(cache_key):
logger.opt(colors=True).debug(
f"<magenta>[BeatmapsetFetcher]</magenta> Query {query.sort} already cached"
)
continue
# 请求并缓存
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
api_response = await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search",
params=params,
)
if api_response.get("cursor"):
cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
# 缓存结果
cache_ttl = 20 * 60 # 20 分钟
await redis_client.set(
cache_key,
json.dumps(api_response, separators=(",", ":")),
ex=cache_ttl,
)
logger.opt(colors=True).info(
f"<magenta>[BeatmapsetFetcher]</magenta> Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
)
if api_response.get("cursor"):
try:
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
except RateLimitError:
logger.opt(colors=True).info(
f"<yellow>[BeatmapsetFetcher]</yellow> Warmup prefetch skipped for {query.sort} due to rate limit"
)
except RateLimitError:
logger.opt(colors=True).warning(
f"<yellow>[BeatmapsetFetcher]</yellow> Warmup skipped for {query.sort} due to rate limit"
)
except Exception as e:
logger.opt(colors=True).error(
f"<red>[BeatmapsetFetcher]</red> Failed to warmup cache for {query.sort}: {e}"
)