fix token
This commit is contained in:
@@ -89,7 +89,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
|||||||
# 使用线程池执行计算密集型操作以避免阻塞事件循环
|
# 使用线程池执行计算密集型操作以避免阻塞事件循环
|
||||||
import asyncio
|
import asyncio
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def _calculate_pp_sync():
|
def _calculate_pp_sync():
|
||||||
map = rosu.Beatmap(content=beatmap)
|
map = rosu.Beatmap(content=beatmap)
|
||||||
mods = deepcopy(score.mods.copy())
|
mods = deepcopy(score.mods.copy())
|
||||||
@@ -111,7 +111,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
|||||||
misses=score.nmiss,
|
misses=score.nmiss,
|
||||||
)
|
)
|
||||||
return perf.calculate(map)
|
return perf.calculate(map)
|
||||||
|
|
||||||
# 在线程池中执行计算
|
# 在线程池中执行计算
|
||||||
attrs = await loop.run_in_executor(None, _calculate_pp_sync)
|
attrs = await loop.run_in_executor(None, _calculate_pp_sync)
|
||||||
pp = attrs.pp
|
pp = attrs.pp
|
||||||
|
|||||||
@@ -1,12 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from app.dependencies.database import get_redis
|
from app.dependencies.database import get_redis
|
||||||
|
from app.log import logger
|
||||||
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
class TokenAuthError(Exception):
|
||||||
|
"""Token 授权失败异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseFetcher:
|
class BaseFetcher:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -14,6 +21,7 @@ class BaseFetcher:
|
|||||||
client_secret: str,
|
client_secret: str,
|
||||||
scope: list[str] = ["public"],
|
scope: list[str] = ["public"],
|
||||||
callback_url: str = "",
|
callback_url: str = "",
|
||||||
|
max_retries: int = 3,
|
||||||
):
|
):
|
||||||
self.client_id = client_id
|
self.client_id = client_id
|
||||||
self.client_secret = client_secret
|
self.client_secret = client_secret
|
||||||
@@ -22,6 +30,8 @@ class BaseFetcher:
|
|||||||
self.token_expiry: int = 0
|
self.token_expiry: int = 0
|
||||||
self.callback_url: str = callback_url
|
self.callback_url: str = callback_url
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self._auth_retry_count = 0 # 授权重试计数器
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def authorize_url(self) -> str:
|
def authorize_url(self) -> str:
|
||||||
@@ -39,20 +49,91 @@ class BaseFetcher:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
|
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
|
||||||
if self.is_token_expired():
|
"""
|
||||||
await self.refresh_access_token()
|
发送 API 请求,具有智能重试和自动重新授权机制
|
||||||
header = kwargs.pop("headers", {})
|
"""
|
||||||
header = self.header
|
return await self._request_with_retry(url, method, **kwargs)
|
||||||
|
|
||||||
async with AsyncClient() as client:
|
async def _request_with_retry(
|
||||||
response = await client.request(
|
self, url: str, method: str = "GET", max_retries: Optional[int] = None, **kwargs
|
||||||
method,
|
) -> dict:
|
||||||
url,
|
"""
|
||||||
headers=header,
|
带重试机制的请求方法
|
||||||
**kwargs,
|
"""
|
||||||
)
|
if max_retries is None:
|
||||||
response.raise_for_status()
|
max_retries = self.max_retries
|
||||||
return response.json()
|
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
# 检查 token 是否过期
|
||||||
|
if self.is_token_expired():
|
||||||
|
await self.refresh_access_token()
|
||||||
|
|
||||||
|
header = kwargs.pop("headers", {})
|
||||||
|
header.update(self.header)
|
||||||
|
|
||||||
|
async with AsyncClient() as client:
|
||||||
|
response = await client.request(
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
headers=header,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理 401 错误
|
||||||
|
if response.status_code == 401:
|
||||||
|
self._auth_retry_count += 1
|
||||||
|
logger.warning(
|
||||||
|
f"Received 401 error (attempt {attempt + 1}/{max_retries + 1}) "
|
||||||
|
f"for {url}, auth retry count: {self._auth_retry_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果达到最大重试次数,触发重新授权
|
||||||
|
if self._auth_retry_count >= self.max_retries:
|
||||||
|
await self._trigger_reauthorization()
|
||||||
|
raise TokenAuthError(
|
||||||
|
f"Authentication failed after {self._auth_retry_count} attempts. "
|
||||||
|
f"Please re-authorize using: {self.authorize_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果还有重试机会,刷新 token 后继续
|
||||||
|
if attempt < max_retries:
|
||||||
|
await self.refresh_access_token()
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 最后一次重试也失败了
|
||||||
|
await self._trigger_reauthorization()
|
||||||
|
raise TokenAuthError(
|
||||||
|
f"Max retries ({max_retries}) exceeded for authentication. "
|
||||||
|
f"Please re-authorize using: {self.authorize_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 请求成功,重置重试计数器
|
||||||
|
self._auth_retry_count = 0
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
except TokenAuthError:
|
||||||
|
# 重新抛出授权错误
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
if attempt < max_retries:
|
||||||
|
logger.warning(
|
||||||
|
f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {e}, retrying..."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.error(f"Request failed after {max_retries + 1} attempts: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果所有重试都失败了
|
||||||
|
if last_error:
|
||||||
|
raise last_error
|
||||||
|
else:
|
||||||
|
raise Exception(f"Request to {url} failed after {max_retries + 1} attempts")
|
||||||
|
|
||||||
def is_token_expired(self) -> bool:
|
def is_token_expired(self) -> bool:
|
||||||
return self.token_expiry <= int(time.time())
|
return self.token_expiry <= int(time.time())
|
||||||
@@ -86,28 +167,91 @@ class BaseFetcher:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def refresh_access_token(self) -> None:
|
async def refresh_access_token(self) -> None:
|
||||||
async with AsyncClient() as client:
|
try:
|
||||||
response = await client.post(
|
logger.info(f"Refreshing access token for client {self.client_id}")
|
||||||
"https://osu.ppy.sh/oauth/token",
|
async with AsyncClient() as client:
|
||||||
data={
|
response = await client.post(
|
||||||
"client_id": self.client_id,
|
"https://osu.ppy.sh/oauth/token",
|
||||||
"client_secret": self.client_secret,
|
data={
|
||||||
"grant_type": "refresh_token",
|
"client_id": self.client_id,
|
||||||
"refresh_token": self.refresh_token,
|
"client_secret": self.client_secret,
|
||||||
},
|
"grant_type": "refresh_token",
|
||||||
)
|
"refresh_token": self.refresh_token,
|
||||||
response.raise_for_status()
|
},
|
||||||
token_data = response.json()
|
)
|
||||||
self.access_token = token_data["access_token"]
|
response.raise_for_status()
|
||||||
self.refresh_token = token_data.get("refresh_token", "")
|
token_data = response.json()
|
||||||
self.token_expiry = int(time.time()) + token_data["expires_in"]
|
self.access_token = token_data["access_token"]
|
||||||
|
self.refresh_token = token_data.get("refresh_token", "")
|
||||||
|
self.token_expiry = int(time.time()) + token_data["expires_in"]
|
||||||
|
redis = get_redis()
|
||||||
|
await redis.set(
|
||||||
|
f"fetcher:access_token:{self.client_id}",
|
||||||
|
self.access_token,
|
||||||
|
ex=token_data["expires_in"],
|
||||||
|
)
|
||||||
|
await redis.set(
|
||||||
|
f"fetcher:refresh_token:{self.client_id}",
|
||||||
|
self.refresh_token,
|
||||||
|
)
|
||||||
|
logger.info(f"Successfully refreshed access token for client {self.client_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to refresh access token for client {self.client_id}: {e}")
|
||||||
|
# 清除无效的 token,要求重新授权
|
||||||
|
self.access_token = ""
|
||||||
|
self.refresh_token = ""
|
||||||
|
self.token_expiry = 0
|
||||||
redis = get_redis()
|
redis = get_redis()
|
||||||
await redis.set(
|
await redis.delete(f"fetcher:access_token:{self.client_id}")
|
||||||
f"fetcher:access_token:{self.client_id}",
|
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
|
||||||
self.access_token,
|
logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}")
|
||||||
ex=token_data["expires_in"],
|
raise
|
||||||
)
|
|
||||||
await redis.set(
|
async def _trigger_reauthorization(self) -> None:
|
||||||
f"fetcher:refresh_token:{self.client_id}",
|
"""
|
||||||
self.refresh_token,
|
触发重新授权流程
|
||||||
)
|
清除所有 token 并重置重试计数器
|
||||||
|
"""
|
||||||
|
logger.error(
|
||||||
|
f"Authentication failed after {self._auth_retry_count} attempts. "
|
||||||
|
f"Triggering reauthorization for client {self.client_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 清除内存中的 token
|
||||||
|
self.access_token = ""
|
||||||
|
self.refresh_token = ""
|
||||||
|
self.token_expiry = 0
|
||||||
|
self._auth_retry_count = 0 # 重置重试计数器
|
||||||
|
|
||||||
|
# 清除 Redis 中的 token
|
||||||
|
redis = get_redis()
|
||||||
|
await redis.delete(f"fetcher:access_token:{self.client_id}")
|
||||||
|
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"All tokens cleared for client {self.client_id}. "
|
||||||
|
f"Please re-authorize using: {self.authorize_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_auth_retry_count(self) -> None:
|
||||||
|
"""
|
||||||
|
重置授权重试计数器
|
||||||
|
可以在手动重新授权后调用
|
||||||
|
"""
|
||||||
|
self._auth_retry_count = 0
|
||||||
|
logger.info(f"Auth retry count reset for client {self.client_id}")
|
||||||
|
|
||||||
|
def get_auth_status(self) -> dict:
|
||||||
|
"""
|
||||||
|
获取当前授权状态信息
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"has_access_token": bool(self.access_token),
|
||||||
|
"has_refresh_token": bool(self.refresh_token),
|
||||||
|
"token_expired": self.is_token_expired(),
|
||||||
|
"auth_retry_count": self._auth_retry_count,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
|
"authorize_url": self.authorize_url,
|
||||||
|
"needs_reauth": self._auth_retry_count >= self.max_retries,
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,10 +38,10 @@ class BeatmapRawFetcher(BaseFetcher):
|
|||||||
self, redis: redis.Redis, beatmap_id: int
|
self, redis: redis.Redis, beatmap_id: int
|
||||||
) -> str:
|
) -> str:
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
cache_key = f"beatmap:{beatmap_id}:raw"
|
cache_key = f"beatmap:{beatmap_id}:raw"
|
||||||
cache_expire = settings.beatmap_cache_expire_hours * 60 * 60
|
cache_expire = settings.beatmap_cache_expire_hours * 60 * 60
|
||||||
|
|
||||||
# 检查缓存
|
# 检查缓存
|
||||||
if await redis.exists(cache_key):
|
if await redis.exists(cache_key):
|
||||||
content = await redis.get(cache_key)
|
content = await redis.get(cache_key)
|
||||||
@@ -49,7 +49,7 @@ class BeatmapRawFetcher(BaseFetcher):
|
|||||||
# 延长缓存时间
|
# 延长缓存时间
|
||||||
await redis.expire(cache_key, cache_expire)
|
await redis.expire(cache_key, cache_expire)
|
||||||
return content # pyright: ignore[reportReturnType]
|
return content # pyright: ignore[reportReturnType]
|
||||||
|
|
||||||
# 获取并缓存
|
# 获取并缓存
|
||||||
raw = await self.get_beatmap_raw(beatmap_id)
|
raw = await self.get_beatmap_raw(beatmap_id)
|
||||||
await redis.set(cache_key, raw, ex=cache_expire)
|
await redis.set(cache_key, raw, ex=cache_expire)
|
||||||
|
|||||||
Reference in New Issue
Block a user