Files
g0v0-server/app/fetcher/_base.py
2025-11-30 00:43:33 +08:00

237 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
from datetime import datetime
import time
from app.dependencies.database import get_redis
from app.log import fetcher_logger
from httpx import AsyncClient, HTTPStatusError, TimeoutException
class TokenAuthError(Exception):
"""Token 授权失败异常"""
pass
class PassiveRateLimiter:
"""
被动速率限制器
当收到 429 响应时,读取 Retry-After 头并暂停所有请求
"""
def __init__(self):
self._lock = asyncio.Lock()
self._retry_after_time: float | None = None
self._waiting_tasks: set[asyncio.Task] = set()
async def wait_if_limited(self) -> None:
"""如果正在限流中,等待限流解除"""
async with self._lock:
if self._retry_after_time is not None:
current_time = time.time()
if current_time < self._retry_after_time:
wait_seconds = self._retry_after_time - current_time
logger.warning(f"Rate limited, waiting {wait_seconds:.2f} seconds")
await asyncio.sleep(wait_seconds)
self._retry_after_time = None
async def handle_rate_limit(self, retry_after: str | int | None) -> None:
"""
处理 429 响应,设置限流时间
Args:
retry_after: Retry-After 头的值,可以是秒数或 HTTP 日期
"""
async with self._lock:
if retry_after is None:
# 如果没有 Retry-After 头,默认等待 60 秒
wait_seconds = 60
elif isinstance(retry_after, int):
wait_seconds = retry_after
elif retry_after.isdigit():
wait_seconds = int(retry_after)
else:
# 尝试解析 HTTP 日期格式
try:
retry_time = datetime.strptime(retry_after, "%a, %d %b %Y %H:%M:%S %Z")
wait_seconds = max(0, (retry_time - datetime.utcnow()).total_seconds())
except ValueError:
# 解析失败,默认等待 60 秒
wait_seconds = 60
self._retry_after_time = time.time() + wait_seconds
logger.warning(f"Rate limit triggered, will retry after {wait_seconds} seconds")
logger = fetcher_logger("Fetcher")
class BaseFetcher:
# 类级别的 rate limiter所有实例共享
_rate_limiter = PassiveRateLimiter()
def __init__(
self,
client_id: str,
client_secret: str,
scope: list[str] = ["public"],
callback_url: str = "",
):
self.client_id = client_id
self.client_secret = client_secret
self.access_token: str = ""
self.refresh_token: str = ""
self.token_expiry: int = 0
self.callback_url: str = callback_url
self.scope = scope
self._token_lock = asyncio.Lock()
# NOTE: Reserve for user-based fetchers
# @property
# def authorize_url(self) -> str:
# return (
# f"https://osu.ppy.sh/oauth/authorize?client_id={self.client_id}"
# f"&response_type=code&scope={quote(' '.join(self.scope))}"
# f"&redirect_uri={self.callback_url}"
# )
@property
def header(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self.access_token}",
"Content-Type": "application/json",
}
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
"""
发送 API 请求,支持被动速率限制
"""
await self.ensure_valid_access_token()
headers = kwargs.pop("headers", {}).copy()
attempt = 0
while attempt < 2:
# 在发送请求前等待速率限制
await self._rate_limiter.wait_if_limited()
request_headers = {**headers, **self.header}
request_kwargs = kwargs.copy()
async with AsyncClient() as client:
try:
response = await client.request(
method,
url,
headers=request_headers,
**request_kwargs,
)
response.raise_for_status()
return response.json()
except HTTPStatusError as e:
# 处理 429 速率限制响应
if e.response.status_code == 429:
retry_after = e.response.headers.get("Retry-After")
logger.warning(f"Rate limited for {url}, Retry-After: {retry_after}")
await self._rate_limiter.handle_rate_limit(retry_after)
# 速率限制后重试当前请求(不增加 attempt
continue
# 处理 401 未授权响应
if e.response.status_code == 401:
attempt += 1
logger.warning(f"Received 401 error for {url}, attempt {attempt}")
await self._handle_unauthorized()
continue
# 其他 HTTP 错误直接抛出
raise
await self._clear_access_token()
logger.warning(f"Failed to authorize after retries for {url}, cleaned up tokens")
await self.grant_access_token()
raise TokenAuthError(f"Failed to authorize after retries for {url}")
def is_token_expired(self) -> bool:
if not isinstance(self.token_expiry, int):
return True
return self.token_expiry <= int(time.time()) or not self.access_token
async def grant_access_token(self, retries: int = 3, backoff: float = 1.0) -> None:
last_error: Exception | None = None
async with AsyncClient(timeout=30.0) as client:
for attempt in range(1, retries + 1):
try:
response = await client.post(
"https://osu.ppy.sh/oauth/token",
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "client_credentials",
"scope": "public",
},
)
response.raise_for_status()
token_data = response.json()
self.access_token = token_data["access_token"]
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
await redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
await redis.set(
f"fetcher:expire_at:{self.client_id}",
self.token_expiry,
ex=token_data["expires_in"],
)
logger.success(
f"Granted new access token for client {self.client_id}, "
f"expires in {token_data['expires_in']} seconds"
)
return
except TimeoutException as exc:
last_error = exc
logger.warning(
f"Timed out while requesting access token for "
f"client {self.client_id} (attempt {attempt}/{retries})"
)
except HTTPStatusError as exc:
last_error = exc
logger.warning(
f"HTTP error while requesting access token for client {self.client_id}"
f" (status: {exc.response.status_code}, attempt {attempt}/{retries})"
)
except Exception as exc:
last_error = exc
logger.exception(
f"Unexpected error while requesting access token for client {self.client_id}"
f" (attempt {attempt}/{retries})"
)
if attempt < retries:
await asyncio.sleep(backoff * attempt)
raise TokenAuthError("Failed to grant access token after retries") from last_error
async def ensure_valid_access_token(self) -> None:
if self.is_token_expired():
await self.grant_access_token()
async def _handle_unauthorized(self) -> None:
await self.grant_access_token()
async def _clear_access_token(self) -> None:
logger.warning(f"Clearing access token for client {self.client_id}")
self.access_token = ""
self.token_expiry = 0
redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:expire_at:{self.client_id}")