feat(fetcher): use client_credentials grant type to avoid missing refresh token (#62)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
MingxuanGame
2025-10-25 20:01:50 +08:00
committed by GitHub
parent 2c81e22749
commit 8f4a9d5fed
6 changed files with 67 additions and 137 deletions

View File

@@ -10,7 +10,7 @@ from app.models.beatmap import SearchQueryModel
from app.models.model import Cursor
from app.utils import bg_tasks
from ._base import BaseFetcher, TokenAuthError
from ._base import BaseFetcher
from httpx import AsyncClient
import redis.asyncio as redis
@@ -25,6 +25,9 @@ class RateLimitError(Exception):
logger = fetcher_logger("BeatmapsetFetcher")
MAX_RETRY_ATTEMPTS = 3
class BeatmapsetFetcher(BaseFetcher):
@staticmethod
def _get_homepage_queries() -> list[tuple[SearchQueryModel, Cursor]]:
@@ -46,14 +49,17 @@ class BeatmapsetFetcher(BaseFetcher):
return homepage_queries
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
async def request_api(self, url: str, method: str = "GET", *, retry_times: int = 0, **kwargs) -> dict:
"""覆盖基类方法添加速率限制和429错误处理"""
# 在请求前获取速率限制许可
if retry_times > MAX_RETRY_ATTEMPTS:
raise RuntimeError(f"Maximum retry attempts ({MAX_RETRY_ATTEMPTS}) reached for API request to {url}")
await osu_api_rate_limiter.acquire()
# 检查 token 是否过期,如果过期则刷新
if self.is_token_expired():
await self.refresh_access_token()
await self.grant_access_token()
header = kwargs.pop("headers", {})
header.update(self.header)
@@ -70,12 +76,10 @@ class BeatmapsetFetcher(BaseFetcher):
if response.status_code == 429:
logger.warning(f"Rate limit exceeded (429) for {url}")
raise RateLimitError(f"Rate limit exceeded for {url}. Please try again later.")
# 处理 401 错误
if response.status_code == 401:
logger.warning(f"Received 401 error for {url}")
await self._clear_tokens()
raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}")
await self._clear_access_token()
return await self.request_api(url, method, retry_times=retry_times + 1, **kwargs)
response.raise_for_status()
return response.json()