refactor(assets_proxy): use decorators to simplify code

This commit is contained in:
MingxuanGame
2025-10-03 17:12:28 +00:00
parent d490239f46
commit 046f894407
53 changed files with 151 additions and 313 deletions

View File

@@ -0,0 +1,108 @@
"""资源代理辅助方法与路由装饰器。"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from functools import wraps
import re
from typing import Any
from app.config import settings
from fastapi import Response
from pydantic import BaseModel
Handler = Callable[..., Awaitable[Any]]
def _replace_asset_urls_in_string(value: str) -> str:
result = value
custom_domain = settings.custom_asset_domain
asset_prefix = settings.asset_proxy_prefix
avatar_prefix = settings.avatar_proxy_prefix
beatmap_prefix = settings.beatmap_proxy_prefix
audio_proxy_base_url = f"{settings.server_url}api/private/audio/beatmapset"
result = re.sub(
r"^https://assets\.ppy\.sh/",
f"https://{asset_prefix}.{custom_domain}/",
result,
)
result = re.sub(
r"^https://b\.ppy\.sh/preview/(\d+)\\.mp3",
rf"{audio_proxy_base_url}/\1",
result,
)
result = re.sub(
r"^//b\.ppy\.sh/preview/(\d+)\\.mp3",
rf"{audio_proxy_base_url}/\1",
result,
)
result = re.sub(
r"^https://a\.ppy\.sh/",
f"https://{avatar_prefix}.{custom_domain}/",
result,
)
result = re.sub(
r"https://b\.ppy\.sh/",
f"https://{beatmap_prefix}.{custom_domain}/",
result,
)
return result
def _replace_asset_urls_in_data(data: Any) -> Any:
if isinstance(data, str):
return _replace_asset_urls_in_string(data)
if isinstance(data, list):
return [_replace_asset_urls_in_data(item) for item in data]
if isinstance(data, tuple):
return tuple(_replace_asset_urls_in_data(item) for item in data)
if isinstance(data, dict):
return {key: _replace_asset_urls_in_data(value) for key, value in data.items()}
return data
async def replace_asset_urls(data: Any) -> Any:
"""替换数据中的 osu! 资源 URL。"""
if not settings.enable_asset_proxy:
return data
if hasattr(data, "model_dump"):
raw = data.model_dump()
processed = _replace_asset_urls_in_data(raw)
try:
return data.__class__(**processed)
except Exception:
return processed
if isinstance(data, (dict, list, tuple, str)):
return _replace_asset_urls_in_data(data)
return data
def asset_proxy_response(func: Handler) -> Handler:
"""装饰器:在返回响应前替换资源 URL。"""
@wraps(func)
async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
if not settings.enable_asset_proxy:
return result
if isinstance(result, Response):
return result
if isinstance(result, BaseModel):
result = result.model_dump()
return _replace_asset_urls_in_data(result)
return wrapper # type: ignore[return-value]

View File

@@ -204,3 +204,6 @@ class SearchQueryModel(BaseModel):
default=None,
description="游标字符串,用于分页",
)
SearchQueryModel.model_rebuild()

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from .auth import router as auth_router
from .fetcher import fetcher_router as fetcher_router
from .file import file_router as file_router

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import timedelta
import re
from typing import Annotated, Literal

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from app.dependencies.fetcher import Fetcher
from fastapi import APIRouter

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from app.dependencies.storage import StorageService as StorageServiceDep
from app.storage import LocalStorageService

View File

@@ -1,7 +1,5 @@
"""LIO (Legacy IO) router for osu-server-spectator compatibility."""
from __future__ import annotations
import base64
import json
from typing import Any

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from app.config import settings
from app.database.notification import Notification, UserNotification
from app.database.user import User

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from math import ceil

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated, Any, Literal, Self
from app.database.chat import (

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated
from app.database import ChatMessageResp

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import asyncio
from typing import Annotated, overload

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from app.config import settings
from . import admin, audio_proxy, avatar, beatmapset, cover, oauth, relationship, score, team, username # noqa: F401

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated
from app.database.auth import OAuthToken

View File

@@ -3,8 +3,6 @@
提供从osu!官方获取beatmapset音频预览的代理服务
"""
from __future__ import annotations
from typing import Annotated
from app.dependencies.database import get_redis, get_redis_binary

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import hashlib
from typing import Annotated

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated
from app.database.beatmap import Beatmap

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import hashlib
from typing import Annotated

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import secrets
from typing import Annotated

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated
from app.database import Relationship

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from app.dependencies.rate_limit import LIMITERS
from fastapi import APIRouter

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from app.database.score import Score
from app.dependencies.database import Database, Redis
from app.dependencies.storage import StorageService

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import hashlib
from typing import Annotated

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated
from app.auth import (

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated
from app.auth import validate_username

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import urllib.parse
from app.config import settings

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from . import beatmap, public_user, replay, score, user # noqa: F401
from .public_router import public_router as api_v1_public_router
from .router import router as api_v1_router

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Literal

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated, Literal
from app.database.statistics import UserStatistics

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import base64
from datetime import date
from typing import Annotated, Literal

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Annotated, Literal

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Literal

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from . import ( # noqa: F401
beatmap,
beatmapset,

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import asyncio
import hashlib
import json

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import re
from typing import Annotated, Literal
from urllib.parse import parse_qs
@@ -12,8 +10,8 @@ from app.dependencies.database import Database, Redis, with_db
from app.dependencies.fetcher import Fetcher
from app.dependencies.geoip import IPAddress, get_geoip_helper
from app.dependencies.user import ClientUser, get_current_user
from app.helpers.asset_proxy_helper import asset_proxy_response
from app.models.beatmap import SearchQueryModel
from app.service.asset_proxy_helper import process_response_assets
from app.service.beatmapset_cache_service import generate_hash
from .router import router
@@ -45,8 +43,9 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
tags=["谱面集"],
response_model=SearchBeatmapsetsResp,
)
@asset_proxy_response
async def search_beatmapset(
query: Annotated[SearchQueryModel, Query(...)],
query: Annotated[SearchQueryModel, Query()],
request: Request,
background_tasks: BackgroundTasks,
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
@@ -102,9 +101,7 @@ async def search_beatmapset(
cached_result = await cache_service.get_search_from_cache(query_hash, cursor_hash)
if cached_result:
sets = SearchBeatmapsetsResp(**cached_result)
# 处理资源代理
processed_sets = await process_response_assets(sets)
return processed_sets
return sets
try:
sets = await fetcher.search_beatmapset(query, cursor, redis)
@@ -112,10 +109,7 @@ async def search_beatmapset(
# 缓存搜索结果
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
# 处理资源代理
processed_sets = await process_response_assets(sets)
return processed_sets
return sets
except HTTPError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@@ -127,6 +121,7 @@ async def search_beatmapset(
response_model=BeatmapsetResp,
description=("通过谱面 ID 查询所属谱面集。"),
)
@asset_proxy_response
async def lookup_beatmapset(
db: Database,
request: Request,
@@ -138,9 +133,7 @@ async def lookup_beatmapset(
# 先尝试从缓存获取
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
if cached_resp:
# 处理资源代理
processed_resp = await process_response_assets(cached_resp)
return processed_resp
return cached_resp
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
@@ -148,10 +141,7 @@ async def lookup_beatmapset(
# 缓存结果
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
# 处理资源代理
processed_resp = await process_response_assets(resp)
return processed_resp
return resp
except HTTPError as exc:
raise HTTPException(status_code=404, detail="Beatmap not found") from exc
@@ -163,6 +153,7 @@ async def lookup_beatmapset(
response_model=BeatmapsetResp,
description="获取单个谱面集详情。",
)
@asset_proxy_response
async def get_beatmapset(
db: Database,
request: Request,
@@ -174,9 +165,7 @@ async def get_beatmapset(
# 先尝试从缓存获取
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
if cached_resp:
# 处理资源代理
processed_resp = await process_response_assets(cached_resp)
return processed_resp
return cached_resp
try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
@@ -184,10 +173,7 @@ async def get_beatmapset(
# 缓存结果
await cache_service.cache_beatmapset(resp)
# 处理资源代理
processed_resp = await process_response_assets(resp)
return processed_resp
return resp
except HTTPError as exc:
raise HTTPException(status_code=404, detail="Beatmapset not found") from exc

View File

@@ -3,8 +3,6 @@
提供缓存统计、清理和预热功能
"""
from __future__ import annotations
from app.dependencies.database import Redis
from app.service.user_cache_service import get_user_cache_service

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated
from app.database import MeResp, User

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import UTC, datetime
from app.config import settings

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated, Literal
from app.config import settings

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated
from app.database import Relationship, RelationshipResp, RelationshipType, User

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import UTC
from typing import Annotated, Literal

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from app.dependencies.rate_limit import LIMITERS
from fastapi import APIRouter

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import UTC, date
import time
from typing import Annotated

View File

@@ -2,8 +2,6 @@
会话验证路由 - 实现类似 osu! 的邮件验证流程 (API v2)
"""
from __future__ import annotations
from typing import Annotated, Literal
from app.auth import check_totp_backup_code, verify_totp_key_with_replay_protection

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from typing import Annotated
from app.database.beatmap import Beatmap

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
from datetime import timedelta
from typing import Annotated, Literal
@@ -19,10 +17,10 @@ from app.database.user import SEARCH_INCLUDED
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_current_user
from app.helpers.asset_proxy_helper import asset_proxy_response
from app.log import log
from app.models.score import GameMode
from app.models.user import BeatmapsetType
from app.service.asset_proxy_helper import process_response_assets
from app.service.user_cache_service import get_user_cache_service
from app.utils import utcnow
@@ -47,6 +45,7 @@ class BatchUserResponse(BaseModel):
)
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
@asset_proxy_response
async def get_users(
session: Database,
request: Request,
@@ -89,28 +88,25 @@ async def get_users(
# 异步缓存,不阻塞响应
background_task.add_task(cache_service.cache_user, user_resp)
# 处理资源代理
response = BatchUserResponse(users=cached_users)
processed_response = await process_response_assets(response)
return processed_response
return response
else:
searched_users = (await session.exec(select(User).limit(50))).all()
users = []
for searched_user in searched_users:
if searched_user.id != BANCHOBOT_ID:
user_resp = await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
)
users.append(user_resp)
# 异步缓存
background_task.add_task(cache_service.cache_user, user_resp)
if searched_user.id == BANCHOBOT_ID:
continue
user_resp = await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
)
users.append(user_resp)
# 异步缓存
background_task.add_task(cache_service.cache_user, user_resp)
# 处理资源代理
response = BatchUserResponse(users=users)
processed_response = await process_response_assets(response)
return processed_response
return response
@router.get(
@@ -176,6 +172,7 @@ async def get_user_kudosu(
description="通过用户 ID 或用户名获取单个用户的详细信息,并指定特定 ruleset。",
tags=["用户"],
)
@asset_proxy_response
async def get_user_info_ruleset(
session: Database,
background_task: BackgroundTasks,
@@ -224,6 +221,7 @@ async def get_user_info_ruleset(
description="通过用户 ID 或用户名获取单个用户的详细信息。",
tags=["用户"],
)
@asset_proxy_response
async def get_user_info(
background_task: BackgroundTasks,
session: Database,
@@ -239,9 +237,7 @@ async def get_user_info(
user_id_int = int(user_id)
cached_user = await cache_service.get_user_from_cache(user_id_int)
if cached_user:
# 处理资源代理
processed_user = await process_response_assets(cached_user)
return processed_user
return cached_user
searched_user = (
await session.exec(
@@ -262,9 +258,7 @@ async def get_user_info(
# 异步缓存结果
background_task.add_task(cache_service.cache_user, user_resp)
# 处理资源代理
processed_user = await process_response_assets(user_resp)
return processed_user
return user_resp
@router.get(
@@ -274,6 +268,7 @@ async def get_user_info(
description="获取指定用户特定类型的谱面集列表,如最常游玩、收藏等。",
tags=["用户"],
)
@asset_proxy_response
async def get_user_beatmapsets(
session: Database,
background_task: BackgroundTasks,
@@ -354,6 +349,7 @@ async def get_user_beatmapsets(
),
tags=["用户"],
)
@asset_proxy_response
async def get_user_scores(
request: Request,
session: Database,
@@ -381,8 +377,7 @@ async def get_user_scores(
user_id, type, include_fails, mode, limit, offset, is_legacy_api
)
if cached_scores is not None:
processed_scores = await process_response_assets(cached_scores)
return processed_scores
return cached_scores
db_user = await session.get(User, user_id)
if not db_user or db_user.id == BANCHOBOT_ID:
@@ -437,6 +432,4 @@ async def get_user_scores(
is_legacy_api,
)
# 处理资源代理
processed_scores = await process_response_assets(score_responses)
return processed_scores
return score_responses

View File

@@ -1,79 +0,0 @@
"""
资源代理辅助函数和中间件
"""
from __future__ import annotations
from typing import Any
from app.config import settings
from app.service.asset_proxy_service import get_asset_proxy_service
from fastapi import Request
async def process_response_assets(data: Any) -> Any:
"""
根据配置处理响应数据中的资源URL
Args:
data: API响应数据
request: FastAPI请求对象
Returns:
处理后的数据
"""
if not settings.enable_asset_proxy:
return data
asset_service = get_asset_proxy_service()
# 仅URL替换模式
return await asset_service.replace_asset_urls(data)
def should_process_asset_proxy(path: str) -> bool:
"""
判断路径是否需要处理资源代理
"""
# 只对特定的API端点处理资源代理
asset_proxy_endpoints = [
"/api/v1/users/",
"/api/v2/users/",
"/api/v1/me/",
"/api/v2/me/",
"/api/v2/beatmapsets/search",
"/api/v2/beatmapsets/lookup",
"/api/v2/beatmaps/",
"/api/v1/beatmaps/",
"/api/v2/beatmapsets/",
# 可以根据需要添加更多端点
]
return any(path.startswith(endpoint) for endpoint in asset_proxy_endpoints)
# 响应处理装饰器
def asset_proxy_response(func):
"""
装饰器自动处理响应中的资源URL
"""
async def wrapper(*args, **kwargs):
# 获取request对象
request = None
for arg in args:
if isinstance(arg, Request):
request = arg
break
# 执行原函数
result = await func(*args, **kwargs)
# 如果有request对象且启用了资源代理则处理响应
if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path):
result = await process_response_assets(result)
return result
return wrapper

View File

@@ -1,83 +0,0 @@
"""
资源文件代理服务
提供URL替换方案将osu!官方资源URL替换为自定义域名
"""
from __future__ import annotations
import re
from typing import Any
from app.config import settings
class AssetProxyService:
"""资源代理服务 - 仅URL替换模式"""
def __init__(self):
# 从配置获取自定义assets域名和前缀
self.custom_asset_domain = settings.custom_asset_domain
self.asset_proxy_prefix = settings.asset_proxy_prefix
self.avatar_proxy_prefix = settings.avatar_proxy_prefix
self.beatmap_proxy_prefix = settings.beatmap_proxy_prefix
# 音频代理接口URL
self.audio_proxy_base_url = f"{settings.server_url}api/private/audio/beatmapset"
async def replace_asset_urls(self, data: Any) -> Any:
"""
递归替换数据中的osu!资源URL为自定义域名
"""
# 处理Pydantic模型
if hasattr(data, "model_dump"):
# 转换为字典,处理后再转换回模型
data_dict = data.model_dump()
processed_dict = await self.replace_asset_urls(data_dict)
# 尝试从字典重新创建模型
try:
return data.__class__(**processed_dict)
except Exception:
# 如果重新创建失败,返回字典
return processed_dict
elif isinstance(data, dict):
result = {}
for key, value in data.items():
result[key] = await self.replace_asset_urls(value)
return result
elif isinstance(data, list):
return [await self.replace_asset_urls(item) for item in data]
elif isinstance(data, str):
# 替换各种osu!资源域名
result = data
# 替换 assets.ppy.sh (用户头像、封面、奖章等)
result = re.sub(
r"https://assets\.ppy\.sh/", f"https://{self.asset_proxy_prefix}.{self.custom_asset_domain}/", result
)
# 替换 b.ppy.sh 预览音频为我们的音频代理接口
# 匹配 https://b.ppy.sh/preview/{beatmapset_id}.mp3 格式
result = re.sub(r"https://b\.ppy\.sh/preview/(\d+)\.mp3", rf"{self.audio_proxy_base_url}/\1", result)
# 匹配 //b.ppy.sh/preview/{beatmapset_id}.mp3 格式
result = re.sub(r"//b\.ppy\.sh/preview/(\d+)\.mp3", rf"{self.audio_proxy_base_url}/\1", result)
# 替换 a.ppy.sh 头像
result = re.sub(
r"https://a\.ppy\.sh/", f"https://{self.avatar_proxy_prefix}.{self.custom_asset_domain}/", result
)
return result
else:
return data
# 全局实例
_asset_proxy_service: AssetProxyService | None = None
def get_asset_proxy_service() -> AssetProxyService:
"""获取资源代理服务实例"""
global _asset_proxy_service
if _asset_proxy_service is None:
_asset_proxy_service = AssetProxyService()
return _asset_proxy_service

View File

@@ -12,9 +12,9 @@ from typing import TYPE_CHECKING, Literal
from app.config import settings
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.helpers.asset_proxy_helper import replace_asset_urls
from app.log import logger
from app.models.score import GameMode
from app.service.asset_proxy_service import get_asset_proxy_service
from app.utils import utcnow
from redis.asyncio import Redis
@@ -357,16 +357,15 @@ class RankingCacheService:
for statistics in statistics_data:
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
user_dict = user_stats_resp.model_dump()
# 应用资源代理处理
if settings.enable_asset_proxy:
try:
asset_proxy_service = get_asset_proxy_service()
user_stats_resp = await asset_proxy_service.replace_asset_urls(user_stats_resp)
user_dict = await replace_asset_urls(user_dict)
except Exception as e:
logger.warning(f"Asset proxy processing failed for ranking cache: {e}")
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题
user_dict = json.loads(user_stats_resp.model_dump_json())
ranking_data.append(user_dict)
# 缓存这一页的数据

View File

@@ -15,9 +15,9 @@ from app.database import User, UserResp
from app.database.score import LegacyScoreResp, ScoreResp
from app.database.user import SEARCH_INCLUDED
from app.dependencies.database import with_db
from app.helpers.asset_proxy_helper import replace_asset_urls
from app.log import logger
from app.models.score import GameMode
from app.service.asset_proxy_service import get_asset_proxy_service
from redis.asyncio import Redis
from sqlmodel import col, select
@@ -318,8 +318,7 @@ class UserCacheService:
# 应用资源代理处理
if settings.enable_asset_proxy:
try:
asset_proxy_service = get_asset_proxy_service()
user_resp = await asset_proxy_service.replace_asset_urls(user_resp)
user_resp = await replace_asset_urls(user_resp)
except Exception as e:
logger.warning(f"Asset proxy processing failed for user cache {user.id}: {e}")

View File

@@ -86,7 +86,7 @@ ignore = [
"migrations/**/*.py" = ["INP001"]
".github/**/*.py" = ["INP001"]
"app/achievements/*.py" = ["INP001", "ARG"]
"app/router/**/*.py" = ["ARG001"]
"app/router/**/*.py" = ["ARG001", "I002"]
[tool.ruff.lint.isort]
force-sort-within-sections = true