feat(fetcher): use client_credentials grant type to avoid missing refresh token (#62)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
MingxuanGame
2025-10-25 20:01:50 +08:00
committed by GitHub
parent 2c81e22749
commit 8f4a9d5fed
6 changed files with 67 additions and 137 deletions

View File

@@ -1,6 +1,5 @@
import asyncio
import time
from urllib.parse import quote
from app.dependencies.database import get_redis
from app.log import fetcher_logger
@@ -34,13 +33,14 @@ class BaseFetcher:
self.scope = scope
self._token_lock = asyncio.Lock()
@property
def authorize_url(self) -> str:
return (
f"https://osu.ppy.sh/oauth/authorize?client_id={self.client_id}"
f"&response_type=code&scope={quote(' '.join(self.scope))}"
f"&redirect_uri={self.callback_url}"
)
# NOTE: Reserve for user-based fetchers
# @property
# def authorize_url(self) -> str:
# return (
# f"https://osu.ppy.sh/oauth/authorize?client_id={self.client_id}"
# f"&response_type=code&scope={quote(' '.join(self.scope))}"
# f"&redirect_uri={self.callback_url}"
# )
@property
def header(self) -> dict[str, str]:
@@ -53,7 +53,7 @@ class BaseFetcher:
"""
发送 API 请求
"""
await self._ensure_valid_access_token()
await self.ensure_valid_access_token()
headers = kwargs.pop("headers", {}).copy()
attempt = 0
@@ -78,28 +78,30 @@ class BaseFetcher:
logger.warning(f"Received 401 error for {url}, attempt {attempt}")
await self._handle_unauthorized()
await self._clear_tokens()
raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}")
await self._clear_access_token()
logger.warning(f"Failed to authorize after retries for {url}, cleaned up tokens")
await self.grant_access_token()
raise TokenAuthError(f"Failed to authorize after retries for {url}")
def is_token_expired(self) -> bool:
return self.token_expiry <= int(time.time())
if not isinstance(self.token_expiry, int):
return True
return self.token_expiry <= int(time.time()) or not self.access_token
async def grant_access_token(self, code: str) -> None:
async def grant_access_token(self) -> None:
async with AsyncClient() as client:
response = await client.post(
"https://osu.ppy.sh/oauth/token",
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "authorization_code",
"redirect_uri": self.callback_url,
"code": code,
"grant_type": "client_credentials",
"scope": "public",
},
)
response.raise_for_status()
token_data = response.json()
self.access_token = token_data["access_token"]
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
await redis.set(
@@ -108,66 +110,20 @@ class BaseFetcher:
ex=token_data["expires_in"],
)
await redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
f"fetcher:expire_at:{self.client_id}",
self.token_expiry,
ex=token_data["expires_in"],
)
logger.success(
f"Granted new access token for client {self.client_id}, expires in {token_data['expires_in']} seconds"
)
async def refresh_access_token(self, *, force: bool = False) -> None:
if not force and not self.is_token_expired():
return
async with self._token_lock:
if not force and not self.is_token_expired():
return
if force:
await self._clear_access_token()
if not self.refresh_token:
logger.error(f"Missing refresh token for client {self.client_id}")
await self._clear_tokens()
raise TokenAuthError(f"Missing refresh token. Please re-authorize using: {self.authorize_url}")
try:
logger.info(f"Refreshing access token for client {self.client_id}")
async with AsyncClient() as client:
response = await client.post(
"https://osu.ppy.sh/oauth/token",
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "refresh_token",
"refresh_token": self.refresh_token,
},
)
response.raise_for_status()
token_data = response.json()
self.access_token = token_data["access_token"]
self.refresh_token = token_data.get("refresh_token", self.refresh_token)
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
await redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
await redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
logger.info(f"Successfully refreshed access token for client {self.client_id}")
except Exception as e:
logger.error(f"Failed to refresh access token for client {self.client_id}: {e}")
await self._clear_tokens()
logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}")
raise
async def _ensure_valid_access_token(self) -> None:
async def ensure_valid_access_token(self) -> None:
if self.is_token_expired():
await self.refresh_access_token()
await self.grant_access_token()
async def _handle_unauthorized(self) -> None:
await self.refresh_access_token(force=True)
await self.grant_access_token()
async def _clear_access_token(self) -> None:
logger.warning(f"Clearing access token for client {self.client_id}")
@@ -177,31 +133,4 @@ class BaseFetcher:
redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}")
async def _clear_tokens(self) -> None:
"""
清除所有 token
"""
logger.warning(f"Clearing tokens for client {self.client_id}")
# 清除内存中的 token
self.access_token = ""
self.refresh_token = ""
self.token_expiry = 0
# 清除 Redis 中的 token
redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
def get_auth_status(self) -> dict:
"""
获取当前授权状态信息
"""
return {
"client_id": self.client_id,
"has_access_token": bool(self.access_token),
"has_refresh_token": bool(self.refresh_token),
"token_expired": self.is_token_expired(),
"authorize_url": self.authorize_url,
}
await redis.delete(f"fetcher:expire_at:{self.client_id}")