From b3fff65e35a923409939abe748778947b417f58c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= Date: Mon, 18 Aug 2025 17:41:10 +0800 Subject: [PATCH] fix token --- app/calculator.py | 4 +- app/fetcher/_base.py | 218 ++++++++++++++++++++++++++++++------- app/fetcher/beatmap_raw.py | 6 +- 3 files changed, 186 insertions(+), 42 deletions(-) diff --git a/app/calculator.py b/app/calculator.py index da2d258..cc340cf 100644 --- a/app/calculator.py +++ b/app/calculator.py @@ -89,7 +89,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f # 使用线程池执行计算密集型操作以避免阻塞事件循环 import asyncio loop = asyncio.get_event_loop() - + def _calculate_pp_sync(): map = rosu.Beatmap(content=beatmap) mods = deepcopy(score.mods.copy()) @@ -111,7 +111,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f misses=score.nmiss, ) return perf.calculate(map) - + # 在线程池中执行计算 attrs = await loop.run_in_executor(None, _calculate_pp_sync) pp = attrs.pp diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py index 7e7e35b..ec0bb4a 100644 --- a/app/fetcher/_base.py +++ b/app/fetcher/_base.py @@ -1,12 +1,19 @@ from __future__ import annotations import time +from typing import Optional from app.dependencies.database import get_redis +from app.log import logger from httpx import AsyncClient +class TokenAuthError(Exception): + """Token 授权失败异常""" + pass + + class BaseFetcher: def __init__( self, @@ -14,6 +21,7 @@ class BaseFetcher: client_secret: str, scope: list[str] = ["public"], callback_url: str = "", + max_retries: int = 3, ): self.client_id = client_id self.client_secret = client_secret @@ -22,6 +30,8 @@ class BaseFetcher: self.token_expiry: int = 0 self.callback_url: str = callback_url self.scope = scope + self.max_retries = max_retries + self._auth_retry_count = 0 # 授权重试计数器 @property def authorize_url(self) -> str: @@ -39,20 +49,91 @@ class BaseFetcher: } async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict: - if self.is_token_expired(): - await self.refresh_access_token() - header = kwargs.pop("headers", {}) - header = self.header + """ + 发送 API 请求,具有智能重试和自动重新授权机制 + """ + return await self._request_with_retry(url, method, **kwargs) - async with AsyncClient() as client: - response = await client.request( - method, - url, - headers=header, - **kwargs, - ) - response.raise_for_status() - return response.json() + async def _request_with_retry( + self, url: str, method: str = "GET", max_retries: Optional[int] = None, **kwargs + ) -> dict: + """ + 带重试机制的请求方法 + """ + if max_retries is None: + max_retries = self.max_retries + + last_error = None + + for attempt in range(max_retries + 1): + try: + # 检查 token 是否过期 + if self.is_token_expired(): + await self.refresh_access_token() + + header = kwargs.pop("headers", {}) + header.update(self.header) + + async with AsyncClient() as client: + response = await client.request( + method, + url, + headers=header, + **kwargs, + ) + + # 处理 401 错误 + if response.status_code == 401: + self._auth_retry_count += 1 + logger.warning( + f"Received 401 error (attempt {attempt + 1}/{max_retries + 1}) " + f"for {url}, auth retry count: {self._auth_retry_count}" + ) + + # 如果达到最大重试次数,触发重新授权 + if self._auth_retry_count >= self.max_retries: + await self._trigger_reauthorization() + raise TokenAuthError( + f"Authentication failed after {self._auth_retry_count} attempts. " + f"Please re-authorize using: {self.authorize_url}" + ) + + # 如果还有重试机会,刷新 token 后继续 + if attempt < max_retries: + await self.refresh_access_token() + continue + else: + # 最后一次重试也失败了 + await self._trigger_reauthorization() + raise TokenAuthError( + f"Max retries ({max_retries}) exceeded for authentication. " + f"Please re-authorize using: {self.authorize_url}" + ) + + # 请求成功,重置重试计数器 + self._auth_retry_count = 0 + response.raise_for_status() + return response.json() + + except TokenAuthError: + # 重新抛出授权错误 + raise + except Exception as e: + last_error = e + if attempt < max_retries: + logger.warning( + f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {e}, retrying..." + ) + continue + else: + logger.error(f"Request failed after {max_retries + 1} attempts: {e}") + break + + # 如果所有重试都失败了 + if last_error: + raise last_error + else: + raise Exception(f"Request to {url} failed after {max_retries + 1} attempts") def is_token_expired(self) -> bool: return self.token_expiry <= int(time.time()) @@ -86,28 +167,91 @@ class BaseFetcher: ) async def refresh_access_token(self) -> None: - async with AsyncClient() as client: - response = await client.post( - "https://osu.ppy.sh/oauth/token", - data={ - "client_id": self.client_id, - "client_secret": self.client_secret, - "grant_type": "refresh_token", - "refresh_token": self.refresh_token, - }, - ) - response.raise_for_status() - token_data = response.json() - self.access_token = token_data["access_token"] - self.refresh_token = token_data.get("refresh_token", "") - self.token_expiry = int(time.time()) + token_data["expires_in"] + try: + logger.info(f"Refreshing access token for client {self.client_id}") + async with AsyncClient() as client: + response = await client.post( + "https://osu.ppy.sh/oauth/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + }, + ) + response.raise_for_status() + token_data = response.json() + self.access_token = token_data["access_token"] + self.refresh_token = token_data.get("refresh_token", "") + self.token_expiry = int(time.time()) + token_data["expires_in"] + redis = get_redis() + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) + logger.info(f"Successfully refreshed access token for client {self.client_id}") + except Exception as e: + logger.error(f"Failed to refresh access token for client {self.client_id}: {e}") + # 清除无效的 token,要求重新授权 + self.access_token = "" + self.refresh_token = "" + self.token_expiry = 0 redis = get_redis() - await redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - await redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.delete(f"fetcher:access_token:{self.client_id}") + await redis.delete(f"fetcher:refresh_token:{self.client_id}") + logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}") + raise + + async def _trigger_reauthorization(self) -> None: + """ + 触发重新授权流程 + 清除所有 token 并重置重试计数器 + """ + logger.error( + f"Authentication failed after {self._auth_retry_count} attempts. " + f"Triggering reauthorization for client {self.client_id}" + ) + + # 清除内存中的 token + self.access_token = "" + self.refresh_token = "" + self.token_expiry = 0 + self._auth_retry_count = 0 # 重置重试计数器 + + # 清除 Redis 中的 token + redis = get_redis() + await redis.delete(f"fetcher:access_token:{self.client_id}") + await redis.delete(f"fetcher:refresh_token:{self.client_id}") + + logger.warning( + f"All tokens cleared for client {self.client_id}. " + f"Please re-authorize using: {self.authorize_url}" + ) + + def reset_auth_retry_count(self) -> None: + """ + 重置授权重试计数器 + 可以在手动重新授权后调用 + """ + self._auth_retry_count = 0 + logger.info(f"Auth retry count reset for client {self.client_id}") + + def get_auth_status(self) -> dict: + """ + 获取当前授权状态信息 + """ + return { + "client_id": self.client_id, + "has_access_token": bool(self.access_token), + "has_refresh_token": bool(self.refresh_token), + "token_expired": self.is_token_expired(), + "auth_retry_count": self._auth_retry_count, + "max_retries": self.max_retries, + "authorize_url": self.authorize_url, + "needs_reauth": self._auth_retry_count >= self.max_retries, + } diff --git a/app/fetcher/beatmap_raw.py b/app/fetcher/beatmap_raw.py index 985fc48..25e9152 100644 --- a/app/fetcher/beatmap_raw.py +++ b/app/fetcher/beatmap_raw.py @@ -38,10 +38,10 @@ class BeatmapRawFetcher(BaseFetcher): self, redis: redis.Redis, beatmap_id: int ) -> str: from app.config import settings - + cache_key = f"beatmap:{beatmap_id}:raw" cache_expire = settings.beatmap_cache_expire_hours * 60 * 60 - + # 检查缓存 if await redis.exists(cache_key): content = await redis.get(cache_key) @@ -49,7 +49,7 @@ class BeatmapRawFetcher(BaseFetcher): # 延长缓存时间 await redis.expire(cache_key, cache_expire) return content # pyright: ignore[reportReturnType] - + # 获取并缓存 raw = await self.get_beatmap_raw(beatmap_id) await redis.set(cache_key, raw, ex=cache_expire)