feat(fetcher): use client_credentials grant type to avoid missing refresh token (#62)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
MingxuanGame
2025-10-25 20:01:50 +08:00
committed by GitHub
parent 2c81e22749
commit 8f4a9d5fed
6 changed files with 67 additions and 137 deletions

View File

@@ -8,7 +8,7 @@ from pydantic import (
ValidationInfo, ValidationInfo,
field_validator, field_validator,
) )
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class AWSS3StorageSettings(BaseSettings): class AWSS3StorageSettings(BaseSettings):
@@ -302,16 +302,26 @@ CALCULATOR_CONFIG='{
Field(default="", description="Fetcher 客户端密钥"), Field(default="", description="Fetcher 客户端密钥"),
"Fetcher 设置", "Fetcher 设置",
] ]
fetcher_scopes: Annotated[
list[str],
Field(default=["public"], description="Fetcher 权限范围,以逗号分隔每个权限"),
"Fetcher 设置",
NoDecode,
]
@property # NOTE: Reserve for user-based-fetcher
def fetcher_callback_url(self) -> str:
return f"{self.server_url}fetcher/callback" # fetcher_scopes: Annotated[
# list[str],
# Field(default=["public"], description="Fetcher 权限范围,以逗号分隔每个权限"),
# "Fetcher 设置",
# NoDecode,
# ]
# @field_validator("fetcher_scopes", mode="before")
# @classmethod
# def validate_fetcher_scopes(cls, v: Any) -> list[str]:
# if isinstance(v, str):
# return v.split(",")
# return v
# @property
# def fetcher_callback_url(self) -> str:
# return f"{self.server_url}fetcher/callback"
# 日志设置 # 日志设置
log_level: Annotated[ log_level: Annotated[
@@ -690,13 +700,6 @@ CALCULATOR_CONFIG='{
"存储服务设置", "存储服务设置",
] ]
@field_validator("fetcher_scopes", mode="before")
@classmethod
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
if isinstance(v, str):
return v.split(",")
return v
@field_validator("storage_settings", mode="after") @field_validator("storage_settings", mode="after")
@classmethod @classmethod
def validate_storage_settings( def validate_storage_settings(

View File

@@ -3,7 +3,6 @@ from typing import Annotated
from app.config import settings from app.config import settings
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
from app.fetcher import Fetcher as OriginFetcher from app.fetcher import Fetcher as OriginFetcher
from app.log import fetcher_logger
from fastapi import Depends from fastapi import Depends
@@ -16,20 +15,16 @@ async def get_fetcher() -> OriginFetcher:
fetcher = OriginFetcher( fetcher = OriginFetcher(
settings.fetcher_client_id, settings.fetcher_client_id,
settings.fetcher_client_secret, settings.fetcher_client_secret,
settings.fetcher_scopes,
settings.fetcher_callback_url,
) )
redis = get_redis() redis = get_redis()
access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}") access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
expire_at = await redis.get(f"fetcher:expire_at:{fetcher.client_id}")
if expire_at:
fetcher.token_expiry = int(float(expire_at))
if access_token: if access_token:
fetcher.access_token = str(access_token) fetcher.access_token = str(access_token)
refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}") # Always ensure the access token is valid, regardless of initial state
if refresh_token: await fetcher.ensure_valid_access_token()
fetcher.refresh_token = str(refresh_token)
if not fetcher.access_token or not fetcher.refresh_token:
fetcher_logger("Fetcher").opt(colors=True).info(
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
)
return fetcher return fetcher

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import time import time
from urllib.parse import quote
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
from app.log import fetcher_logger from app.log import fetcher_logger
@@ -34,13 +33,14 @@ class BaseFetcher:
self.scope = scope self.scope = scope
self._token_lock = asyncio.Lock() self._token_lock = asyncio.Lock()
@property # NOTE: Reserve for user-based fetchers
def authorize_url(self) -> str: # @property
return ( # def authorize_url(self) -> str:
f"https://osu.ppy.sh/oauth/authorize?client_id={self.client_id}" # return (
f"&response_type=code&scope={quote(' '.join(self.scope))}" # f"https://osu.ppy.sh/oauth/authorize?client_id={self.client_id}"
f"&redirect_uri={self.callback_url}" # f"&response_type=code&scope={quote(' '.join(self.scope))}"
) # f"&redirect_uri={self.callback_url}"
# )
@property @property
def header(self) -> dict[str, str]: def header(self) -> dict[str, str]:
@@ -53,7 +53,7 @@ class BaseFetcher:
""" """
发送 API 请求 发送 API 请求
""" """
await self._ensure_valid_access_token() await self.ensure_valid_access_token()
headers = kwargs.pop("headers", {}).copy() headers = kwargs.pop("headers", {}).copy()
attempt = 0 attempt = 0
@@ -78,28 +78,30 @@ class BaseFetcher:
logger.warning(f"Received 401 error for {url}, attempt {attempt}") logger.warning(f"Received 401 error for {url}, attempt {attempt}")
await self._handle_unauthorized() await self._handle_unauthorized()
await self._clear_tokens() await self._clear_access_token()
raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}") logger.warning(f"Failed to authorize after retries for {url}, cleaned up tokens")
await self.grant_access_token()
raise TokenAuthError(f"Failed to authorize after retries for {url}")
def is_token_expired(self) -> bool: def is_token_expired(self) -> bool:
return self.token_expiry <= int(time.time()) if not isinstance(self.token_expiry, int):
return True
return self.token_expiry <= int(time.time()) or not self.access_token
async def grant_access_token(self, code: str) -> None: async def grant_access_token(self) -> None:
async with AsyncClient() as client: async with AsyncClient() as client:
response = await client.post( response = await client.post(
"https://osu.ppy.sh/oauth/token", "https://osu.ppy.sh/oauth/token",
data={ data={
"client_id": self.client_id, "client_id": self.client_id,
"client_secret": self.client_secret, "client_secret": self.client_secret,
"grant_type": "authorization_code", "grant_type": "client_credentials",
"redirect_uri": self.callback_url, "scope": "public",
"code": code,
}, },
) )
response.raise_for_status() response.raise_for_status()
token_data = response.json() token_data = response.json()
self.access_token = token_data["access_token"] self.access_token = token_data["access_token"]
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"] self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis() redis = get_redis()
await redis.set( await redis.set(
@@ -108,66 +110,20 @@ class BaseFetcher:
ex=token_data["expires_in"], ex=token_data["expires_in"],
) )
await redis.set( await redis.set(
f"fetcher:refresh_token:{self.client_id}", f"fetcher:expire_at:{self.client_id}",
self.refresh_token, self.token_expiry,
ex=token_data["expires_in"],
)
logger.success(
f"Granted new access token for client {self.client_id}, expires in {token_data['expires_in']} seconds"
) )
async def refresh_access_token(self, *, force: bool = False) -> None: async def ensure_valid_access_token(self) -> None:
if not force and not self.is_token_expired():
return
async with self._token_lock:
if not force and not self.is_token_expired():
return
if force:
await self._clear_access_token()
if not self.refresh_token:
logger.error(f"Missing refresh token for client {self.client_id}")
await self._clear_tokens()
raise TokenAuthError(f"Missing refresh token. Please re-authorize using: {self.authorize_url}")
try:
logger.info(f"Refreshing access token for client {self.client_id}")
async with AsyncClient() as client:
response = await client.post(
"https://osu.ppy.sh/oauth/token",
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "refresh_token",
"refresh_token": self.refresh_token,
},
)
response.raise_for_status()
token_data = response.json()
self.access_token = token_data["access_token"]
self.refresh_token = token_data.get("refresh_token", self.refresh_token)
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
await redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
await redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
logger.info(f"Successfully refreshed access token for client {self.client_id}")
except Exception as e:
logger.error(f"Failed to refresh access token for client {self.client_id}: {e}")
await self._clear_tokens()
logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}")
raise
async def _ensure_valid_access_token(self) -> None:
if self.is_token_expired(): if self.is_token_expired():
await self.refresh_access_token() await self.grant_access_token()
async def _handle_unauthorized(self) -> None: async def _handle_unauthorized(self) -> None:
await self.refresh_access_token(force=True) await self.grant_access_token()
async def _clear_access_token(self) -> None: async def _clear_access_token(self) -> None:
logger.warning(f"Clearing access token for client {self.client_id}") logger.warning(f"Clearing access token for client {self.client_id}")
@@ -177,31 +133,4 @@ class BaseFetcher:
redis = get_redis() redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}") await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:expire_at:{self.client_id}")
async def _clear_tokens(self) -> None:
"""
清除所有 token
"""
logger.warning(f"Clearing tokens for client {self.client_id}")
# 清除内存中的 token
self.access_token = ""
self.refresh_token = ""
self.token_expiry = 0
# 清除 Redis 中的 token
redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
def get_auth_status(self) -> dict:
"""
获取当前授权状态信息
"""
return {
"client_id": self.client_id,
"has_access_token": bool(self.access_token),
"has_refresh_token": bool(self.refresh_token),
"token_expired": self.is_token_expired(),
"authorize_url": self.authorize_url,
}

View File

@@ -10,7 +10,7 @@ from app.models.beatmap import SearchQueryModel
from app.models.model import Cursor from app.models.model import Cursor
from app.utils import bg_tasks from app.utils import bg_tasks
from ._base import BaseFetcher, TokenAuthError from ._base import BaseFetcher
from httpx import AsyncClient from httpx import AsyncClient
import redis.asyncio as redis import redis.asyncio as redis
@@ -25,6 +25,9 @@ class RateLimitError(Exception):
logger = fetcher_logger("BeatmapsetFetcher") logger = fetcher_logger("BeatmapsetFetcher")
MAX_RETRY_ATTEMPTS = 3
class BeatmapsetFetcher(BaseFetcher): class BeatmapsetFetcher(BaseFetcher):
@staticmethod @staticmethod
def _get_homepage_queries() -> list[tuple[SearchQueryModel, Cursor]]: def _get_homepage_queries() -> list[tuple[SearchQueryModel, Cursor]]:
@@ -46,14 +49,17 @@ class BeatmapsetFetcher(BaseFetcher):
return homepage_queries return homepage_queries
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict: async def request_api(self, url: str, method: str = "GET", *, retry_times: int = 0, **kwargs) -> dict:
"""覆盖基类方法添加速率限制和429错误处理""" """覆盖基类方法添加速率限制和429错误处理"""
# 在请求前获取速率限制许可 # 在请求前获取速率限制许可
if retry_times > MAX_RETRY_ATTEMPTS:
raise RuntimeError(f"Maximum retry attempts ({MAX_RETRY_ATTEMPTS}) reached for API request to {url}")
await osu_api_rate_limiter.acquire() await osu_api_rate_limiter.acquire()
# 检查 token 是否过期,如果过期则刷新 # 检查 token 是否过期,如果过期则刷新
if self.is_token_expired(): if self.is_token_expired():
await self.refresh_access_token() await self.grant_access_token()
header = kwargs.pop("headers", {}) header = kwargs.pop("headers", {})
header.update(self.header) header.update(self.header)
@@ -70,12 +76,10 @@ class BeatmapsetFetcher(BaseFetcher):
if response.status_code == 429: if response.status_code == 429:
logger.warning(f"Rate limit exceeded (429) for {url}") logger.warning(f"Rate limit exceeded (429) for {url}")
raise RateLimitError(f"Rate limit exceeded for {url}. Please try again later.") raise RateLimitError(f"Rate limit exceeded for {url}. Please try again later.")
# 处理 401 错误
if response.status_code == 401: if response.status_code == 401:
logger.warning(f"Received 401 error for {url}") logger.warning(f"Received 401 error for {url}")
await self._clear_tokens() await self._clear_access_token()
raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}") return await self.request_api(url, method, retry_times=retry_times + 1, **kwargs)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()

View File

@@ -7,5 +7,5 @@ fetcher_router = APIRouter(prefix="/fetcher", include_in_schema=False)
@fetcher_router.get("/callback") @fetcher_router.get("/callback")
async def callback(code: str, fetcher: Fetcher): async def callback(code: str, fetcher: Fetcher):
await fetcher.grant_access_token(code) # await fetcher.grant_access_token(code)
return {"message": "Login successful"} return {"message": "Login successful"}

View File

@@ -23,7 +23,6 @@ from app.router import (
api_v2_router, api_v2_router,
auth_router, auth_router,
chat_router, chat_router,
fetcher_router,
file_router, file_router,
lio_router, lio_router,
private_router, private_router,
@@ -184,7 +183,7 @@ app.include_router(api_v1_router)
app.include_router(api_v1_public_router) app.include_router(api_v1_public_router)
app.include_router(chat_router) app.include_router(chat_router)
app.include_router(redirect_api_router) app.include_router(redirect_api_router)
app.include_router(fetcher_router) # app.include_router(fetcher_router)
app.include_router(file_router) app.include_router(file_router)
app.include_router(auth_router) app.include_router(auth_router)
app.include_router(private_router) app.include_router(private_router)