From 8f4a9d5fed6a508b3f3682f1ef8885f4e4d8db0d Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 25 Oct 2025 20:01:50 +0800 Subject: [PATCH] feat(fetcher): use `client_credentials` grant type to avoid missing refresh token (#62) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- app/config.py | 37 ++++++----- app/dependencies/fetcher.py | 15 ++--- app/fetcher/_base.py | 129 ++++++++---------------------------- app/fetcher/beatmapset.py | 18 +++-- app/router/fetcher.py | 2 +- main.py | 3 +- 6 files changed, 67 insertions(+), 137 deletions(-) diff --git a/app/config.py b/app/config.py index d2d5daf..2f0791f 100644 --- a/app/config.py +++ b/app/config.py @@ -8,7 +8,7 @@ from pydantic import ( ValidationInfo, field_validator, ) -from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict +from pydantic_settings import BaseSettings, SettingsConfigDict class AWSS3StorageSettings(BaseSettings): @@ -302,16 +302,26 @@ CALCULATOR_CONFIG='{ Field(default="", description="Fetcher 客户端密钥"), "Fetcher 设置", ] - fetcher_scopes: Annotated[ - list[str], - Field(default=["public"], description="Fetcher 权限范围,以逗号分隔每个权限"), - "Fetcher 设置", - NoDecode, - ] - @property - def fetcher_callback_url(self) -> str: - return f"{self.server_url}fetcher/callback" + # NOTE: Reserve for user-based-fetcher + + # fetcher_scopes: Annotated[ + # list[str], + # Field(default=["public"], description="Fetcher 权限范围,以逗号分隔每个权限"), + # "Fetcher 设置", + # NoDecode, + # ] + + # @field_validator("fetcher_scopes", mode="before") + # @classmethod + # def validate_fetcher_scopes(cls, v: Any) -> list[str]: + # if isinstance(v, str): + # return v.split(",") + # return v + + # @property + # def fetcher_callback_url(self) -> str: + # return f"{self.server_url}fetcher/callback" # 日志设置 log_level: Annotated[ @@ -690,13 +700,6 @@ CALCULATOR_CONFIG='{ "存储服务设置", ] - @field_validator("fetcher_scopes", mode="before") - @classmethod - def validate_fetcher_scopes(cls, v: Any) -> list[str]: - if isinstance(v, str): - return v.split(",") - return v - @field_validator("storage_settings", mode="after") @classmethod def validate_storage_settings( diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py index 9aecbb4..f9c91f1 100644 --- a/app/dependencies/fetcher.py +++ b/app/dependencies/fetcher.py @@ -3,7 +3,6 @@ from typing import Annotated from app.config import settings from app.dependencies.database import get_redis from app.fetcher import Fetcher as OriginFetcher -from app.log import fetcher_logger from fastapi import Depends @@ -16,20 +15,16 @@ async def get_fetcher() -> OriginFetcher: fetcher = OriginFetcher( settings.fetcher_client_id, settings.fetcher_client_secret, - settings.fetcher_scopes, - settings.fetcher_callback_url, ) redis = get_redis() access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}") + expire_at = await redis.get(f"fetcher:expire_at:{fetcher.client_id}") + if expire_at: + fetcher.token_expiry = int(float(expire_at)) if access_token: fetcher.access_token = str(access_token) - refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}") - if refresh_token: - fetcher.refresh_token = str(refresh_token) - if not fetcher.access_token or not fetcher.refresh_token: - fetcher_logger("Fetcher").opt(colors=True).info( - f"Login to initialize fetcher: {fetcher.authorize_url}" - ) + # Always ensure the access token is valid, regardless of initial state + await fetcher.ensure_valid_access_token() return fetcher diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py index 4f9f7e5..1757cb9 100644 --- a/app/fetcher/_base.py +++ b/app/fetcher/_base.py @@ -1,6 +1,5 @@ import asyncio import time -from urllib.parse import quote from app.dependencies.database import get_redis from app.log import fetcher_logger @@ -34,13 +33,14 @@ class BaseFetcher: self.scope = scope self._token_lock = asyncio.Lock() - @property - def authorize_url(self) -> str: - return ( - f"https://osu.ppy.sh/oauth/authorize?client_id={self.client_id}" - f"&response_type=code&scope={quote(' '.join(self.scope))}" - f"&redirect_uri={self.callback_url}" - ) + # NOTE: Reserve for user-based fetchers + # @property + # def authorize_url(self) -> str: + # return ( + # f"https://osu.ppy.sh/oauth/authorize?client_id={self.client_id}" + # f"&response_type=code&scope={quote(' '.join(self.scope))}" + # f"&redirect_uri={self.callback_url}" + # ) @property def header(self) -> dict[str, str]: @@ -53,7 +53,7 @@ class BaseFetcher: """ 发送 API 请求 """ - await self._ensure_valid_access_token() + await self.ensure_valid_access_token() headers = kwargs.pop("headers", {}).copy() attempt = 0 @@ -78,28 +78,30 @@ class BaseFetcher: logger.warning(f"Received 401 error for {url}, attempt {attempt}") await self._handle_unauthorized() - await self._clear_tokens() - raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}") + await self._clear_access_token() + logger.warning(f"Failed to authorize after retries for {url}, cleaned up tokens") + await self.grant_access_token() + raise TokenAuthError(f"Failed to authorize after retries for {url}") def is_token_expired(self) -> bool: - return self.token_expiry <= int(time.time()) + if not isinstance(self.token_expiry, int): + return True + return self.token_expiry <= int(time.time()) or not self.access_token - async def grant_access_token(self, code: str) -> None: + async def grant_access_token(self) -> None: async with AsyncClient() as client: response = await client.post( "https://osu.ppy.sh/oauth/token", data={ "client_id": self.client_id, "client_secret": self.client_secret, - "grant_type": "authorization_code", - "redirect_uri": self.callback_url, - "code": code, + "grant_type": "client_credentials", + "scope": "public", }, ) response.raise_for_status() token_data = response.json() self.access_token = token_data["access_token"] - self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() await redis.set( @@ -108,66 +110,20 @@ class BaseFetcher: ex=token_data["expires_in"], ) await redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, + f"fetcher:expire_at:{self.client_id}", + self.token_expiry, + ex=token_data["expires_in"], + ) + logger.success( + f"Granted new access token for client {self.client_id}, expires in {token_data['expires_in']} seconds" ) - async def refresh_access_token(self, *, force: bool = False) -> None: - if not force and not self.is_token_expired(): - return - - async with self._token_lock: - if not force and not self.is_token_expired(): - return - - if force: - await self._clear_access_token() - - if not self.refresh_token: - logger.error(f"Missing refresh token for client {self.client_id}") - await self._clear_tokens() - raise TokenAuthError(f"Missing refresh token. Please re-authorize using: {self.authorize_url}") - - try: - logger.info(f"Refreshing access token for client {self.client_id}") - async with AsyncClient() as client: - response = await client.post( - "https://osu.ppy.sh/oauth/token", - data={ - "client_id": self.client_id, - "client_secret": self.client_secret, - "grant_type": "refresh_token", - "refresh_token": self.refresh_token, - }, - ) - response.raise_for_status() - token_data = response.json() - self.access_token = token_data["access_token"] - self.refresh_token = token_data.get("refresh_token", self.refresh_token) - self.token_expiry = int(time.time()) + token_data["expires_in"] - redis = get_redis() - await redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - await redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) - logger.info(f"Successfully refreshed access token for client {self.client_id}") - except Exception as e: - logger.error(f"Failed to refresh access token for client {self.client_id}: {e}") - await self._clear_tokens() - logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}") - raise - - async def _ensure_valid_access_token(self) -> None: + async def ensure_valid_access_token(self) -> None: if self.is_token_expired(): - await self.refresh_access_token() + await self.grant_access_token() async def _handle_unauthorized(self) -> None: - await self.refresh_access_token(force=True) + await self.grant_access_token() async def _clear_access_token(self) -> None: logger.warning(f"Clearing access token for client {self.client_id}") @@ -177,31 +133,4 @@ class BaseFetcher: redis = get_redis() await redis.delete(f"fetcher:access_token:{self.client_id}") - - async def _clear_tokens(self) -> None: - """ - 清除所有 token - """ - logger.warning(f"Clearing tokens for client {self.client_id}") - - # 清除内存中的 token - self.access_token = "" - self.refresh_token = "" - self.token_expiry = 0 - - # 清除 Redis 中的 token - redis = get_redis() - await redis.delete(f"fetcher:access_token:{self.client_id}") - await redis.delete(f"fetcher:refresh_token:{self.client_id}") - - def get_auth_status(self) -> dict: - """ - 获取当前授权状态信息 - """ - return { - "client_id": self.client_id, - "has_access_token": bool(self.access_token), - "has_refresh_token": bool(self.refresh_token), - "token_expired": self.is_token_expired(), - "authorize_url": self.authorize_url, - } + await redis.delete(f"fetcher:expire_at:{self.client_id}") diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py index b228930..893552f 100644 --- a/app/fetcher/beatmapset.py +++ b/app/fetcher/beatmapset.py @@ -10,7 +10,7 @@ from app.models.beatmap import SearchQueryModel from app.models.model import Cursor from app.utils import bg_tasks -from ._base import BaseFetcher, TokenAuthError +from ._base import BaseFetcher from httpx import AsyncClient import redis.asyncio as redis @@ -25,6 +25,9 @@ class RateLimitError(Exception): logger = fetcher_logger("BeatmapsetFetcher") +MAX_RETRY_ATTEMPTS = 3 + + class BeatmapsetFetcher(BaseFetcher): @staticmethod def _get_homepage_queries() -> list[tuple[SearchQueryModel, Cursor]]: @@ -46,14 +49,17 @@ class BeatmapsetFetcher(BaseFetcher): return homepage_queries - async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict: + async def request_api(self, url: str, method: str = "GET", *, retry_times: int = 0, **kwargs) -> dict: """覆盖基类方法,添加速率限制和429错误处理""" # 在请求前获取速率限制许可 + if retry_times > MAX_RETRY_ATTEMPTS: + raise RuntimeError(f"Maximum retry attempts ({MAX_RETRY_ATTEMPTS}) reached for API request to {url}") + await osu_api_rate_limiter.acquire() # 检查 token 是否过期,如果过期则刷新 if self.is_token_expired(): - await self.refresh_access_token() + await self.grant_access_token() header = kwargs.pop("headers", {}) header.update(self.header) @@ -70,12 +76,10 @@ class BeatmapsetFetcher(BaseFetcher): if response.status_code == 429: logger.warning(f"Rate limit exceeded (429) for {url}") raise RateLimitError(f"Rate limit exceeded for {url}. Please try again later.") - - # 处理 401 错误 if response.status_code == 401: logger.warning(f"Received 401 error for {url}") - await self._clear_tokens() - raise TokenAuthError(f"Authentication failed. Please re-authorize using: {self.authorize_url}") + await self._clear_access_token() + return await self.request_api(url, method, retry_times=retry_times + 1, **kwargs) response.raise_for_status() return response.json() diff --git a/app/router/fetcher.py b/app/router/fetcher.py index 23bf8e6..2ebb5d3 100644 --- a/app/router/fetcher.py +++ b/app/router/fetcher.py @@ -7,5 +7,5 @@ fetcher_router = APIRouter(prefix="/fetcher", include_in_schema=False) @fetcher_router.get("/callback") async def callback(code: str, fetcher: Fetcher): - await fetcher.grant_access_token(code) + # await fetcher.grant_access_token(code) return {"message": "Login successful"} diff --git a/main.py b/main.py index 05efd64..dded8c4 100644 --- a/main.py +++ b/main.py @@ -23,7 +23,6 @@ from app.router import ( api_v2_router, auth_router, chat_router, - fetcher_router, file_router, lio_router, private_router, @@ -184,7 +183,7 @@ app.include_router(api_v1_router) app.include_router(api_v1_public_router) app.include_router(chat_router) app.include_router(redirect_api_router) -app.include_router(fetcher_router) +# app.include_router(fetcher_router) app.include_router(file_router) app.include_router(auth_router) app.include_router(private_router)