237 lines
8.8 KiB
Python
237 lines
8.8 KiB
Python
import asyncio
|
||
from datetime import datetime
|
||
import time
|
||
|
||
from app.dependencies.database import get_redis
|
||
from app.log import fetcher_logger
|
||
|
||
from httpx import AsyncClient, HTTPStatusError, TimeoutException
|
||
|
||
|
||
class TokenAuthError(Exception):
|
||
"""Token 授权失败异常"""
|
||
|
||
pass
|
||
|
||
|
||
class PassiveRateLimiter:
|
||
"""
|
||
被动速率限制器
|
||
当收到 429 响应时,读取 Retry-After 头并暂停所有请求
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._lock = asyncio.Lock()
|
||
self._retry_after_time: float | None = None
|
||
self._waiting_tasks: set[asyncio.Task] = set()
|
||
|
||
async def wait_if_limited(self) -> None:
|
||
"""如果正在限流中,等待限流解除"""
|
||
async with self._lock:
|
||
if self._retry_after_time is not None:
|
||
current_time = time.time()
|
||
if current_time < self._retry_after_time:
|
||
wait_seconds = self._retry_after_time - current_time
|
||
logger.warning(f"Rate limited, waiting {wait_seconds:.2f} seconds")
|
||
await asyncio.sleep(wait_seconds)
|
||
self._retry_after_time = None
|
||
|
||
async def handle_rate_limit(self, retry_after: str | int | None) -> None:
|
||
"""
|
||
处理 429 响应,设置限流时间
|
||
|
||
Args:
|
||
retry_after: Retry-After 头的值,可以是秒数或 HTTP 日期
|
||
"""
|
||
async with self._lock:
|
||
if retry_after is None:
|
||
# 如果没有 Retry-After 头,默认等待 60 秒
|
||
wait_seconds = 60
|
||
elif isinstance(retry_after, int):
|
||
wait_seconds = retry_after
|
||
elif retry_after.isdigit():
|
||
wait_seconds = int(retry_after)
|
||
else:
|
||
# 尝试解析 HTTP 日期格式
|
||
try:
|
||
retry_time = datetime.strptime(retry_after, "%a, %d %b %Y %H:%M:%S %Z")
|
||
wait_seconds = max(0, (retry_time - datetime.utcnow()).total_seconds())
|
||
except ValueError:
|
||
# 解析失败,默认等待 60 秒
|
||
wait_seconds = 60
|
||
|
||
self._retry_after_time = time.time() + wait_seconds
|
||
logger.warning(f"Rate limit triggered, will retry after {wait_seconds} seconds")
|
||
|
||
|
||
logger = fetcher_logger("Fetcher")
|
||
|
||
|
||
class BaseFetcher:
|
||
# 类级别的 rate limiter,所有实例共享
|
||
_rate_limiter = PassiveRateLimiter()
|
||
|
||
def __init__(
|
||
self,
|
||
client_id: str,
|
||
client_secret: str,
|
||
scope: list[str] = ["public"],
|
||
callback_url: str = "",
|
||
):
|
||
self.client_id = client_id
|
||
self.client_secret = client_secret
|
||
self.access_token: str = ""
|
||
self.refresh_token: str = ""
|
||
self.token_expiry: int = 0
|
||
self.callback_url: str = callback_url
|
||
self.scope = scope
|
||
self._token_lock = asyncio.Lock()
|
||
|
||
# NOTE: Reserve for user-based fetchers
|
||
# @property
|
||
# def authorize_url(self) -> str:
|
||
# return (
|
||
# f"https://osu.ppy.sh/oauth/authorize?client_id={self.client_id}"
|
||
# f"&response_type=code&scope={quote(' '.join(self.scope))}"
|
||
# f"&redirect_uri={self.callback_url}"
|
||
# )
|
||
|
||
@property
|
||
def header(self) -> dict[str, str]:
|
||
return {
|
||
"Authorization": f"Bearer {self.access_token}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
|
||
"""
|
||
发送 API 请求,支持被动速率限制
|
||
"""
|
||
await self.ensure_valid_access_token()
|
||
|
||
headers = kwargs.pop("headers", {}).copy()
|
||
attempt = 0
|
||
|
||
while attempt < 2:
|
||
# 在发送请求前等待速率限制
|
||
await self._rate_limiter.wait_if_limited()
|
||
|
||
request_headers = {**headers, **self.header}
|
||
request_kwargs = kwargs.copy()
|
||
|
||
async with AsyncClient() as client:
|
||
try:
|
||
response = await client.request(
|
||
method,
|
||
url,
|
||
headers=request_headers,
|
||
**request_kwargs,
|
||
)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
|
||
except HTTPStatusError as e:
|
||
# 处理 429 速率限制响应
|
||
if e.response.status_code == 429:
|
||
retry_after = e.response.headers.get("Retry-After")
|
||
logger.warning(f"Rate limited for {url}, Retry-After: {retry_after}")
|
||
await self._rate_limiter.handle_rate_limit(retry_after)
|
||
# 速率限制后重试当前请求(不增加 attempt)
|
||
continue
|
||
|
||
# 处理 401 未授权响应
|
||
if e.response.status_code == 401:
|
||
attempt += 1
|
||
logger.warning(f"Received 401 error for {url}, attempt {attempt}")
|
||
await self._handle_unauthorized()
|
||
continue
|
||
|
||
# 其他 HTTP 错误直接抛出
|
||
raise
|
||
|
||
await self._clear_access_token()
|
||
logger.warning(f"Failed to authorize after retries for {url}, cleaned up tokens")
|
||
await self.grant_access_token()
|
||
raise TokenAuthError(f"Failed to authorize after retries for {url}")
|
||
|
||
def is_token_expired(self) -> bool:
|
||
if not isinstance(self.token_expiry, int):
|
||
return True
|
||
return self.token_expiry <= int(time.time()) or not self.access_token
|
||
|
||
async def grant_access_token(self, retries: int = 3, backoff: float = 1.0) -> None:
|
||
last_error: Exception | None = None
|
||
async with AsyncClient(timeout=30.0) as client:
|
||
for attempt in range(1, retries + 1):
|
||
try:
|
||
response = await client.post(
|
||
"https://osu.ppy.sh/oauth/token",
|
||
data={
|
||
"client_id": self.client_id,
|
||
"client_secret": self.client_secret,
|
||
"grant_type": "client_credentials",
|
||
"scope": "public",
|
||
},
|
||
)
|
||
response.raise_for_status()
|
||
token_data = response.json()
|
||
self.access_token = token_data["access_token"]
|
||
self.token_expiry = int(time.time()) + token_data["expires_in"]
|
||
redis = get_redis()
|
||
await redis.set(
|
||
f"fetcher:access_token:{self.client_id}",
|
||
self.access_token,
|
||
ex=token_data["expires_in"],
|
||
)
|
||
await redis.set(
|
||
f"fetcher:expire_at:{self.client_id}",
|
||
self.token_expiry,
|
||
ex=token_data["expires_in"],
|
||
)
|
||
logger.success(
|
||
f"Granted new access token for client {self.client_id}, "
|
||
f"expires in {token_data['expires_in']} seconds"
|
||
)
|
||
return
|
||
|
||
except TimeoutException as exc:
|
||
last_error = exc
|
||
logger.warning(
|
||
f"Timed out while requesting access token for "
|
||
f"client {self.client_id} (attempt {attempt}/{retries})"
|
||
)
|
||
except HTTPStatusError as exc:
|
||
last_error = exc
|
||
logger.warning(
|
||
f"HTTP error while requesting access token for client {self.client_id}"
|
||
f" (status: {exc.response.status_code}, attempt {attempt}/{retries})"
|
||
)
|
||
except Exception as exc:
|
||
last_error = exc
|
||
logger.exception(
|
||
f"Unexpected error while requesting access token for client {self.client_id}"
|
||
f" (attempt {attempt}/{retries})"
|
||
)
|
||
|
||
if attempt < retries:
|
||
await asyncio.sleep(backoff * attempt)
|
||
|
||
raise TokenAuthError("Failed to grant access token after retries") from last_error
|
||
|
||
async def ensure_valid_access_token(self) -> None:
|
||
if self.is_token_expired():
|
||
await self.grant_access_token()
|
||
|
||
async def _handle_unauthorized(self) -> None:
|
||
await self.grant_access_token()
|
||
|
||
async def _clear_access_token(self) -> None:
|
||
logger.warning(f"Clearing access token for client {self.client_id}")
|
||
|
||
self.access_token = ""
|
||
self.token_expiry = 0
|
||
|
||
redis = get_redis()
|
||
await redis.delete(f"fetcher:access_token:{self.client_id}")
|
||
await redis.delete(f"fetcher:expire_at:{self.client_id}")
|