From 346c2557cfec9012a4af7da7fb8b6b30fdfeda1a Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 3 Oct 2025 05:41:31 +0000 Subject: [PATCH] refactor(api): use Annotated-style dependency injection --- app/database/playlists.py | 9 +- app/database/room.py | 7 +- app/dependencies/__init__.py | 3 - app/dependencies/api_version.py | 2 +- app/dependencies/beatmap_download.py | 9 +- app/dependencies/beatmapset_cache.py | 19 ++-- app/dependencies/database.py | 3 + app/dependencies/fetcher.py | 15 +++- app/dependencies/geoip.py | 9 +- app/dependencies/storage.py | 11 ++- app/dependencies/user.py | 10 ++- app/router/auth.py | 73 +++++++--------- app/router/fetcher.py | 7 +- app/router/file.py | 8 +- app/router/lio.py | 17 ++-- app/router/notification/channel.py | 46 +++++----- app/router/notification/message.py | 37 ++++---- app/router/notification/server.py | 12 +-- app/router/private/admin.py | 19 ++-- app/router/private/avatar.py | 15 ++-- app/router/private/beatmapset.py | 15 ++-- app/router/private/cover.py | 16 ++-- app/router/private/oauth.py | 41 +++++---- app/router/private/relationship.py | 12 +-- app/router/private/score.py | 17 ++-- app/router/private/team.py | 67 +++++++------- app/router/private/totp.py | 28 +++--- app/router/private/username.py | 10 ++- app/router/v1/beatmap.py | 34 ++++---- app/router/v1/public_user.py | 8 +- app/router/v1/replay.py | 31 +++---- app/router/v1/score.py | 30 +++---- app/router/v1/user.py | 10 +-- app/router/v2/beatmap.py | 57 ++++++------ app/router/v2/beatmapset.py | 62 +++++++------ app/router/v2/cache.py | 13 ++- app/router/v2/me.py | 15 ++-- app/router/v2/ranking.py | 68 ++++++++------- app/router/v2/relationship.py | 14 +-- app/router/v2/room.py | 87 +++++++++--------- app/router/v2/score.py | 126 +++++++++++++-------------- app/router/v2/session_verify.py | 27 +++--- app/router/v2/tags.py | 20 +++-- app/router/v2/user.py | 58 ++++++------ app/signalr/router.py | 3 +- 45 files changed, 623 insertions(+), 577 deletions(-) diff --git a/app/database/playlists.py b/app/database/playlists.py index e36441e..b589e96 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING from app.models.model import UTCBaseModel from app.models.mods import APIMod -from app.models.multiplayer_hub import PlaylistItem from .beatmap import Beatmap, BeatmapResp @@ -22,6 +21,8 @@ from sqlmodel import ( from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: + from app.models.multiplayer_hub import PlaylistItem + from .room import Room @@ -72,7 +73,7 @@ class Playlist(PlaylistBase, table=True): return result.one() @classmethod - async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist": + async def from_hub(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession) -> "Playlist": next_id = await cls.get_next_id_for_room(room_id, session=session) return cls( id=next_id, @@ -89,7 +90,7 @@ class Playlist(PlaylistBase, table=True): ) @classmethod - async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): + async def update(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession): db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id)) db_playlist = db_playlist.first() if db_playlist is None: @@ -106,7 +107,7 @@ class Playlist(PlaylistBase, table=True): await session.commit() @classmethod - async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): + async def add_to_db(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession): db_playlist = await cls.from_hub(playlist, room_id, session) session.add(db_playlist) await session.commit() diff --git a/app/database/room.py b/app/database/room.py index 647b7ca..2729e37 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -1,9 +1,9 @@ from datetime import datetime +from typing import TYPE_CHECKING from app.database.item_attempts_count import PlaylistAggregateScore from app.database.room_participated_user import RoomParticipatedUser from app.models.model import UTCBaseModel -from app.models.multiplayer_hub import ServerMultiplayerRoom from app.models.room import ( MatchType, QueueMode, @@ -32,6 +32,9 @@ from sqlmodel import ( ) from sqlmodel.ext.asyncio.session import AsyncSession +if TYPE_CHECKING: + from app.models.multiplayer_hub import ServerMultiplayerRoom + class RoomBase(SQLModel, UTCBaseModel): name: str = Field(index=True) @@ -161,7 +164,7 @@ class RoomResp(RoomBase): return resp @classmethod - async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp": + async def from_hub(cls, server_room: "ServerMultiplayerRoom") -> "RoomResp": room = server_room.room resp = cls( id=room.room_id, diff --git a/app/dependencies/__init__.py b/app/dependencies/__init__.py index cdcce5a..8b13789 100644 --- a/app/dependencies/__init__.py +++ b/app/dependencies/__init__.py @@ -1,4 +1 @@ -from __future__ import annotations -from .database import get_db as get_db -from .user import get_current_user as get_current_user diff --git a/app/dependencies/api_version.py b/app/dependencies/api_version.py index 7cfa1c7..af8489a 100644 --- a/app/dependencies/api_version.py +++ b/app/dependencies/api_version.py @@ -5,7 +5,7 @@ from typing import Annotated from fastapi import Depends, Header -def get_api_version(version: int | None = Header(None, alias="x-api-version")) -> int: +def get_api_version(version: int | None = Header(None, alias="x-api-version", include_in_schema=False)) -> int: if version is None: return 0 if version < 1: diff --git a/app/dependencies/beatmap_download.py b/app/dependencies/beatmap_download.py index ffed3a0..818dc7e 100644 --- a/app/dependencies/beatmap_download.py +++ b/app/dependencies/beatmap_download.py @@ -1,8 +1,15 @@ from __future__ import annotations -from app.service.beatmap_download_service import download_service +from typing import Annotated + +from app.service.beatmap_download_service import BeatmapDownloadService, download_service + +from fastapi import Depends def get_beatmap_download_service(): """获取谱面下载服务实例""" return download_service + + +DownloadService = Annotated[BeatmapDownloadService, Depends(get_beatmap_download_service)] diff --git a/app/dependencies/beatmapset_cache.py b/app/dependencies/beatmapset_cache.py index f5ac96b..df177e2 100644 --- a/app/dependencies/beatmapset_cache.py +++ b/app/dependencies/beatmapset_cache.py @@ -1,16 +1,19 @@ -""" -Beatmapset缓存服务依赖注入 -""" - from __future__ import annotations -from app.dependencies.database import get_redis -from app.service.beatmapset_cache_service import BeatmapsetCacheService, get_beatmapset_cache_service +from typing import Annotated + +from app.dependencies.database import Redis +from app.service.beatmapset_cache_service import ( + BeatmapsetCacheService as OriginBeatmapsetCacheService, + get_beatmapset_cache_service, +) from fastapi import Depends -from redis.asyncio import Redis -def get_beatmapset_cache_dependency(redis: Redis = Depends(get_redis)) -> BeatmapsetCacheService: +def get_beatmapset_cache_dependency(redis: Redis) -> OriginBeatmapsetCacheService: """获取beatmapset缓存服务依赖""" return get_beatmapset_cache_service(redis) + + +BeatmapsetCacheService = Annotated[OriginBeatmapsetCacheService, Depends(get_beatmapset_cache_dependency)] diff --git a/app/dependencies/database.py b/app/dependencies/database.py index 2fc11fb..1e0a29a 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -91,6 +91,9 @@ def get_redis(): return redis_client +Redis = Annotated[redis.Redis, Depends(get_redis)] + + def get_redis_binary(): """获取二进制数据专用的 Redis 客户端 (不自动解码响应)""" return redis_binary_client diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py index b4db26c..ccc3f06 100644 --- a/app/dependencies/fetcher.py +++ b/app/dependencies/fetcher.py @@ -1,17 +1,21 @@ from __future__ import annotations +from typing import Annotated + from app.config import settings from app.dependencies.database import get_redis -from app.fetcher import Fetcher +from app.fetcher import Fetcher as OriginFetcher from app.log import logger -fetcher: Fetcher | None = None +from fastapi import Depends + +fetcher: OriginFetcher | None = None -async def get_fetcher() -> Fetcher: +async def get_fetcher() -> OriginFetcher: global fetcher if fetcher is None: - fetcher = Fetcher( + fetcher = OriginFetcher( settings.fetcher_client_id, settings.fetcher_client_secret, settings.fetcher_scopes, @@ -27,3 +31,6 @@ async def get_fetcher() -> Fetcher: if not fetcher.access_token or not fetcher.refresh_token: logger.opt(colors=True).info(f"Login to initialize fetcher: {fetcher.authorize_url}") return fetcher + + +Fetcher = Annotated[OriginFetcher, Depends(get_fetcher)] diff --git a/app/dependencies/geoip.py b/app/dependencies/geoip.py index 9aafafd..089b90c 100644 --- a/app/dependencies/geoip.py +++ b/app/dependencies/geoip.py @@ -6,10 +6,13 @@ from __future__ import annotations from functools import lru_cache import ipaddress +from typing import Annotated from app.config import settings from app.helpers.geoip_helper import GeoIPHelper +from fastapi import Depends, Request + @lru_cache def get_geoip_helper() -> GeoIPHelper: @@ -26,7 +29,7 @@ def get_geoip_helper() -> GeoIPHelper: ) -def get_client_ip(request) -> str: +def get_client_ip(request: Request) -> str: """ 获取客户端真实 IP 地址 支持 IPv4 和 IPv6,考虑代理、负载均衡器等情况 @@ -66,6 +69,10 @@ def get_client_ip(request) -> str: return client_ip if is_valid_ip(client_ip) else "127.0.0.1" +IPAddress = Annotated[str, Depends(get_client_ip)] +GeoIPService = Annotated[GeoIPHelper, Depends(get_geoip_helper)] + + def is_valid_ip(ip_str: str) -> bool: """ 验证 IP 地址是否有效(支持 IPv4 和 IPv6) diff --git a/app/dependencies/storage.py b/app/dependencies/storage.py index 22906e0..413e5b0 100644 --- a/app/dependencies/storage.py +++ b/app/dependencies/storage.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import cast +from typing import Annotated, cast from app.config import ( AWSS3StorageSettings, @@ -9,11 +9,13 @@ from app.config import ( StorageServiceType, settings, ) -from app.storage import StorageService +from app.storage import StorageService as OriginStorageService from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService from app.storage.local import LocalStorageService -storage: StorageService | None = None +from fastapi import Depends + +storage: OriginStorageService | None = None def init_storage_service(): @@ -50,3 +52,6 @@ def get_storage_service(): if storage is None: return init_storage_service() return storage + + +StorageService = Annotated[OriginStorageService, Depends(get_storage_service)] diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 449b550..ff3ff53 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -4,6 +4,7 @@ from typing import Annotated from app.auth import get_token_by_access_token from app.config import settings +from app.const import SUPPORT_TOTP_VERIFICATION_VER from app.database import User from app.database.auth import OAuthToken, V1APIKeys from app.models.oauth import OAuth2ClientCredentialsBearer @@ -11,7 +12,7 @@ from app.models.oauth import OAuth2ClientCredentialsBearer from .api_version import APIVersion from .database import Database, get_redis -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, Security from fastapi.security import ( APIKeyQuery, HTTPBearer, @@ -112,13 +113,13 @@ async def get_client_user( if await LoginSessionService.check_is_need_verification(db, user.id, token.id): # 获取当前验证方式 verify_method = None - if api_version >= 20250913: + if api_version >= SUPPORT_TOTP_VERIFICATION_VER: verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis) if verify_method is None: # 智能选择验证方式(有TOTP优先TOTP) totp_key = await user.awaitable_attrs.totp_key - if totp_key is not None and api_version >= 20240101: + if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER: verify_method = "totp" else: verify_method = "mail" @@ -169,3 +170,6 @@ async def get_current_user( user_and_token: UserAndToken = Depends(get_current_user_and_token), ) -> User: return user_and_token[0] + + +ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])] diff --git a/app/router/auth.py b/app/router/auth.py index 544ce50..f2e7336 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -2,7 +2,7 @@ from __future__ import annotations from datetime import timedelta import re -from typing import Literal +from typing import Annotated, Literal from app.auth import ( authenticate_user, @@ -19,10 +19,9 @@ from app.const import BANCHOBOT_ID from app.database import DailyChallengeStats, OAuthClient, User from app.database.auth import TotpKeys from app.database.statistics import UserStatistics -from app.dependencies.database import Database, get_redis -from app.dependencies.geoip import get_client_ip, get_geoip_helper +from app.dependencies.database import Database, Redis +from app.dependencies.geoip import GeoIPService, IPAddress from app.dependencies.user_agent import UserAgentInfo -from app.helpers.geoip_helper import GeoIPHelper from app.log import logger from app.models.extended_auth import ExtendedTokenResponse from app.models.oauth import ( @@ -40,9 +39,8 @@ from app.service.verification_service import ( ) from app.utils import utcnow -from fastapi import APIRouter, Depends, Form, Header, Request +from fastapi import APIRouter, Form, Header, Request from fastapi.responses import JSONResponse -from redis.asyncio import Redis from sqlalchemy import text from sqlmodel import exists, select @@ -93,11 +91,11 @@ router = APIRouter(tags=["osu! OAuth 认证"]) ) async def register_user( db: Database, - request: Request, - user_username: str = Form(..., alias="user[username]", description="用户名"), - user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"), - user_password: str = Form(..., alias="user[password]", description="密码"), - geoip: GeoIPHelper = Depends(get_geoip_helper), + user_username: Annotated[str, Form(..., alias="user[username]", description="用户名")], + user_email: Annotated[str, Form(..., alias="user[user_email]", description="电子邮箱")], + user_password: Annotated[str, Form(..., alias="user[password]", description="密码")], + geoip: GeoIPService, + client_ip: IPAddress, ): username_errors = validate_username(user_username) email_errors = validate_email(user_email) @@ -126,7 +124,6 @@ async def register_user( try: # 获取客户端 IP 并查询地理位置 - client_ip = get_client_ip(request) country_code = "CN" # 默认国家代码 try: @@ -201,19 +198,21 @@ async def oauth_token( db: Database, request: Request, user_agent: UserAgentInfo, - grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form( - ..., description="授权类型:密码/刷新令牌/授权码/客户端凭证" - ), - client_id: int = Form(..., description="客户端 ID"), - client_secret: str = Form(..., description="客户端密钥"), - code: str | None = Form(None, description="授权码(仅授权码模式需要)"), - scope: str = Form("*", description="权限范围(空格分隔,默认为 '*')"), - username: str | None = Form(None, description="用户名(仅密码模式需要)"), - password: str | None = Form(None, description="密码(仅密码模式需要)"), - refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"), - redis: Redis = Depends(get_redis), - geoip: GeoIPHelper = Depends(get_geoip_helper), - web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"), + ip_address: IPAddress, + grant_type: Annotated[ + Literal["authorization_code", "refresh_token", "password", "client_credentials"], + Form(..., description="授权类型:密码、刷新令牌和授权码三种授权方式。"), + ], + client_id: Annotated[int, Form(..., description="客户端 ID")], + client_secret: Annotated[str, Form(..., description="客户端密钥")], + redis: Redis, + geoip: GeoIPService, + code: Annotated[str | None, Form(description="授权码(仅授权码模式需要)")] = None, + scope: Annotated[str, Form(description="权限范围(空格分隔,默认为 '*')")] = "*", + username: Annotated[str | None, Form(description="用户名(仅密码模式需要)")] = None, + password: Annotated[str | None, Form(description="密码(仅密码模式需要)")] = None, + refresh_token: Annotated[str | None, Form(description="刷新令牌(仅刷新令牌模式需要)")] = None, + web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None, ): scopes = scope.split(" ") @@ -311,8 +310,6 @@ async def oauth_token( ) token_id = token.id - ip_address = get_client_ip(request) - # 获取国家代码 geo_info = geoip.lookup(ip_address) country_code = geo_info.get("country_iso", "XX") @@ -571,16 +568,14 @@ async def oauth_token( ) async def request_password_reset( request: Request, - email: str = Form(..., description="邮箱地址"), - redis: Redis = Depends(get_redis), + email: Annotated[str, Form(..., description="邮箱地址")], + redis: Redis, + ip_address: IPAddress, ): """ 请求密码重置 """ - from app.dependencies.geoip import get_client_ip - # 获取客户端信息 - ip_address = get_client_ip(request) user_agent = request.headers.get("User-Agent", "") # 请求密码重置 @@ -599,20 +594,16 @@ async def request_password_reset( @router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码") async def reset_password( - request: Request, - email: str = Form(..., description="邮箱地址"), - reset_code: str = Form(..., description="重置验证码"), - new_password: str = Form(..., description="新密码"), - redis: Redis = Depends(get_redis), + email: Annotated[str, Form(..., description="邮箱地址")], + reset_code: Annotated[str, Form(..., description="重置验证码")], + new_password: Annotated[str, Form(..., description="新密码")], + redis: Redis, + ip_address: IPAddress, ): """ 重置密码 """ - from app.dependencies.geoip import get_client_ip - # 获取客户端信息 - ip_address = get_client_ip(request) - # 重置密码 success, message = await password_reset_service.reset_password( email=email.lower().strip(), diff --git a/app/router/fetcher.py b/app/router/fetcher.py index f936ed6..887eabf 100644 --- a/app/router/fetcher.py +++ b/app/router/fetcher.py @@ -1,14 +1,13 @@ from __future__ import annotations -from app.dependencies.fetcher import get_fetcher -from app.fetcher import Fetcher +from app.dependencies.fetcher import Fetcher -from fastapi import APIRouter, Depends +from fastapi import APIRouter fetcher_router = APIRouter(prefix="/fetcher", include_in_schema=False) @fetcher_router.get("/callback") -async def callback(code: str, fetcher: Fetcher = Depends(get_fetcher)): +async def callback(code: str, fetcher: Fetcher): await fetcher.grant_access_token(code) return {"message": "Login successful"} diff --git a/app/router/file.py b/app/router/file.py index bd35a7e..14263f9 100644 --- a/app/router/file.py +++ b/app/router/file.py @@ -1,16 +1,16 @@ from __future__ import annotations -from app.dependencies.storage import get_storage_service -from app.storage import LocalStorageService, StorageService +from app.dependencies.storage import StorageService as StorageServiceDep +from app.storage import LocalStorageService -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, HTTPException from fastapi.responses import FileResponse file_router = APIRouter(prefix="/file", include_in_schema=False) @file_router.get("/{path:path}") -async def get_file(path: str, storage: StorageService = Depends(get_storage_service)): +async def get_file(path: str, storage: StorageServiceDep): if not isinstance(storage, LocalStorageService): raise HTTPException(404, "Not Found") if not await storage.is_exists(path): diff --git a/app/router/lio.py b/app/router/lio.py index 8d3d960..2cabebb 100644 --- a/app/router/lio.py +++ b/app/router/lio.py @@ -11,21 +11,18 @@ from app.database.playlists import Playlist as DBPlaylist from app.database.room import Room from app.database.room_participated_user import RoomParticipatedUser from app.database.user import User -from app.dependencies.database import Database, get_redis -from app.dependencies.fetcher import get_fetcher -from app.dependencies.storage import get_storage_service -from app.fetcher import Fetcher +from app.dependencies.database import Database, Redis +from app.dependencies.fetcher import Fetcher +from app.dependencies.storage import StorageService from app.log import logger from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus -from app.storage.base import StorageService from app.utils import utcnow from .notification.server import server -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, HTTPException, Request, status from pydantic import BaseModel -from redis.asyncio import Redis from sqlalchemy import update from sqlmodel import col, select @@ -637,8 +634,8 @@ async def add_user_to_room( async def ensure_beatmap_present( beatmap_data: BeatmapEnsureRequest, db: Database, - redis: Redis = Depends(get_redis), - fetcher: Fetcher = Depends(get_fetcher), + redis: Redis, + fetcher: Fetcher, ) -> dict[str, Any]: """ 确保谱面在服务器中存在(包括元数据和原始文件缓存)。 @@ -677,7 +674,7 @@ class ReplayDataRequest(BaseModel): @router.post("/scores/replay") async def save_replay( req: ReplayDataRequest, - storage_service: StorageService = Depends(get_storage_service), + storage_service: StorageService, ): replay_data = req.mreplay replay_path = f"replays/{req.score_id}_{req.beatmap_id}_{req.user_id}_lazer_replay.osr" diff --git a/app/router/notification/channel.py b/app/router/notification/channel.py index bc251ca..62861ec 100644 --- a/app/router/notification/channel.py +++ b/app/router/notification/channel.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Literal, Self +from typing import Annotated, Any, Literal, Self from app.database.chat import ( ChannelType, @@ -11,7 +11,7 @@ from app.database.chat import ( UserSilenceResp, ) from app.database.user import User, UserResp -from app.dependencies.database import Database, get_redis +from app.dependencies.database import Database, Redis from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router @@ -20,7 +20,6 @@ from .server import server from fastapi import Depends, HTTPException, Path, Query, Security from pydantic import BaseModel, Field, model_validator -from redis.asyncio import Redis from sqlmodel import col, select @@ -38,11 +37,14 @@ class UpdateResponse(BaseModel): ) async def get_update( session: Database, - history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"), - since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), - includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"), - current_user: User = Security(get_current_user, scopes=["chat.read"]), - redis: Redis = Depends(get_redis), + current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])], + redis: Redis, + history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None, + since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None, + includes: Annotated[ + list[str], + Query(alias="includes[]", description="要包含的更新类型"), + ] = ["presence", "silences"], ): resp = UpdateResponse() if "presence" in includes: @@ -86,9 +88,9 @@ async def get_update( ) async def join_channel( session: Database, - channel: str = Path(..., description="频道 ID/名称"), - user: str = Path(..., description="用户 ID"), - current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), + channel: Annotated[str, Path(..., description="频道 ID/名称")], + user: Annotated[str, Path(..., description="用户 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])], ): # 使用明确的查询避免延迟加载 if channel.isdigit(): @@ -110,9 +112,9 @@ async def join_channel( ) async def leave_channel( session: Database, - channel: str = Path(..., description="频道 ID/名称"), - user: str = Path(..., description="用户 ID"), - current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), + channel: Annotated[str, Path(..., description="频道 ID/名称")], + user: Annotated[str, Path(..., description="用户 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])], ): # 使用明确的查询避免延迟加载 if channel.isdigit(): @@ -135,8 +137,8 @@ async def leave_channel( ) async def get_channel_list( session: Database, - current_user: User = Security(get_current_user, scopes=["chat.read"]), - redis: Redis = Depends(get_redis), + current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])], + redis: Redis, ): channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all() results = [] @@ -171,9 +173,9 @@ class GetChannelResp(BaseModel): ) async def get_channel( session: Database, - channel: str = Path(..., description="频道 ID/名称"), - current_user: User = Security(get_current_user, scopes=["chat.read"]), - redis: Redis = Depends(get_redis), + channel: Annotated[str, Path(..., description="频道 ID/名称")], + current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])], + redis: Redis, ): # 使用明确的查询避免延迟加载 if channel.isdigit(): @@ -245,9 +247,9 @@ class CreateChannelReq(BaseModel): ) async def create_channel( session: Database, - req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)), - current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), - redis: Redis = Depends(get_redis), + req: Annotated[CreateChannelReq, Depends(BodyOrForm(CreateChannelReq))], + current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])], + redis: Redis, ): if req.type == "PM": target = await session.get(User, req.target_id) diff --git a/app/router/notification/message.py b/app/router/notification/message.py index 6470c36..41ac452 100644 --- a/app/router/notification/message.py +++ b/app/router/notification/message.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Annotated + from app.database import ChatMessageResp from app.database.chat import ( ChannelType, @@ -11,7 +13,7 @@ from app.database.chat import ( UserSilenceResp, ) from app.database.user import User -from app.dependencies.database import Database, get_redis +from app.dependencies.database import Database, Redis from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.log import logger @@ -24,7 +26,6 @@ from .server import server from fastapi import Depends, HTTPException, Path, Query, Security from pydantic import BaseModel, Field -from redis.asyncio import Redis from sqlmodel import col, select @@ -41,9 +42,9 @@ class KeepAliveResp(BaseModel): ) async def keep_alive( session: Database, - history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"), - since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), - current_user: User = Security(get_current_user, scopes=["chat.read"]), + current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])], + history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None, + since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None, ): resp = KeepAliveResp() if history_since: @@ -73,9 +74,9 @@ class MessageReq(BaseModel): ) async def send_message( session: Database, - channel: str = Path(..., description="频道 ID/名称"), - req: MessageReq = Depends(BodyOrForm(MessageReq)), - current_user: User = Security(get_current_user, scopes=["chat.write"]), + channel: Annotated[str, Path(..., description="频道 ID/名称")], + req: Annotated[MessageReq, Depends(BodyOrForm(MessageReq))], + current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])], ): # 使用明确的查询来获取 channel,避免延迟加载 if channel.isdigit(): @@ -156,10 +157,10 @@ async def send_message( async def get_message( session: Database, channel: str, - limit: int = Query(50, ge=1, le=50, description="获取消息的数量"), - since: int = Query(0, ge=0, description="获取自此消息 ID 之后的消息(向前加载新消息)"), - until: int | None = Query(None, description="获取自此消息 ID 之前的消息(向后翻历史)"), - current_user: User = Security(get_current_user, scopes=["chat.read"]), + current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])], + limit: Annotated[int, Query(ge=1, le=50, description="获取消息的数量")] = 50, + since: Annotated[int, Query(ge=0, description="获取自此消息 ID 之后的消息(向前加载新消息)")] = 0, + until: Annotated[int | None, Query(description="获取自此消息 ID 之前的消息(向后翻历史)")] = None, ): # 1) 查频道 if channel.isdigit(): @@ -220,9 +221,9 @@ async def get_message( ) async def mark_as_read( session: Database, - channel: str = Path(..., description="频道 ID/名称"), - message: int = Path(..., description="消息 ID"), - current_user: User = Security(get_current_user, scopes=["chat.read"]), + channel: Annotated[str, Path(..., description="频道 ID/名称")], + message: Annotated[int, Path(..., description="消息 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])], ): # 使用明确的查询获取 channel,避免延迟加载 if channel.isdigit(): @@ -259,9 +260,9 @@ class NewPMResp(BaseModel): ) async def create_new_pm( session: Database, - req: PMReq = Depends(BodyOrForm(PMReq)), - current_user: User = Security(get_current_user, scopes=["chat.write"]), - redis: Redis = Depends(get_redis), + req: Annotated[PMReq, Depends(BodyOrForm(PMReq))], + current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])], + redis: Redis, ): user_id = current_user.id target = await session.get(User, req.target_id) diff --git a/app/router/notification/server.py b/app/router/notification/server.py index 778338c..29fc663 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -1,13 +1,14 @@ from __future__ import annotations import asyncio -from typing import overload +from typing import Annotated, overload from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp from app.database.notification import UserNotification, insert_notification from app.database.user import User from app.dependencies.database import ( DBFactory, + Redis, get_db_factory, get_redis, with_db, @@ -22,7 +23,6 @@ from app.utils import bg_tasks from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect from fastapi.security import SecurityScopes from fastapi.websockets import WebSocketState -from redis.asyncio import Redis from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -298,10 +298,10 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory): @chat_router.websocket("/notification-server") async def chat_websocket( websocket: WebSocket, - token: str | None = Query(None, description="认证令牌,支持通过URL参数传递"), - access_token: str | None = Query(None, description="访问令牌,支持通过URL参数传递"), - authorization: str | None = Header(None, description="Bearer认证头"), - factory: DBFactory = Depends(get_db_factory), + factory: Annotated[DBFactory, Depends(get_db_factory)], + token: Annotated[str | None, Query(description="认证令牌,支持通过URL参数传递")] = None, + access_token: Annotated[str | None, Query(description="访问令牌,支持通过URL参数传递")] = None, + authorization: Annotated[str | None, Header(description="Bearer认证头")] = None, ): if not server._subscribed: server._subscribed = True diff --git a/app/router/private/admin.py b/app/router/private/admin.py index e29a264..57dcf98 100644 --- a/app/router/private/admin.py +++ b/app/router/private/admin.py @@ -1,15 +1,16 @@ from __future__ import annotations +from typing import Annotated + from app.database.auth import OAuthToken from app.database.verification import LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp from app.dependencies.database import Database -from app.dependencies.geoip import get_geoip_helper +from app.dependencies.geoip import GeoIPService from app.dependencies.user import UserAndToken, get_client_user_and_token -from app.helpers.geoip_helper import GeoIPHelper from .router import router -from fastapi import Depends, HTTPException, Security +from fastapi import HTTPException, Security from pydantic import BaseModel from sqlmodel import col, select @@ -28,8 +29,8 @@ class SessionsResp(BaseModel): ) async def get_sessions( session: Database, - user_and_token: UserAndToken = Security(get_client_user_and_token), - geoip: GeoIPHelper = Depends(get_geoip_helper), + user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)], + geoip: GeoIPService, ): current_user, token = user_and_token sessions = ( @@ -57,7 +58,7 @@ async def get_sessions( async def delete_session( session: Database, session_id: int, - user_and_token: UserAndToken = Security(get_client_user_and_token), + user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)], ): current_user, token = user_and_token if session_id == token.id: @@ -91,8 +92,8 @@ class TrustedDevicesResp(BaseModel): ) async def get_trusted_devices( session: Database, - user_and_token: UserAndToken = Security(get_client_user_and_token), - geoip: GeoIPHelper = Depends(get_geoip_helper), + user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)], + geoip: GeoIPService, ): current_user, token = user_and_token devices = ( @@ -131,7 +132,7 @@ async def get_trusted_devices( async def delete_trusted_device( session: Database, device_id: int, - user_and_token: UserAndToken = Security(get_client_user_and_token), + user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)], ): current_user, token = user_and_token device = await session.get(TrustedDevice, device_id) diff --git a/app/router/private/avatar.py b/app/router/private/avatar.py index e37596f..0af8694 100644 --- a/app/router/private/avatar.py +++ b/app/router/private/avatar.py @@ -1,25 +1,24 @@ from __future__ import annotations import hashlib +from typing import Annotated -from app.database.user import User from app.dependencies.database import Database -from app.dependencies.storage import get_storage_service -from app.dependencies.user import get_client_user -from app.storage.base import StorageService +from app.dependencies.storage import StorageService +from app.dependencies.user import ClientUser from app.utils import check_image from .router import router -from fastapi import Depends, File, Security +from fastapi import File @router.post("/avatar/upload", name="上传头像", tags=["用户", "g0v0 API"]) async def upload_avatar( session: Database, - content: bytes = File(...), - current_user: User = Security(get_client_user), - storage: StorageService = Depends(get_storage_service), + content: Annotated[bytes, File(...)], + current_user: ClientUser, + storage: StorageService, ): """上传用户头像 diff --git a/app/router/private/beatmapset.py b/app/router/private/beatmapset.py index ba3c4ee..5b80841 100644 --- a/app/router/private/beatmapset.py +++ b/app/router/private/beatmapset.py @@ -1,17 +1,18 @@ from __future__ import annotations +from typing import Annotated + from app.database.beatmap import Beatmap from app.database.beatmapset import Beatmapset from app.database.beatmapset_ratings import BeatmapRating from app.database.score import Score -from app.database.user import User from app.dependencies.database import Database -from app.dependencies.user import get_client_user +from app.dependencies.user import ClientUser from app.service.beatmapset_update_service import get_beatmapset_update_service from .router import router -from fastapi import Body, Depends, HTTPException, Security +from fastapi import Body, Depends, HTTPException from fastapi_limiter.depends import RateLimiter from sqlmodel import col, exists, select @@ -25,7 +26,7 @@ from sqlmodel import col, exists, select async def can_rate_beatmapset( beatmapset_id: int, session: Database, - current_user: User = Security(get_client_user), + current_user: ClientUser, ): """检查用户是否可以评价谱面集 @@ -57,8 +58,8 @@ async def can_rate_beatmapset( async def rate_beatmaps( beatmapset_id: int, session: Database, - rating: int = Body(..., ge=0, le=10), - current_user: User = Security(get_client_user), + rating: Annotated[int, Body(..., ge=0, le=10)], + current_user: ClientUser, ): """为谱面集评分 @@ -96,7 +97,7 @@ async def rate_beatmaps( async def sync_beatmapset( beatmapset_id: int, session: Database, - current_user: User = Security(get_client_user), + current_user: ClientUser, ): """请求同步谱面集 diff --git a/app/router/private/cover.py b/app/router/private/cover.py index 04f8d1b..71992e0 100644 --- a/app/router/private/cover.py +++ b/app/router/private/cover.py @@ -1,25 +1,25 @@ from __future__ import annotations import hashlib +from typing import Annotated -from app.database.user import User, UserProfileCover +from app.database.user import UserProfileCover from app.dependencies.database import Database -from app.dependencies.storage import get_storage_service -from app.dependencies.user import get_client_user -from app.storage.base import StorageService +from app.dependencies.storage import StorageService +from app.dependencies.user import ClientUser from app.utils import check_image from .router import router -from fastapi import Depends, File, Security +from fastapi import File @router.post("/cover/upload", name="上传头图", tags=["用户", "g0v0 API"]) async def upload_cover( session: Database, - content: bytes = File(...), - current_user: User = Security(get_client_user), - storage: StorageService = Depends(get_storage_service), + content: Annotated[bytes, File(...)], + current_user: ClientUser, + storage: StorageService, ): """上传用户头图 diff --git a/app/router/private/oauth.py b/app/router/private/oauth.py index f4f5d78..2af00dc 100644 --- a/app/router/private/oauth.py +++ b/app/router/private/oauth.py @@ -1,16 +1,15 @@ from __future__ import annotations import secrets +from typing import Annotated from app.database.auth import OAuthClient, OAuthToken -from app.database.user import User -from app.dependencies.database import Database, get_redis -from app.dependencies.user import get_client_user +from app.dependencies.database import Database, Redis +from app.dependencies.user import ClientUser from .router import router -from fastapi import Body, Depends, HTTPException, Security -from redis.asyncio import Redis +from fastapi import Body, HTTPException from sqlmodel import select, text @@ -22,10 +21,10 @@ from sqlmodel import select, text ) async def create_oauth_app( session: Database, - name: str = Body(..., max_length=100, description="应用程序名称"), - description: str = Body("", description="应用程序描述"), - redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"), - current_user: User = Security(get_client_user), + name: Annotated[str, Body(..., max_length=100, description="应用程序名称")], + redirect_uris: Annotated[list[str], Body(..., description="允许的重定向 URI 列表")], + current_user: ClientUser, + description: Annotated[str, Body(description="应用程序描述")] = "", ): result = await session.execute( text( @@ -64,7 +63,7 @@ async def create_oauth_app( async def get_oauth_app( session: Database, client_id: int, - current_user: User = Security(get_client_user), + current_user: ClientUser, ): oauth_app = await session.get(OAuthClient, client_id) if not oauth_app: @@ -85,7 +84,7 @@ async def get_oauth_app( ) async def get_user_oauth_apps( session: Database, - current_user: User = Security(get_client_user), + current_user: ClientUser, ): oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id)) return [ @@ -109,7 +108,7 @@ async def get_user_oauth_apps( async def delete_oauth_app( session: Database, client_id: int, - current_user: User = Security(get_client_user), + current_user: ClientUser, ): oauth_client = await session.get(OAuthClient, client_id) if not oauth_client: @@ -134,10 +133,10 @@ async def delete_oauth_app( async def update_oauth_app( session: Database, client_id: int, - name: str = Body(..., max_length=100, description="应用程序新名称"), - description: str = Body("", description="应用程序新描述"), - redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"), - current_user: User = Security(get_client_user), + name: Annotated[str, Body(..., max_length=100, description="应用程序新名称")], + redirect_uris: Annotated[list[str], Body(..., description="新的重定向 URI 列表")], + current_user: ClientUser, + description: Annotated[str, Body(description="应用程序新描述")] = "", ): oauth_client = await session.get(OAuthClient, client_id) if not oauth_client: @@ -168,7 +167,7 @@ async def update_oauth_app( async def refresh_secret( session: Database, client_id: int, - current_user: User = Security(get_client_user), + current_user: ClientUser, ): oauth_client = await session.get(OAuthClient, client_id) if not oauth_client: @@ -200,10 +199,10 @@ async def refresh_secret( async def generate_oauth_code( session: Database, client_id: int, - current_user: User = Security(get_client_user), - redirect_uri: str = Body(..., description="授权后重定向的 URI"), - scopes: list[str] = Body(..., description="请求的权限范围列表"), - redis: Redis = Depends(get_redis), + current_user: ClientUser, + redirect_uri: Annotated[str, Body(..., description="授权后重定向的 URI")], + scopes: Annotated[list[str], Body(..., description="请求的权限范围列表")], + redis: Redis, ): client = await session.get(OAuthClient, client_id) if not client: diff --git a/app/router/private/relationship.py b/app/router/private/relationship.py index 4698350..1f882cb 100644 --- a/app/router/private/relationship.py +++ b/app/router/private/relationship.py @@ -1,13 +1,15 @@ from __future__ import annotations -from app.database import Relationship, User +from typing import Annotated + +from app.database import Relationship from app.database.relationship import RelationshipType from app.dependencies.database import Database -from app.dependencies.user import get_client_user +from app.dependencies.user import ClientUser from .router import router -from fastapi import HTTPException, Path, Security +from fastapi import HTTPException, Path from pydantic import BaseModel, Field from sqlmodel import select @@ -27,8 +29,8 @@ class CheckResponse(BaseModel): ) async def check_user_relationship( db: Database, - user_id: int = Path(..., description="目标用户的 ID"), - current_user: User = Security(get_client_user), + user_id: Annotated[int, Path(..., description="目标用户的 ID")], + current_user: ClientUser, ): if user_id == current_user.id: raise HTTPException(422, "Cannot check relationship with yourself") diff --git a/app/router/private/score.py b/app/router/private/score.py index 75225bd..e640121 100644 --- a/app/router/private/score.py +++ b/app/router/private/score.py @@ -1,17 +1,14 @@ from __future__ import annotations from app.database.score import Score -from app.database.user import User -from app.dependencies.database import Database, get_redis -from app.dependencies.storage import get_storage_service -from app.dependencies.user import get_client_user +from app.dependencies.database import Database, Redis +from app.dependencies.storage import StorageService +from app.dependencies.user import ClientUser from app.service.user_cache_service import refresh_user_cache_background -from app.storage.base import StorageService from .router import router -from fastapi import BackgroundTasks, Depends, HTTPException, Security -from redis.asyncio import Redis +from fastapi import BackgroundTasks, HTTPException @router.delete( @@ -24,9 +21,9 @@ async def delete_score( session: Database, background_task: BackgroundTasks, score_id: int, - redis: Redis = Depends(get_redis), - current_user: User = Security(get_client_user), - storage_service: StorageService = Depends(get_storage_service), + redis: Redis, + current_user: ClientUser, + storage_service: StorageService, ): """删除成绩 diff --git a/app/router/private/team.py b/app/router/private/team.py index 681cd13..b60461c 100644 --- a/app/router/private/team.py +++ b/app/router/private/team.py @@ -1,12 +1,13 @@ from __future__ import annotations import hashlib +from typing import Annotated from app.database.team import Team, TeamMember, TeamRequest from app.database.user import BASE_INCLUDES, User, UserResp -from app.dependencies.database import Database, get_redis -from app.dependencies.storage import get_storage_service -from app.dependencies.user import get_client_user +from app.dependencies.database import Database, Redis +from app.dependencies.storage import StorageService +from app.dependencies.user import ClientUser from app.models.notification import ( TeamApplicationAccept, TeamApplicationReject, @@ -14,27 +15,25 @@ from app.models.notification import ( ) from app.router.notification import server from app.service.ranking_cache_service import get_ranking_cache_service -from app.storage.base import StorageService from app.utils import check_image, utcnow from .router import router -from fastapi import Depends, File, Form, HTTPException, Path, Request, Security +from fastapi import File, Form, HTTPException, Path, Request from pydantic import BaseModel -from redis.asyncio import Redis from sqlmodel import exists, select @router.post("/team", name="创建战队", response_model=Team, tags=["战队", "g0v0 API"]) async def create_team( session: Database, - storage: StorageService = Depends(get_storage_service), - current_user: User = Security(get_client_user), - flag: bytes = File(..., description="战队图标文件"), - cover: bytes = File(..., description="战队头图文件"), - name: str = Form(max_length=100, description="战队名称"), - short_name: str = Form(max_length=10, description="战队缩写"), - redis: Redis = Depends(get_redis), + storage: StorageService, + current_user: ClientUser, + flag: Annotated[bytes, File(..., description="战队图标文件")], + cover: Annotated[bytes, File(..., description="战队头图文件")], + name: Annotated[str, Form(max_length=100, description="战队名称")], + short_name: Annotated[str, Form(max_length=10, description="战队缩写")], + redis: Redis, ): """创建战队。 @@ -88,13 +87,13 @@ async def create_team( async def update_team( team_id: int, session: Database, - storage: StorageService = Depends(get_storage_service), - current_user: User = Security(get_client_user), - flag: bytes | None = File(default=None, description="战队图标文件"), - cover: bytes | None = File(default=None, description="战队头图文件"), - name: str | None = Form(default=None, max_length=100, description="战队名称"), - short_name: str | None = Form(default=None, max_length=10, description="战队缩写"), - leader_id: int | None = Form(default=None, description="战队队长 ID"), + storage: StorageService, + current_user: ClientUser, + flag: Annotated[bytes | None, File(description="战队图标文件")] = None, + cover: Annotated[bytes | None, File(description="战队头图文件")] = None, + name: Annotated[str | None, Form(max_length=100, description="战队名称")] = None, + short_name: Annotated[str | None, Form(max_length=10, description="战队缩写")] = None, + leader_id: Annotated[int | None, Form(description="战队队长 ID")] = None, ): """修改战队。 @@ -161,9 +160,9 @@ async def update_team( @router.delete("/team/{team_id}", name="删除战队", status_code=204, tags=["战队", "g0v0 API"]) async def delete_team( session: Database, - team_id: int = Path(..., description="战队 ID"), - current_user: User = Security(get_client_user), - redis: Redis = Depends(get_redis), + team_id: Annotated[int, Path(..., description="战队 ID")], + current_user: ClientUser, + redis: Redis, ): team = await session.get(Team, team_id) if not team: @@ -191,7 +190,7 @@ class TeamQueryResp(BaseModel): @router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"]) async def get_team( session: Database, - team_id: int = Path(..., description="战队 ID"), + team_id: Annotated[int, Path(..., description="战队 ID")], ): members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all() return TeamQueryResp( @@ -203,8 +202,8 @@ async def get_team( @router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"]) async def request_join_team( session: Database, - team_id: int = Path(..., description="战队 ID"), - current_user: User = Security(get_client_user), + team_id: Annotated[int, Path(..., description="战队 ID")], + current_user: ClientUser, ): team = await session.get(Team, team_id) if not team: @@ -231,10 +230,10 @@ async def request_join_team( async def handle_request( req: Request, session: Database, - team_id: int = Path(..., description="战队 ID"), - user_id: int = Path(..., description="用户 ID"), - current_user: User = Security(get_client_user), - redis: Redis = Depends(get_redis), + team_id: Annotated[int, Path(..., description="战队 ID")], + user_id: Annotated[int, Path(..., description="用户 ID")], + current_user: ClientUser, + redis: Redis, ): team = await session.get(Team, team_id) if not team: @@ -272,10 +271,10 @@ async def handle_request( @router.delete("/team/{team_id}/{user_id}", name="踢出成员 / 退出战队", status_code=204, tags=["战队", "g0v0 API"]) async def kick_member( session: Database, - team_id: int = Path(..., description="战队 ID"), - user_id: int = Path(..., description="用户 ID"), - current_user: User = Security(get_client_user), - redis: Redis = Depends(get_redis), + team_id: Annotated[int, Path(..., description="战队 ID")], + user_id: Annotated[int, Path(..., description="用户 ID")], + current_user: ClientUser, + redis: Redis, ): team = await session.get(Team, team_id) if not team: diff --git a/app/router/private/totp.py b/app/router/private/totp.py index 2435567..06406aa 100644 --- a/app/router/private/totp.py +++ b/app/router/private/totp.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Annotated + from app.auth import ( check_totp_backup_code, finish_create_totp_key, @@ -9,17 +11,15 @@ from app.auth import ( ) from app.const import BACKUP_CODE_LENGTH from app.database.auth import TotpKeys -from app.database.user import User -from app.dependencies.database import Database, get_redis -from app.dependencies.user import get_client_user +from app.dependencies.database import Database, Redis +from app.dependencies.user import ClientUser from app.models.totp import FinishStatus, StartCreateTotpKeyResp from .router import router -from fastapi import Body, Depends, HTTPException, Security +from fastapi import Body, HTTPException from pydantic import BaseModel import pyotp -from redis.asyncio import Redis class TotpStatusResp(BaseModel): @@ -37,7 +37,7 @@ class TotpStatusResp(BaseModel): response_model=TotpStatusResp, ) async def get_totp_status( - current_user: User = Security(get_client_user), + current_user: ClientUser, ): """检查用户是否已创建TOTP""" totp_key = await current_user.awaitable_attrs.totp_key @@ -62,8 +62,8 @@ async def get_totp_status( status_code=201, ) async def start_create_totp( - redis: Redis = Depends(get_redis), - current_user: User = Security(get_client_user), + redis: Redis, + current_user: ClientUser, ): if await current_user.awaitable_attrs.totp_key: raise HTTPException(status_code=400, detail="TOTP is already enabled for this user") @@ -98,9 +98,9 @@ async def start_create_totp( ) async def finish_create_totp( session: Database, - code: str = Body(..., embed=True, description="用户提供的 TOTP 代码"), - redis: Redis = Depends(get_redis), - current_user: User = Security(get_client_user), + code: Annotated[str, Body(..., embed=True, description="用户提供的 TOTP 代码")], + redis: Redis, + current_user: ClientUser, ): status, backup_codes = await finish_create_totp_key(current_user, code, redis, session) if status == FinishStatus.SUCCESS: @@ -122,9 +122,9 @@ async def finish_create_totp( ) async def disable_totp( session: Database, - code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"), - redis: Redis = Depends(get_redis), - current_user: User = Security(get_client_user), + code: Annotated[str, Body(..., embed=True, description="用户提供的 TOTP 代码或备份码")], + redis: Redis, + current_user: ClientUser, ): totp = await session.get(TotpKeys, current_user.id) if not totp: diff --git a/app/router/private/username.py b/app/router/private/username.py index 571cd40..18eb219 100644 --- a/app/router/private/username.py +++ b/app/router/private/username.py @@ -1,24 +1,26 @@ from __future__ import annotations +from typing import Annotated + from app.auth import validate_username from app.config import settings from app.database.events import Event, EventType from app.database.user import User from app.dependencies.database import Database -from app.dependencies.user import get_client_user +from app.dependencies.user import ClientUser from app.utils import utcnow from .router import router -from fastapi import Body, HTTPException, Security +from fastapi import Body, HTTPException from sqlmodel import exists, select @router.post("/rename", name="修改用户名", tags=["用户", "g0v0 API"]) async def user_rename( session: Database, - new_name: str = Body(..., description="新的用户名"), - current_user: User = Security(get_client_user), + new_name: Annotated[str, Body(..., description="新的用户名")], + current_user: ClientUser, ): """修改用户名 diff --git a/app/router/v1/beatmap.py b/app/router/v1/beatmap.py index 3301fd2..b723713 100644 --- a/app/router/v1/beatmap.py +++ b/app/router/v1/beatmap.py @@ -1,24 +1,22 @@ from __future__ import annotations from datetime import datetime -from typing import Literal +from typing import Annotated, Literal from app.database.beatmap import Beatmap, calculate_beatmap_attributes from app.database.beatmap_playcounts import BeatmapPlaycounts from app.database.beatmapset import Beatmapset from app.database.favourite_beatmapset import FavouriteBeatmapset from app.database.score import Score -from app.dependencies.database import Database, get_redis -from app.dependencies.fetcher import get_fetcher -from app.fetcher import Fetcher +from app.dependencies.database import Database, Redis +from app.dependencies.fetcher import Fetcher from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.mods import int_to_mods from app.models.score import GameMode from .router import AllStrModel, router -from fastapi import Depends, Query -from redis.asyncio import Redis +from fastapi import Query from sqlmodel import col, func, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -148,18 +146,18 @@ class V1Beatmap(AllStrModel): ) async def get_beatmaps( session: Database, - since: datetime | None = Query(None, description="自指定时间后拥有排行榜的谱面"), - beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"), - beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"), - user: str | None = Query(None, alias="u", description="谱师"), - type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), - ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0, le=3), # TODO - convert: bool = Query(False, alias="a", description="转谱"), # TODO - checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"), - limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"), - mods: int = Query(0, description="应用到谱面属性的 MOD"), - redis: Redis = Depends(get_redis), - fetcher: Fetcher = Depends(get_fetcher), + redis: Redis, + fetcher: Fetcher, + since: Annotated[datetime | None, Query(description="自指定时间后拥有排行榜的谱面")] = None, + beatmapset_id: Annotated[int | None, Query(alias="s", description="谱面集 ID")] = None, + beatmap_id: Annotated[int | None, Query(alias="b", description="谱面 ID")] = None, + user: Annotated[str | None, Query(alias="u", description="谱师")] = None, + type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None, + ruleset_id: Annotated[int | None, Query(alias="m", description="Ruleset ID", ge=0, le=3)] = None, # TODO + convert: Annotated[bool, Query(alias="a", description="转谱")] = False, # TODO + checksum: Annotated[str | None, Query(alias="h", description="谱面文件 MD5")] = None, + limit: Annotated[int, Query(ge=1, le=500, description="返回结果数量限制")] = 500, + mods: Annotated[int, Query(description="应用到谱面属性的 MOD")] = 0, ): beatmaps: list[Beatmap] = [] results = [] diff --git a/app/router/v1/public_user.py b/app/router/v1/public_user.py index 1f5df71..dadbfb9 100644 --- a/app/router/v1/public_user.py +++ b/app/router/v1/public_user.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Literal +from typing import Annotated, Literal from app.database.statistics import UserStatistics from app.database.user import User @@ -181,9 +181,9 @@ async def _count_online_users_optimized(redis): ) async def api_get_player_info( session: Database, - scope: Literal["stats", "events", "info", "all"] = Query(..., description="信息范围"), - id: int | None = Query(None, ge=3, le=2147483647, description="用户 ID"), - name: str | None = Query(None, regex=r"^[\w \[\]-]{2,32}$", description="用户名"), + scope: Annotated[Literal["stats", "events", "info", "all"], Query(..., description="信息范围")], + id: Annotated[int | None, Query(ge=3, le=2147483647, description="用户 ID")] = None, + name: Annotated[str | None, Query(regex=r"^[\w \[\]-]{2,32}$", description="用户名")] = None, ): """ 获取指定玩家的信息 diff --git a/app/router/v1/replay.py b/app/router/v1/replay.py index e37057e..b0ffc99 100644 --- a/app/router/v1/replay.py +++ b/app/router/v1/replay.py @@ -2,15 +2,14 @@ from __future__ import annotations import base64 from datetime import date -from typing import Literal +from typing import Annotated, Literal from app.database.counts import ReplayWatchedCount from app.database.score import Score from app.dependencies.database import Database -from app.dependencies.storage import get_storage_service +from app.dependencies.storage import StorageService from app.models.mods import int_to_mods from app.models.score import GameMode -from app.storage import StorageService from .router import router @@ -34,18 +33,20 @@ class ReplayModel(BaseModel): ) async def download_replay( session: Database, - beatmap: int = Query(..., alias="b", description="谱面 ID"), - user: str = Query(..., alias="u", description="用户"), - ruleset_id: int | None = Query( - None, - alias="m", - description="Ruleset ID", - ge=0, - ), - score_id: int | None = Query(None, alias="s", description="成绩 ID"), - type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), - mods: int = Query(0, description="成绩的 MOD"), - storage_service: StorageService = Depends(get_storage_service), + beatmap: Annotated[int, Query(..., alias="b", description="谱面 ID")], + user: Annotated[str, Query(..., alias="u", description="用户")], + storage_service: StorageService, + ruleset_id: Annotated[ + int | None, + Query( + alias="m", + description="Ruleset ID", + ge=0, + ), + ] = None, + score_id: Annotated[int | None, Query(alias="s", description="成绩 ID")] = None, + type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None, + mods: Annotated[int, Query(description="成绩的 MOD")] = 0, ): mods_ = int_to_mods(mods) if score_id is not None: diff --git a/app/router/v1/score.py b/app/router/v1/score.py index ccbcf4f..4ac9b42 100644 --- a/app/router/v1/score.py +++ b/app/router/v1/score.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime, timedelta -from typing import Literal +from typing import Annotated, Literal from app.database.best_scores import PPBestScore from app.database.score import Score, get_leaderboard @@ -69,10 +69,10 @@ class V1Score(AllStrModel): ) async def get_user_best( session: Database, - user: str = Query(..., alias="u", description="用户"), - ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), - type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), - limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), + user: Annotated[str, Query(..., alias="u", description="用户")], + ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0, + type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None, + limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10, ): try: scores = ( @@ -101,10 +101,10 @@ async def get_user_best( ) async def get_user_recent( session: Database, - user: str = Query(..., alias="u", description="用户"), - ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), - type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), - limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), + user: Annotated[str, Query(..., alias="u", description="用户")], + ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0, + type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None, + limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10, ): try: scores = ( @@ -133,12 +133,12 @@ async def get_user_recent( ) async def get_scores( session: Database, - user: str | None = Query(None, alias="u", description="用户"), - beatmap_id: int = Query(alias="b", description="谱面 ID"), - ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), - type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), - limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), - mods: int = Query(0, description="成绩的 MOD"), + beatmap_id: Annotated[int, Query(alias="b", description="谱面 ID")], + user: Annotated[str | None, Query(alias="u", description="用户")] = None, + ruleset_id: Annotated[int, Query(alias="m", description="Ruleset ID", ge=0)] = 0, + type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None, + limit: Annotated[int, Query(ge=1, le=100, description="返回的成绩数量")] = 10, + mods: Annotated[int, Query(description="成绩的 MOD")] = 0, ): try: if user is not None: diff --git a/app/router/v1/user.py b/app/router/v1/user.py index 77e0369..52ee19a 100644 --- a/app/router/v1/user.py +++ b/app/router/v1/user.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime -from typing import Literal +from typing import Annotated, Literal from app.database.statistics import UserStatistics, UserStatisticsResp from app.database.user import User @@ -104,10 +104,10 @@ class V1User(AllStrModel): async def get_user( session: Database, background_tasks: BackgroundTasks, - user: str = Query(..., alias="u", description="用户"), - ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0), - type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), - event_days: int = Query(default=1, ge=1, le=31, description="从现在起所有事件的最大天数"), + user: Annotated[str, Query(..., alias="u", description="用户")], + ruleset_id: Annotated[int | None, Query(alias="m", description="Ruleset ID", ge=0)] = None, + type: Annotated[Literal["string", "id"] | None, Query(description="用户类型:string 用户名称 / id 用户 ID")] = None, + event_days: Annotated[int, Query(ge=1, le=31, description="从现在起所有事件的最大天数")] = 1, ): redis = get_redis() cache_service = get_user_cache_service(redis) diff --git a/app/router/v2/beatmap.py b/app/router/v2/beatmap.py index 506f9f3..024a542 100644 --- a/app/router/v2/beatmap.py +++ b/app/router/v2/beatmap.py @@ -3,13 +3,13 @@ from __future__ import annotations import asyncio import hashlib import json +from typing import Annotated from app.database import Beatmap, BeatmapResp, User from app.database.beatmap import calculate_beatmap_attributes -from app.dependencies.database import Database, get_redis -from app.dependencies.fetcher import get_fetcher +from app.dependencies.database import Database, Redis +from app.dependencies.fetcher import Fetcher from app.dependencies.user import get_current_user -from app.fetcher import Fetcher from app.models.beatmap import BeatmapAttributes from app.models.mods import APIMod, int_to_mods from app.models.score import ( @@ -18,10 +18,9 @@ from app.models.score import ( from .router import router -from fastapi import Depends, HTTPException, Path, Query, Security +from fastapi import HTTPException, Path, Query, Security from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel -from redis.asyncio import Redis import rosu_pp_py as rosu from sqlmodel import col, select @@ -44,11 +43,11 @@ class BatchGetResp(BaseModel): ) async def lookup_beatmap( db: Database, - id: int | None = Query(default=None, alias="id", description="谱面 ID"), - md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"), - filename: str | None = Query(default=None, alias="filename", description="谱面文件名"), - current_user: User = Security(get_current_user, scopes=["public"]), - fetcher: Fetcher = Depends(get_fetcher), + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + fetcher: Fetcher, + id: Annotated[int | None, Query(alias="id", description="谱面 ID")] = None, + md5: Annotated[str | None, Query(alias="checksum", description="谱面文件 MD5")] = None, + filename: Annotated[str | None, Query(alias="filename", description="谱面文件名")] = None, ): if id is None and md5 is None and filename is None: raise HTTPException( @@ -75,9 +74,9 @@ async def lookup_beatmap( ) async def get_beatmap( db: Database, - beatmap_id: int = Path(..., description="谱面 ID"), - current_user: User = Security(get_current_user, scopes=["public"]), - fetcher: Fetcher = Depends(get_fetcher), + beatmap_id: Annotated[int, Path(..., description="谱面 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + fetcher: Fetcher, ): try: beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id) @@ -95,9 +94,12 @@ async def get_beatmap( ) async def batch_get_beatmaps( db: Database, - beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"), - current_user: User = Security(get_current_user, scopes=["public"]), - fetcher: Fetcher = Depends(get_fetcher), + beatmap_ids: Annotated[ + list[int], + Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"), + ], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + fetcher: Fetcher, ): if not beatmap_ids: beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all() @@ -127,16 +129,19 @@ async def batch_get_beatmaps( ) async def get_beatmap_attributes( db: Database, - beatmap_id: int = Path(..., description="谱面 ID"), - current_user: User = Security(get_current_user, scopes=["public"]), - mods: list[str] = Query( - default_factory=list, - description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称", - ), - ruleset: GameMode | None = Query(default=None, description="指定 ruleset;为空则使用谱面自身模式"), - ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3), - redis: Redis = Depends(get_redis), - fetcher: Fetcher = Depends(get_fetcher), + beatmap_id: Annotated[int, Path(..., description="谱面 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + mods: Annotated[ + list[str], + Query( + default_factory=list, + description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称", + ), + ], + redis: Redis, + fetcher: Fetcher, + ruleset: Annotated[GameMode | None, Query(description="指定 ruleset;为空则使用谱面自身模式")] = None, + ruleset_id: Annotated[int | None, Query(description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3)] = None, ): mods_ = [] if mods and mods[0].isdigit(): diff --git a/app/router/v2/beatmapset.py b/app/router/v2/beatmapset.py index 01b5658..c4f2561 100644 --- a/app/router/v2/beatmapset.py +++ b/app/router/v2/beatmapset.py @@ -6,23 +6,20 @@ from urllib.parse import parse_qs from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.database.beatmapset import SearchBeatmapsetsResp -from app.dependencies.beatmap_download import get_beatmap_download_service -from app.dependencies.beatmapset_cache import get_beatmapset_cache_dependency -from app.dependencies.database import Database, get_redis, with_db -from app.dependencies.fetcher import get_fetcher -from app.dependencies.geoip import get_client_ip, get_geoip_helper -from app.dependencies.user import get_client_user, get_current_user -from app.fetcher import Fetcher +from app.dependencies.beatmap_download import DownloadService +from app.dependencies.beatmapset_cache import BeatmapsetCacheService +from app.dependencies.database import Database, Redis, with_db +from app.dependencies.fetcher import Fetcher +from app.dependencies.geoip import IPAddress, get_geoip_helper +from app.dependencies.user import ClientUser, get_current_user from app.models.beatmap import SearchQueryModel from app.service.asset_proxy_helper import process_response_assets -from app.service.beatmap_download_service import BeatmapDownloadService -from app.service.beatmapset_cache_service import BeatmapsetCacheService, generate_hash +from app.service.beatmapset_cache_service import generate_hash from .router import router from fastapi import ( BackgroundTasks, - Depends, Form, HTTPException, Path, @@ -53,10 +50,10 @@ async def search_beatmapset( query: Annotated[SearchQueryModel, Query(...)], request: Request, background_tasks: BackgroundTasks, - current_user: User = Security(get_current_user, scopes=["public"]), - fetcher: Fetcher = Depends(get_fetcher), - redis=Depends(get_redis), - cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency), + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + fetcher: Fetcher, + redis: Redis, + cache_service: BeatmapsetCacheService, ): params = parse_qs(qs=request.url.query, keep_blank_values=True) cursor = {} @@ -134,10 +131,10 @@ async def search_beatmapset( async def lookup_beatmapset( db: Database, request: Request, - beatmap_id: int = Query(description="谱面 ID"), - current_user: User = Security(get_current_user, scopes=["public"]), - fetcher: Fetcher = Depends(get_fetcher), - cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency), + beatmap_id: Annotated[int, Query(description="谱面 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + fetcher: Fetcher, + cache_service: BeatmapsetCacheService, ): # 先尝试从缓存获取 cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id) @@ -170,10 +167,10 @@ async def lookup_beatmapset( async def get_beatmapset( db: Database, request: Request, - beatmapset_id: int = Path(..., description="谱面集 ID"), - current_user: User = Security(get_current_user, scopes=["public"]), - fetcher: Fetcher = Depends(get_fetcher), - cache_service: BeatmapsetCacheService = Depends(get_beatmapset_cache_dependency), + beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + fetcher: Fetcher, + cache_service: BeatmapsetCacheService, ): # 先尝试从缓存获取 cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id) @@ -203,14 +200,12 @@ async def get_beatmapset( description="\n下载谱面集文件。基于请求IP地理位置智能分流,支持负载均衡和自动故障转移。中国IP使用Sayobot镜像,其他地区使用Nerinyan和OsuDirect镜像。", ) async def download_beatmapset( - request: Request, - beatmapset_id: int = Path(..., description="谱面集 ID"), - no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"), - current_user: User = Security(get_client_user), - download_service: BeatmapDownloadService = Depends(get_beatmap_download_service), + client_ip: IPAddress, + beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")], + current_user: ClientUser, + download_service: DownloadService, + no_video: Annotated[bool, Query(alias="noVideo", description="是否下载无视频版本")] = True, ): - client_ip = get_client_ip(request) - geoip_helper = get_geoip_helper() geo_info = geoip_helper.lookup(client_ip) country_code = geo_info.get("country_iso", "") @@ -242,9 +237,12 @@ async def download_beatmapset( ) async def favourite_beatmapset( db: Database, - beatmapset_id: int = Path(..., description="谱面集 ID"), - action: Literal["favourite", "unfavourite"] = Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"), - current_user: User = Security(get_client_user), + beatmapset_id: Annotated[int, Path(..., description="谱面集 ID")], + action: Annotated[ + Literal["favourite", "unfavourite"], + Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"), + ], + current_user: ClientUser, ): existing_favourite = ( await db.exec( diff --git a/app/router/v2/cache.py b/app/router/v2/cache.py index 08a0b27..fe610a6 100644 --- a/app/router/v2/cache.py +++ b/app/router/v2/cache.py @@ -5,14 +5,13 @@ from __future__ import annotations -from app.dependencies.database import get_redis +from app.dependencies.database import Redis from app.service.user_cache_service import get_user_cache_service from .router import router -from fastapi import Depends, HTTPException +from fastapi import HTTPException from pydantic import BaseModel -from redis.asyncio import Redis class CacheStatsResponse(BaseModel): @@ -28,7 +27,7 @@ class CacheStatsResponse(BaseModel): tags=["缓存管理"], ) async def get_cache_stats( - redis: Redis = Depends(get_redis), + redis: Redis, # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释,可根据需要启用 ): try: @@ -68,7 +67,7 @@ async def get_cache_stats( ) async def invalidate_user_cache( user_id: int, - redis: Redis = Depends(get_redis), + redis: Redis, # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释 ): try: @@ -87,7 +86,7 @@ async def invalidate_user_cache( tags=["缓存管理"], ) async def clear_all_user_cache( - redis: Redis = Depends(get_redis), + redis: Redis, # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释 ): try: @@ -119,7 +118,7 @@ class CacheWarmupRequest(BaseModel): ) async def warmup_cache( request: CacheWarmupRequest, - redis: Redis = Depends(get_redis), + redis: Redis, # current_user: User = Security(get_current_user, scopes=["admin"]), # 暂时注释 ): try: diff --git a/app/router/v2/me.py b/app/router/v2/me.py index fe1e797..5304e9d 100644 --- a/app/router/v2/me.py +++ b/app/router/v2/me.py @@ -1,9 +1,10 @@ from __future__ import annotations +from typing import Annotated + from app.database import MeResp, User -from app.dependencies import get_current_user from app.dependencies.database import Database -from app.dependencies.user import UserAndToken, get_current_user_and_token +from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token from app.exceptions.userpage import UserpageError from app.models.score import GameMode from app.models.user import Page @@ -29,8 +30,8 @@ from fastapi import HTTPException, Path, Security ) async def get_user_info_with_ruleset( session: Database, - ruleset: GameMode = Path(description="指定 ruleset"), - user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]), + ruleset: Annotated[GameMode, Path(description="指定 ruleset")], + user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])], ): user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id) return user_resp @@ -45,7 +46,7 @@ async def get_user_info_with_ruleset( ) async def get_user_info_default( session: Database, - user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]), + user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])], ): user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id) return user_resp @@ -85,8 +86,8 @@ async def get_user_info_default( async def update_userpage( request: UpdateUserpageRequest, session: Database, - user_id: int = Path(description="用户ID"), - current_user: User = Security(get_current_user, scopes=["edit"]), + user_id: Annotated[int, Path(description="用户ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["edit"])], ): """更新用户页面内容(匹配官方osu-web实现)""" # 检查权限:只能编辑自己的页面(除非是管理员) diff --git a/app/router/v2/ranking.py b/app/router/v2/ranking.py index b6893f5..f2c6236 100644 --- a/app/router/v2/ranking.py +++ b/app/router/v2/ranking.py @@ -1,11 +1,11 @@ from __future__ import annotations -from typing import Literal +from typing import Annotated, Literal from app.config import settings from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp -from app.dependencies import get_current_user from app.dependencies.database import Database, get_redis +from app.dependencies.user import get_current_user from app.models.score import GameMode from app.service.ranking_cache_service import get_ranking_cache_service @@ -45,11 +45,11 @@ SortType = Literal["performance", "score"] async def get_team_ranking_pp( session: Database, background_tasks: BackgroundTasks, - ruleset: GameMode = Path(..., description="指定 ruleset"), - page: int = Query(1, ge=1, description="页码"), - current_user: User = Security(get_current_user, scopes=["public"]), + ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + page: Annotated[int, Query(ge=1, description="页码")] = 1, ): - return await get_team_ranking(session, background_tasks, "performance", ruleset, page, current_user) + return await get_team_ranking(session, background_tasks, "performance", ruleset, current_user, page) @router.get( @@ -62,14 +62,17 @@ async def get_team_ranking_pp( async def get_team_ranking( session: Database, background_tasks: BackgroundTasks, - sort: SortType = Path( - ..., - description="排名类型:performance 表现分 / score 计分成绩总分 " - "**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**", - ), - ruleset: GameMode = Path(..., description="指定 ruleset"), - page: int = Query(1, ge=1, description="页码"), - current_user: User = Security(get_current_user, scopes=["public"]), + sort: Annotated[ + SortType, + Path( + ..., + description="排名类型:performance 表现分 / score 计分成绩总分 " + "**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**", + ), + ], + ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + page: Annotated[int, Query(ge=1, description="页码")] = 1, ): # 获取 Redis 连接和缓存服务 redis = get_redis() @@ -193,11 +196,11 @@ class CountryResponse(BaseModel): async def get_country_ranking_pp( session: Database, background_tasks: BackgroundTasks, - ruleset: GameMode = Path(..., description="指定 ruleset"), - page: int = Query(1, ge=1, description="页码"), - current_user: User = Security(get_current_user, scopes=["public"]), + ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + page: Annotated[int, Query(ge=1, description="页码")] = 1, ): - return await get_country_ranking(session, background_tasks, ruleset, page, "performance", current_user) + return await get_country_ranking(session, background_tasks, ruleset, "performance", current_user, page) @router.get( @@ -210,14 +213,17 @@ async def get_country_ranking_pp( async def get_country_ranking( session: Database, background_tasks: BackgroundTasks, - ruleset: GameMode = Path(..., description="指定 ruleset"), - page: int = Query(1, ge=1, description="页码"), - sort: SortType = Path( - ..., - description="排名类型:performance 表现分 / score 计分成绩总分 " - "**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**", - ), - current_user: User = Security(get_current_user, scopes=["public"]), + ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")], + sort: Annotated[ + SortType, + Path( + ..., + description="排名类型:performance 表现分 / score 计分成绩总分 " + "**这个参数是本服务器额外添加的,不属于 v2 API 的一部分**", + ), + ], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + page: Annotated[int, Query(ge=1, description="页码")] = 1, ): # 获取 Redis 连接和缓存服务 redis = get_redis() @@ -317,11 +323,11 @@ class TopUsersResponse(BaseModel): async def get_user_ranking( session: Database, background_tasks: BackgroundTasks, - ruleset: GameMode = Path(..., description="指定 ruleset"), - sort: SortType = Path(..., description="排名类型:performance 表现分 / score 计分成绩总分"), - country: str | None = Query(None, description="国家代码"), - page: int = Query(1, ge=1, description="页码"), - current_user: User = Security(get_current_user, scopes=["public"]), + ruleset: Annotated[GameMode, Path(..., description="指定 ruleset")], + sort: Annotated[SortType, Path(..., description="排名类型:performance 表现分 / score 计分成绩总分")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + country: Annotated[str | None, Query(description="国家代码")] = None, + page: Annotated[int, Query(ge=1, description="页码")] = 1, ): # 获取 Redis 连接和缓存服务 redis = get_redis() diff --git a/app/router/v2/relationship.py b/app/router/v2/relationship.py index b028951..4851e2a 100644 --- a/app/router/v2/relationship.py +++ b/app/router/v2/relationship.py @@ -1,10 +1,12 @@ from __future__ import annotations +from typing import Annotated + from app.database import Relationship, RelationshipResp, RelationshipType, User from app.database.user import UserResp from app.dependencies.api_version import APIVersion from app.dependencies.database import Database -from app.dependencies.user import get_client_user, get_current_user +from app.dependencies.user import ClientUser, get_current_user from .router import router @@ -56,7 +58,7 @@ async def get_relationship( db: Database, request: Request, api_version: APIVersion, - current_user: User = Security(get_current_user, scopes=["friends.read"]), + current_user: Annotated[User, Security(get_current_user, scopes=["friends.read"])], ): relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK relationships = await db.exec( @@ -107,8 +109,8 @@ class AddFriendResp(BaseModel): async def add_relationship( db: Database, request: Request, - target: int = Query(description="目标用户 ID"), - current_user: User = Security(get_client_user), + target: Annotated[int, Query(description="目标用户 ID")], + current_user: ClientUser, ): if not (await db.exec(select(exists()).where(User.id == target))).first(): raise HTTPException(404, "Target user not found") @@ -176,8 +178,8 @@ async def add_relationship( async def delete_relationship( db: Database, request: Request, - target: int = Path(..., description="目标用户 ID"), - current_user: User = Security(get_client_user), + target: Annotated[int, Path(..., description="目标用户 ID")], + current_user: ClientUser, ): if not (await db.exec(select(exists()).where(User.id == target))).first(): raise HTTPException(404, "Target user not found") diff --git a/app/router/v2/room.py b/app/router/v2/room.py index 6a42dbc..21ea109 100644 --- a/app/router/v2/room.py +++ b/app/router/v2/room.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import UTC -from typing import Literal +from typing import Annotated, Literal from app.database.beatmap import Beatmap, BeatmapResp from app.database.beatmapset import BeatmapsetResp @@ -12,8 +12,8 @@ from app.database.room import APIUploadedRoom, Room, RoomResp from app.database.room_participated_user import RoomParticipatedUser from app.database.score import Score from app.database.user import User, UserResp -from app.dependencies.database import Database, get_redis -from app.dependencies.user import get_client_user, get_current_user +from app.dependencies.database import Database, Redis +from app.dependencies.user import ClientUser, get_current_user from app.models.room import RoomCategory, RoomStatus from app.service.room import create_playlist_room_from_api from app.signalr.hub import MultiplayerHubs @@ -21,9 +21,8 @@ from app.utils import utcnow from .router import router -from fastapi import Depends, HTTPException, Path, Query, Security +from fastapi import HTTPException, Path, Query, Security from pydantic import BaseModel, Field -from redis.asyncio import Redis from sqlalchemy.sql.elements import ColumnElement from sqlmodel import col, exists, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -38,16 +37,20 @@ from sqlmodel.ext.asyncio.session import AsyncSession ) async def get_all_rooms( db: Database, - mode: Literal["open", "ended", "participated", "owned"] | None = Query( - default="open", - description=("房间模式:open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"), - ), - category: RoomCategory = Query( - RoomCategory.NORMAL, - description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"), - ), - status: RoomStatus | None = Query(None, description="房间状态(可选)"), - current_user: User = Security(get_current_user, scopes=["public"]), + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + mode: Annotated[ + Literal["open", "ended", "participated", "owned"] | None, + Query( + description=("房间模式:open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"), + ), + ] = "open", + category: Annotated[ + RoomCategory, + Query( + description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"), + ), + ] = RoomCategory.NORMAL, + status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None, ): resp_list: list[RoomResp] = [] where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category] @@ -140,8 +143,8 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session: async def create_room( db: Database, room: APIUploadedRoom, - current_user: User = Security(get_client_user), - redis: Redis = Depends(get_redis), + current_user: ClientUser, + redis: Redis, ): user_id = current_user.id db_room = await create_playlist_room_from_api(db, room, user_id) @@ -162,13 +165,15 @@ async def create_room( ) async def get_room( db: Database, - room_id: int = Path(..., description="房间 ID"), - category: str = Query( - default="", - description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"), - ), - current_user: User = Security(get_current_user, scopes=["public"]), - redis: Redis = Depends(get_redis), + room_id: Annotated[int, Path(..., description="房间 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + redis: Redis, + category: Annotated[ + str, + Query( + description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"), + ), + ] = "", ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is None: @@ -185,8 +190,8 @@ async def get_room( ) async def delete_room( db: Database, - room_id: int = Path(..., description="房间 ID"), - current_user: User = Security(get_client_user), + room_id: Annotated[int, Path(..., description="房间 ID")], + current_user: ClientUser, ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is None: @@ -205,10 +210,10 @@ async def delete_room( ) async def add_user_to_room( db: Database, - room_id: int = Path(..., description="房间 ID"), - user_id: int = Path(..., description="用户 ID"), - redis: Redis = Depends(get_redis), - current_user: User = Security(get_client_user), + room_id: Annotated[int, Path(..., description="房间 ID")], + user_id: Annotated[int, Path(..., description="用户 ID")], + redis: Redis, + current_user: ClientUser, ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is not None: @@ -229,10 +234,10 @@ async def add_user_to_room( ) async def remove_user_from_room( db: Database, - room_id: int = Path(..., description="房间 ID"), - user_id: int = Path(..., description="用户 ID"), - current_user: User = Security(get_client_user), - redis: Redis = Depends(get_redis), + room_id: Annotated[int, Path(..., description="房间 ID")], + user_id: Annotated[int, Path(..., description="用户 ID")], + current_user: ClientUser, + redis: Redis, ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is not None: @@ -273,8 +278,8 @@ class APILeaderboard(BaseModel): ) async def get_room_leaderboard( db: Database, - room_id: int = Path(..., description="房间 ID"), - current_user: User = Security(get_current_user, scopes=["public"]), + room_id: Annotated[int, Path(..., description="房间 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is None: @@ -329,11 +334,11 @@ class RoomEvents(BaseModel): ) async def get_room_events( db: Database, - room_id: int = Path(..., description="房间 ID"), - current_user: User = Security(get_current_user, scopes=["public"]), - limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), - after: int | None = Query(None, ge=0, description="仅包含大于该事件 ID 的事件"), - before: int | None = Query(None, ge=0, description="仅包含小于该事件 ID 的事件"), + room_id: Annotated[int, Path(..., description="房间 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100, + after: Annotated[int | None, Query(ge=0, description="仅包含大于该事件 ID 的事件")] = None, + before: Annotated[int | None, Query(ge=0, description="仅包含小于该事件 ID 的事件")] = None, ): events = ( await db.exec( diff --git a/app/router/v2/score.py b/app/router/v2/score.py index ff5765f..0da6de0 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -2,6 +2,7 @@ from __future__ import annotations from datetime import UTC, date import time +from typing import Annotated from app.calculator import clamp from app.config import settings @@ -34,11 +35,10 @@ from app.database.score import ( process_user, ) from app.dependencies.api_version import APIVersion -from app.dependencies.database import Database, get_redis, with_db -from app.dependencies.fetcher import get_fetcher -from app.dependencies.storage import get_storage_service -from app.dependencies.user import get_client_user, get_current_user -from app.fetcher import Fetcher +from app.dependencies.database import Database, Redis, get_redis, with_db +from app.dependencies.fetcher import Fetcher, get_fetcher +from app.dependencies.storage import StorageService +from app.dependencies.user import ClientUser, get_current_user from app.log import logger from app.models.beatmap import BeatmapRankStatus from app.models.room import RoomCategory @@ -50,7 +50,6 @@ from app.models.score import ( ) from app.service.beatmap_cache_service import get_beatmap_cache_service from app.service.user_cache_service import refresh_user_cache_background -from app.storage.base import StorageService from app.utils import utcnow from .router import router @@ -69,7 +68,6 @@ from fastapi.responses import RedirectResponse from fastapi_limiter.depends import RateLimiter from httpx import HTTPError from pydantic import BaseModel -from redis.asyncio import Redis from sqlalchemy.orm import joinedload from sqlmodel import col, exists, func, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -245,16 +243,18 @@ class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel): async def get_beatmap_scores( db: Database, api_version: APIVersion, - beatmap_id: int = Path(description="谱面 ID"), - mode: GameMode = Query(description="指定 auleset"), - legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), - mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"), - type: LeaderboardType = Query( - LeaderboardType.GLOBAL, - description=("排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"), - ), - current_user: User = Security(get_current_user, scopes=["public"]), - limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"), + beatmap_id: Annotated[int, Path(description="谱面 ID")], + mode: Annotated[GameMode, Query(description="指定 auleset")], + mods: Annotated[list[str], Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None, + type: Annotated[ + LeaderboardType, + Query( + description=("排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"), + ), + ] = LeaderboardType.GLOBAL, + limit: Annotated[int, Query(ge=1, le=200, description="返回条数 (1-200)")] = 50, ): if legacy_only: raise HTTPException(status_code=404, detail="this server only contains lazer scores") @@ -294,12 +294,12 @@ async def get_beatmap_scores( async def get_user_beatmap_score( db: Database, api_version: APIVersion, - beatmap_id: int = Path(description="谱面 ID"), - user_id: int = Path(description="用户 ID"), - legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), - mode: GameMode | None = Query(None, description="指定 ruleset (可选)"), - mods: str = Query(None, description="筛选使用的 Mods (暂未实现)"), - current_user: User = Security(get_current_user, scopes=["public"]), + beatmap_id: Annotated[int, Path(description="谱面 ID")], + user_id: Annotated[int, Path(description="用户 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None, + mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None, + mods: Annotated[str | None, Query(description="筛选使用的 Mods (暂未实现)")] = None, ): user_score = ( await db.exec( @@ -342,11 +342,11 @@ async def get_user_beatmap_score( async def get_user_all_beatmap_scores( db: Database, api_version: APIVersion, - beatmap_id: int = Path(description="谱面 ID"), - user_id: int = Path(description="用户 ID"), - legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), - ruleset: GameMode | None = Query(None, description="指定 ruleset (可选)"), - current_user: User = Security(get_current_user, scopes=["public"]), + beatmap_id: Annotated[int, Path(description="谱面 ID")], + user_id: Annotated[int, Path(description="用户 ID")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + legacy_only: Annotated[bool | None, Query(description="是否只查询 Stable 分数")] = None, + ruleset: Annotated[GameMode | None, Query(description="指定 ruleset (可选)")] = None, ): all_user_scores = ( await db.exec( @@ -374,11 +374,11 @@ async def get_user_all_beatmap_scores( async def create_solo_score( background_task: BackgroundTasks, db: Database, - beatmap_id: int = Path(description="谱面 ID"), - version_hash: str = Form("", description="游戏版本哈希"), - beatmap_hash: str = Form(description="谱面文件哈希"), - ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"), - current_user: User = Security(get_client_user), + beatmap_id: Annotated[int, Path(description="谱面 ID")], + beatmap_hash: Annotated[str, Form(description="谱面文件哈希")], + ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")], + current_user: ClientUser, + version_hash: Annotated[str, Form(description="游戏版本哈希")] = "", ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id @@ -406,12 +406,12 @@ async def create_solo_score( async def submit_solo_score( background_task: BackgroundTasks, db: Database, - beatmap_id: int = Path(description="谱面 ID"), - token: int = Path(description="成绩令牌 ID"), - info: SoloScoreSubmissionInfo = Body(description="成绩提交信息"), - current_user: User = Security(get_client_user), - redis: Redis = Depends(get_redis), - fetcher=Depends(get_fetcher), + beatmap_id: Annotated[int, Path(description="谱面 ID")], + token: Annotated[int, Path(description="成绩令牌 ID")], + info: Annotated[SoloScoreSubmissionInfo, Body(description="成绩提交信息")], + current_user: ClientUser, + redis: Redis, + fetcher: Fetcher, ): return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher) @@ -428,11 +428,11 @@ async def create_playlist_score( background_task: BackgroundTasks, room_id: int, playlist_id: int, - beatmap_id: int = Form(description="谱面 ID"), - beatmap_hash: str = Form(description="游戏版本哈希"), - ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"), - version_hash: str = Form("", description="谱面版本哈希"), - current_user: User = Security(get_client_user), + beatmap_id: Annotated[int, Form(description="谱面 ID")], + beatmap_hash: Annotated[str, Form(description="游戏版本哈希")], + ruleset_id: Annotated[int, Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)")], + current_user: ClientUser, + version_hash: Annotated[str, Form(description="谱面版本哈希")] = "", ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id @@ -496,9 +496,9 @@ async def submit_playlist_score( playlist_id: int, token: int, info: SoloScoreSubmissionInfo, - current_user: User = Security(get_client_user), - redis: Redis = Depends(get_redis), - fetcher: Fetcher = Depends(get_fetcher), + current_user: ClientUser, + redis: Redis, + fetcher: Fetcher, ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id @@ -555,9 +555,9 @@ async def index_playlist_scores( session: Database, room_id: int, playlist_id: int, - limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"), - cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"), - current_user: User = Security(get_current_user, scopes=["public"]), + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + limit: Annotated[int, Query(ge=1, le=50, description="返回条数 (1-50)")] = 50, + cursor: Annotated[int, Query(alias="cursor[total_score]", description="分页游标(上一页最低分)")] = 2000000, ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id @@ -623,8 +623,8 @@ async def show_playlist_score( room_id: int, playlist_id: int, score_id: int, - current_user: User = Security(get_client_user), - redis: Redis = Depends(get_redis), + current_user: ClientUser, + redis: Redis, ): room = await session.get(Room, room_id) if not room: @@ -692,7 +692,7 @@ async def get_user_playlist_score( room_id: int, playlist_id: int, user_id: int, - current_user: User = Security(get_client_user), + current_user: ClientUser, ): score_record = None start_time = time.time() @@ -725,8 +725,8 @@ async def get_user_playlist_score( ) async def pin_score( db: Database, - score_id: int = Path(description="成绩 ID"), - current_user: User = Security(get_client_user), + score_id: Annotated[int, Path(description="成绩 ID")], + current_user: ClientUser, ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id @@ -770,8 +770,8 @@ async def pin_score( ) async def unpin_score( db: Database, - score_id: int = Path(description="成绩 ID"), - current_user: User = Security(get_client_user), + score_id: Annotated[int, Path(description="成绩 ID")], + current_user: ClientUser, ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id @@ -805,10 +805,10 @@ async def unpin_score( ) async def reorder_score_pin( db: Database, - score_id: int = Path(description="成绩 ID"), - after_score_id: int | None = Body(default=None, description="放在该成绩之后"), - before_score_id: int | None = Body(default=None, description="放在该成绩之前"), - current_user: User = Security(get_client_user), + score_id: Annotated[int, Path(description="成绩 ID")], + current_user: ClientUser, + after_score_id: Annotated[int | None, Body(description="放在该成绩之后")] = None, + before_score_id: Annotated[int | None, Body(description="放在该成绩之前")] = None, ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id @@ -893,8 +893,8 @@ async def reorder_score_pin( async def download_score_replay( score_id: int, db: Database, - current_user: User = Security(get_current_user, scopes=["public"]), - storage_service: StorageService = Depends(get_storage_service), + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + storage_service: StorageService, ): # 立即获取用户ID,避免懒加载问题 user_id = current_user.id diff --git a/app/router/v2/session_verify.py b/app/router/v2/session_verify.py index 81abb5d..079a0b3 100644 --- a/app/router/v2/session_verify.py +++ b/app/router/v2/session_verify.py @@ -11,8 +11,8 @@ from app.config import settings from app.const import BACKUP_CODE_LENGTH, SUPPORT_TOTP_VERIFICATION_VER from app.database.auth import TotpKeys from app.dependencies.api_version import APIVersion -from app.dependencies.database import Database, get_redis -from app.dependencies.geoip import get_client_ip +from app.dependencies.database import Database, Redis, get_redis +from app.dependencies.geoip import IPAddress from app.dependencies.user import UserAndToken, get_client_user_and_token from app.dependencies.user_agent import UserAgentInfo from app.log import logger @@ -27,7 +27,6 @@ from .router import router from fastapi import Depends, Form, Header, HTTPException, Request, Security, status from fastapi.responses import JSONResponse, Response from pydantic import BaseModel -from redis.asyncio import Redis class VerifyMethod(BaseModel): @@ -64,10 +63,14 @@ async def verify_session( db: Database, api_version: APIVersion, user_agent: UserAgentInfo, + ip_address: IPAddress, redis: Annotated[Redis, Depends(get_redis)], - verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 (g0v0 扩展支持)"), - user_and_token: UserAndToken = Security(get_client_user_and_token), - web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"), + verification_key: Annotated[ + str, + Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 (g0v0 扩展支持)"), + ], + user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)], + web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None, ) -> Response: current_user = user_and_token[0] token_id = user_and_token[1].id @@ -82,7 +85,6 @@ async def verify_session( else await LoginSessionService.get_login_method(user_id, token_id, redis) ) - ip_address = get_client_ip(request) login_method = "password" try: @@ -182,12 +184,12 @@ async def verify_session( tags=["验证"], ) async def reissue_verification_code( - request: Request, db: Database, user_agent: UserAgentInfo, api_version: APIVersion, + ip_address: IPAddress, redis: Annotated[Redis, Depends(get_redis)], - user_and_token: UserAndToken = Security(get_client_user_and_token), + user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)], ) -> SessionReissueResponse: current_user = user_and_token[0] token_id = user_and_token[1].id @@ -203,7 +205,6 @@ async def reissue_verification_code( return SessionReissueResponse(success=False, message="当前会话不支持重新发送验证码") try: - ip_address = get_client_ip(request) user_id = current_user.id success, message = await EmailVerificationService.resend_verification_code( db, @@ -233,17 +234,15 @@ async def reissue_verification_code( async def fallback_email( db: Database, user_agent: UserAgentInfo, - request: Request, + ip_address: IPAddress, redis: Annotated[Redis, Depends(get_redis)], - user_and_token: UserAndToken = Security(get_client_user_and_token), + user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)], ) -> VerifyMethod: current_user = user_and_token[0] token_id = user_and_token[1].id if not await LoginSessionService.get_login_method(current_user.id, token_id, redis): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退") - ip_address = get_client_ip(request) - await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis) success, message = await EmailVerificationService.resend_verification_code( db, diff --git a/app/router/v2/tags.py b/app/router/v2/tags.py index 99fccb2..644cd77 100644 --- a/app/router/v2/tags.py +++ b/app/router/v2/tags.py @@ -1,10 +1,12 @@ from __future__ import annotations +from typing import Annotated + from app.database.beatmap import Beatmap from app.database.beatmap_tags import BeatmapTagVote from app.database.score import Score from app.database.user import User -from app.dependencies.database import get_db +from app.dependencies.database import Database from app.dependencies.user import get_client_user from app.models.score import Rank from app.models.tags import BeatmapTags, get_all_tags, get_tag_by_id @@ -55,10 +57,10 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession description="为指定谱面添加标签投票。", ) async def vote_beatmap_tags( - beatmap_id: int = Path(..., description="谱面 ID"), - tag_id: int = Path(..., description="标签 ID"), - session: AsyncSession = Depends(get_db), - current_user: User = Depends(get_client_user), + beatmap_id: Annotated[int, Path(..., description="谱面 ID")], + tag_id: Annotated[int, Path(..., description="标签 ID")], + session: Database, + current_user: Annotated[User, Depends(get_client_user)], ): try: get_tag_by_id(tag_id) @@ -90,10 +92,10 @@ async def vote_beatmap_tags( description="取消对指定谱面标签的投票。", ) async def devote_beatmap_tags( - beatmap_id: int = Path(..., description="谱面 ID"), - tag_id: int = Path(..., description="标签 ID"), - session: AsyncSession = Depends(get_db), - current_user: User = Depends(get_client_user), + beatmap_id: Annotated[int, Path(..., description="谱面 ID")], + tag_id: Annotated[int, Path(..., description="标签 ID")], + session: Database, + current_user: Annotated[User, Depends(get_client_user)], ): """ 取消对谱面指定标签的投票。 diff --git a/app/router/v2/user.py b/app/router/v2/user.py index 7ff444a..7928029 100644 --- a/app/router/v2/user.py +++ b/app/router/v2/user.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import timedelta -from typing import Literal +from typing import Annotated, Literal from app.config import settings from app.const import BANCHOBOT_ID @@ -51,9 +51,12 @@ async def get_users( session: Database, request: Request, background_task: BackgroundTasks, - user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"), + user_ids: Annotated[list[int], Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表")], # current_user: User = Security(get_current_user, scopes=["public"]), - include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use + include_variant_statistics: Annotated[ + bool, + Query(description="是否包含各模式的统计信息"), + ] = False, # TODO: future use ): redis = get_redis() cache_service = get_user_cache_service(redis) @@ -119,9 +122,9 @@ async def get_users( ) async def get_user_events( session: Database, - user_id: int = Path(description="用户 ID"), - limit: int | None = Query(None, description="限制返回的活动数量"), - offset: int | None = Query(None, description="活动日志的偏移量"), + user_id: Annotated[int, Path(description="用户 ID")], + limit: Annotated[int | None, Query(description="限制返回的活动数量")] = None, + offset: Annotated[int | None, Query(description="活动日志的偏移量")] = None, ): db_user = await session.get(User, user_id) if db_user is None or db_user.id == BANCHOBOT_ID: @@ -147,9 +150,9 @@ async def get_user_events( ) async def get_user_kudosu( session: Database, - user_id: int = Path(description="用户 ID"), - offset: int = Query(default=0, description="偏移量"), - limit: int = Query(default=6, description="返回记录数量限制"), + user_id: Annotated[int, Path(description="用户 ID")], + offset: Annotated[int, Query(description="偏移量")] = 0, + limit: Annotated[int, Query(description="返回记录数量限制")] = 6, ): """ 获取用户的 kudosu 记录 @@ -176,8 +179,8 @@ async def get_user_kudosu( async def get_user_info_ruleset( session: Database, background_task: BackgroundTasks, - user_id: str = Path(description="用户 ID 或用户名"), - ruleset: GameMode | None = Path(description="指定 ruleset"), + user_id: Annotated[str, Path(description="用户 ID 或用户名")], + ruleset: Annotated[GameMode | None, Path(description="指定 ruleset")], # current_user: User = Security(get_current_user, scopes=["public"]), ): redis = get_redis() @@ -225,7 +228,7 @@ async def get_user_info( background_task: BackgroundTasks, session: Database, request: Request, - user_id: str = Path(description="用户 ID 或用户名"), + user_id: Annotated[str, Path(description="用户 ID 或用户名")], # current_user: User = Security(get_current_user, scopes=["public"]), ): redis = get_redis() @@ -274,11 +277,11 @@ async def get_user_info( async def get_user_beatmapsets( session: Database, background_task: BackgroundTasks, - user_id: int = Path(description="用户 ID"), - type: BeatmapsetType = Path(description="谱面集类型"), - current_user: User = Security(get_current_user, scopes=["public"]), - limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), - offset: int = Query(0, ge=0, description="偏移量"), + user_id: Annotated[int, Path(description="用户 ID")], + type: Annotated[BeatmapsetType, Path(description="谱面集类型")], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100, + offset: Annotated[int, Query(ge=0, description="偏移量")] = 0, ): redis = get_redis() cache_service = get_user_cache_service(redis) @@ -356,16 +359,17 @@ async def get_user_scores( session: Database, api_version: APIVersion, background_task: BackgroundTasks, - user_id: int = Path(description="用户 ID"), - type: Literal["best", "recent", "firsts", "pinned"] = Path( - description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩") - ), - legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"), - include_fails: bool = Query(False, description="是否包含失败的成绩"), - mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"), - limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), - offset: int = Query(0, ge=0, description="偏移量"), - current_user: User = Security(get_current_user, scopes=["public"]), + user_id: Annotated[int, Path(description="用户 ID")], + type: Annotated[ + Literal["best", "recent", "firsts", "pinned"], + Path(description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")), + ], + current_user: Annotated[User, Security(get_current_user, scopes=["public"])], + legacy_only: Annotated[bool, Query(description="是否只查询 Stable 成绩")] = False, + include_fails: Annotated[bool, Query(description="是否包含失败的成绩")] = False, + mode: Annotated[GameMode | None, Query(description="指定 ruleset (可选,默认为用户主模式)")] = None, + limit: Annotated[int, Query(ge=1, le=1000, description="返回条数 (1-1000)")] = 100, + offset: Annotated[int, Query(ge=0, description="偏移量")] = 0, ): is_legacy_api = api_version < 20220705 redis = get_redis() diff --git a/app/signalr/router.py b/app/signalr/router.py index cf4bf97..753bea3 100644 --- a/app/signalr/router.py +++ b/app/signalr/router.py @@ -7,9 +7,8 @@ from typing import Literal import uuid from app.database import User as DBUser -from app.dependencies import get_current_user from app.dependencies.database import DBFactory, get_db_factory -from app.dependencies.user import get_current_user_and_token +from app.dependencies.user import get_current_user, get_current_user_and_token from app.log import logger from app.models.signalr import NegotiateResponse, Transport