refactor(api): use Annotated-style dependency injection

This commit is contained in:
MingxuanGame
2025-10-03 05:41:31 +00:00
parent 37b4eadf79
commit 346c2557cf
45 changed files with 623 additions and 577 deletions

View File

@@ -1,15 +1,16 @@
from __future__ import annotations
from typing import Annotated
from app.database.auth import OAuthToken
from app.database.verification import LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp
from app.dependencies.database import Database
from app.dependencies.geoip import get_geoip_helper
from app.dependencies.geoip import GeoIPService
from app.dependencies.user import UserAndToken, get_client_user_and_token
from app.helpers.geoip_helper import GeoIPHelper
from .router import router
from fastapi import Depends, HTTPException, Security
from fastapi import HTTPException, Security
from pydantic import BaseModel
from sqlmodel import col, select
@@ -28,8 +29,8 @@ class SessionsResp(BaseModel):
)
async def get_sessions(
session: Database,
user_and_token: UserAndToken = Security(get_client_user_and_token),
geoip: GeoIPHelper = Depends(get_geoip_helper),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
geoip: GeoIPService,
):
current_user, token = user_and_token
sessions = (
@@ -57,7 +58,7 @@ async def get_sessions(
async def delete_session(
session: Database,
session_id: int,
user_and_token: UserAndToken = Security(get_client_user_and_token),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
):
current_user, token = user_and_token
if session_id == token.id:
@@ -91,8 +92,8 @@ class TrustedDevicesResp(BaseModel):
)
async def get_trusted_devices(
session: Database,
user_and_token: UserAndToken = Security(get_client_user_and_token),
geoip: GeoIPHelper = Depends(get_geoip_helper),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
geoip: GeoIPService,
):
current_user, token = user_and_token
devices = (
@@ -131,7 +132,7 @@ async def get_trusted_devices(
async def delete_trusted_device(
session: Database,
device_id: int,
user_and_token: UserAndToken = Security(get_client_user_and_token),
user_and_token: Annotated[UserAndToken, Security(get_client_user_and_token)],
):
current_user, token = user_and_token
device = await session.get(TrustedDevice, device_id)

View File

@@ -1,25 +1,24 @@
from __future__ import annotations
import hashlib
from typing import Annotated
from app.database.user import User
from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user
from app.storage.base import StorageService
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
from app.utils import check_image
from .router import router
from fastapi import Depends, File, Security
from fastapi import File
@router.post("/avatar/upload", name="上传头像", tags=["用户", "g0v0 API"])
async def upload_avatar(
session: Database,
content: bytes = File(...),
current_user: User = Security(get_client_user),
storage: StorageService = Depends(get_storage_service),
content: Annotated[bytes, File(...)],
current_user: ClientUser,
storage: StorageService,
):
"""上传用户头像

View File

@@ -1,17 +1,18 @@
from __future__ import annotations
from typing import Annotated
from app.database.beatmap import Beatmap
from app.database.beatmapset import Beatmapset
from app.database.beatmapset_ratings import BeatmapRating
from app.database.score import Score
from app.database.user import User
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from app.dependencies.user import ClientUser
from app.service.beatmapset_update_service import get_beatmapset_update_service
from .router import router
from fastapi import Body, Depends, HTTPException, Security
from fastapi import Body, Depends, HTTPException
from fastapi_limiter.depends import RateLimiter
from sqlmodel import col, exists, select
@@ -25,7 +26,7 @@ from sqlmodel import col, exists, select
async def can_rate_beatmapset(
beatmapset_id: int,
session: Database,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
"""检查用户是否可以评价谱面集
@@ -57,8 +58,8 @@ async def can_rate_beatmapset(
async def rate_beatmaps(
beatmapset_id: int,
session: Database,
rating: int = Body(..., ge=0, le=10),
current_user: User = Security(get_client_user),
rating: Annotated[int, Body(..., ge=0, le=10)],
current_user: ClientUser,
):
"""为谱面集评分
@@ -96,7 +97,7 @@ async def rate_beatmaps(
async def sync_beatmapset(
beatmapset_id: int,
session: Database,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
"""请求同步谱面集

View File

@@ -1,25 +1,25 @@
from __future__ import annotations
import hashlib
from typing import Annotated
from app.database.user import User, UserProfileCover
from app.database.user import UserProfileCover
from app.dependencies.database import Database
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user
from app.storage.base import StorageService
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
from app.utils import check_image
from .router import router
from fastapi import Depends, File, Security
from fastapi import File
@router.post("/cover/upload", name="上传头图", tags=["用户", "g0v0 API"])
async def upload_cover(
session: Database,
content: bytes = File(...),
current_user: User = Security(get_client_user),
storage: StorageService = Depends(get_storage_service),
content: Annotated[bytes, File(...)],
current_user: ClientUser,
storage: StorageService,
):
"""上传用户头图

View File

@@ -1,16 +1,15 @@
from __future__ import annotations
import secrets
from typing import Annotated
from app.database.auth import OAuthClient, OAuthToken
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user
from app.dependencies.database import Database, Redis
from app.dependencies.user import ClientUser
from .router import router
from fastapi import Body, Depends, HTTPException, Security
from redis.asyncio import Redis
from fastapi import Body, HTTPException
from sqlmodel import select, text
@@ -22,10 +21,10 @@ from sqlmodel import select, text
)
async def create_oauth_app(
session: Database,
name: str = Body(..., max_length=100, description="应用程序名称"),
description: str = Body("", description="应用程序描述"),
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
current_user: User = Security(get_client_user),
name: Annotated[str, Body(..., max_length=100, description="应用程序名称")],
redirect_uris: Annotated[list[str], Body(..., description="允许的重定向 URI 列表")],
current_user: ClientUser,
description: Annotated[str, Body(description="应用程序描述")] = "",
):
result = await session.execute(
text(
@@ -64,7 +63,7 @@ async def create_oauth_app(
async def get_oauth_app(
session: Database,
client_id: int,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
oauth_app = await session.get(OAuthClient, client_id)
if not oauth_app:
@@ -85,7 +84,7 @@ async def get_oauth_app(
)
async def get_user_oauth_apps(
session: Database,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id))
return [
@@ -109,7 +108,7 @@ async def get_user_oauth_apps(
async def delete_oauth_app(
session: Database,
client_id: int,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client:
@@ -134,10 +133,10 @@ async def delete_oauth_app(
async def update_oauth_app(
session: Database,
client_id: int,
name: str = Body(..., max_length=100, description="应用程序新名称"),
description: str = Body("", description="应用程序新描述"),
redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"),
current_user: User = Security(get_client_user),
name: Annotated[str, Body(..., max_length=100, description="应用程序新名称")],
redirect_uris: Annotated[list[str], Body(..., description="新的重定向 URI 列表")],
current_user: ClientUser,
description: Annotated[str, Body(description="应用程序新描述")] = "",
):
oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client:
@@ -168,7 +167,7 @@ async def update_oauth_app(
async def refresh_secret(
session: Database,
client_id: int,
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client:
@@ -200,10 +199,10 @@ async def refresh_secret(
async def generate_oauth_code(
session: Database,
client_id: int,
current_user: User = Security(get_client_user),
redirect_uri: str = Body(..., description="授权后重定向的 URI"),
scopes: list[str] = Body(..., description="请求的权限范围列表"),
redis: Redis = Depends(get_redis),
current_user: ClientUser,
redirect_uri: Annotated[str, Body(..., description="授权后重定向的 URI")],
scopes: Annotated[list[str], Body(..., description="请求的权限范围列表")],
redis: Redis,
):
client = await session.get(OAuthClient, client_id)
if not client:

View File

@@ -1,13 +1,15 @@
from __future__ import annotations
from app.database import Relationship, User
from typing import Annotated
from app.database import Relationship
from app.database.relationship import RelationshipType
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from app.dependencies.user import ClientUser
from .router import router
from fastapi import HTTPException, Path, Security
from fastapi import HTTPException, Path
from pydantic import BaseModel, Field
from sqlmodel import select
@@ -27,8 +29,8 @@ class CheckResponse(BaseModel):
)
async def check_user_relationship(
db: Database,
user_id: int = Path(..., description="目标用户的 ID"),
current_user: User = Security(get_client_user),
user_id: Annotated[int, Path(..., description="目标用户的 ID")],
current_user: ClientUser,
):
if user_id == current_user.id:
raise HTTPException(422, "Cannot check relationship with yourself")

View File

@@ -1,17 +1,14 @@
from __future__ import annotations
from app.database.score import Score
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user
from app.dependencies.database import Database, Redis
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
from app.service.user_cache_service import refresh_user_cache_background
from app.storage.base import StorageService
from .router import router
from fastapi import BackgroundTasks, Depends, HTTPException, Security
from redis.asyncio import Redis
from fastapi import BackgroundTasks, HTTPException
@router.delete(
@@ -24,9 +21,9 @@ async def delete_score(
session: Database,
background_task: BackgroundTasks,
score_id: int,
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
storage_service: StorageService = Depends(get_storage_service),
redis: Redis,
current_user: ClientUser,
storage_service: StorageService,
):
"""删除成绩

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
import hashlib
from typing import Annotated
from app.database.team import Team, TeamMember, TeamRequest
from app.database.user import BASE_INCLUDES, User, UserResp
from app.dependencies.database import Database, get_redis
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user
from app.dependencies.database import Database, Redis
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
from app.models.notification import (
TeamApplicationAccept,
TeamApplicationReject,
@@ -14,27 +15,25 @@ from app.models.notification import (
)
from app.router.notification import server
from app.service.ranking_cache_service import get_ranking_cache_service
from app.storage.base import StorageService
from app.utils import check_image, utcnow
from .router import router
from fastapi import Depends, File, Form, HTTPException, Path, Request, Security
from fastapi import File, Form, HTTPException, Path, Request
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import exists, select
@router.post("/team", name="创建战队", response_model=Team, tags=["战队", "g0v0 API"])
async def create_team(
session: Database,
storage: StorageService = Depends(get_storage_service),
current_user: User = Security(get_client_user),
flag: bytes = File(..., description="战队图标文件"),
cover: bytes = File(..., description="战队头图文件"),
name: str = Form(max_length=100, description="战队名称"),
short_name: str = Form(max_length=10, description="战队缩写"),
redis: Redis = Depends(get_redis),
storage: StorageService,
current_user: ClientUser,
flag: Annotated[bytes, File(..., description="战队图标文件")],
cover: Annotated[bytes, File(..., description="战队头图文件")],
name: Annotated[str, Form(max_length=100, description="战队名称")],
short_name: Annotated[str, Form(max_length=10, description="战队缩写")],
redis: Redis,
):
"""创建战队。
@@ -88,13 +87,13 @@ async def create_team(
async def update_team(
team_id: int,
session: Database,
storage: StorageService = Depends(get_storage_service),
current_user: User = Security(get_client_user),
flag: bytes | None = File(default=None, description="战队图标文件"),
cover: bytes | None = File(default=None, description="战队头图文件"),
name: str | None = Form(default=None, max_length=100, description="战队名称"),
short_name: str | None = Form(default=None, max_length=10, description="战队缩写"),
leader_id: int | None = Form(default=None, description="战队队长 ID"),
storage: StorageService,
current_user: ClientUser,
flag: Annotated[bytes | None, File(description="战队图标文件")] = None,
cover: Annotated[bytes | None, File(description="战队头图文件")] = None,
name: Annotated[str | None, Form(max_length=100, description="战队名称")] = None,
short_name: Annotated[str | None, Form(max_length=10, description="战队缩写")] = None,
leader_id: Annotated[int | None, Form(description="战队队长 ID")] = None,
):
"""修改战队。
@@ -161,9 +160,9 @@ async def update_team(
@router.delete("/team/{team_id}", name="删除战队", status_code=204, tags=["战队", "g0v0 API"])
async def delete_team(
session: Database,
team_id: int = Path(..., description="战队 ID"),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
team_id: Annotated[int, Path(..., description="战队 ID")],
current_user: ClientUser,
redis: Redis,
):
team = await session.get(Team, team_id)
if not team:
@@ -191,7 +190,7 @@ class TeamQueryResp(BaseModel):
@router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"])
async def get_team(
session: Database,
team_id: int = Path(..., description="战队 ID"),
team_id: Annotated[int, Path(..., description="战队 ID")],
):
members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all()
return TeamQueryResp(
@@ -203,8 +202,8 @@ async def get_team(
@router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"])
async def request_join_team(
session: Database,
team_id: int = Path(..., description="战队 ID"),
current_user: User = Security(get_client_user),
team_id: Annotated[int, Path(..., description="战队 ID")],
current_user: ClientUser,
):
team = await session.get(Team, team_id)
if not team:
@@ -231,10 +230,10 @@ async def request_join_team(
async def handle_request(
req: Request,
session: Database,
team_id: int = Path(..., description="战队 ID"),
user_id: int = Path(..., description="用户 ID"),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
team_id: Annotated[int, Path(..., description="战队 ID")],
user_id: Annotated[int, Path(..., description="用户 ID")],
current_user: ClientUser,
redis: Redis,
):
team = await session.get(Team, team_id)
if not team:
@@ -272,10 +271,10 @@ async def handle_request(
@router.delete("/team/{team_id}/{user_id}", name="踢出成员 / 退出战队", status_code=204, tags=["战队", "g0v0 API"])
async def kick_member(
session: Database,
team_id: int = Path(..., description="战队 ID"),
user_id: int = Path(..., description="用户 ID"),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
team_id: Annotated[int, Path(..., description="战队 ID")],
user_id: Annotated[int, Path(..., description="用户 ID")],
current_user: ClientUser,
redis: Redis,
):
team = await session.get(Team, team_id)
if not team:

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Annotated
from app.auth import (
check_totp_backup_code,
finish_create_totp_key,
@@ -9,17 +11,15 @@ from app.auth import (
)
from app.const import BACKUP_CODE_LENGTH
from app.database.auth import TotpKeys
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user
from app.dependencies.database import Database, Redis
from app.dependencies.user import ClientUser
from app.models.totp import FinishStatus, StartCreateTotpKeyResp
from .router import router
from fastapi import Body, Depends, HTTPException, Security
from fastapi import Body, HTTPException
from pydantic import BaseModel
import pyotp
from redis.asyncio import Redis
class TotpStatusResp(BaseModel):
@@ -37,7 +37,7 @@ class TotpStatusResp(BaseModel):
response_model=TotpStatusResp,
)
async def get_totp_status(
current_user: User = Security(get_client_user),
current_user: ClientUser,
):
"""检查用户是否已创建TOTP"""
totp_key = await current_user.awaitable_attrs.totp_key
@@ -62,8 +62,8 @@ async def get_totp_status(
status_code=201,
)
async def start_create_totp(
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
redis: Redis,
current_user: ClientUser,
):
if await current_user.awaitable_attrs.totp_key:
raise HTTPException(status_code=400, detail="TOTP is already enabled for this user")
@@ -98,9 +98,9 @@ async def start_create_totp(
)
async def finish_create_totp(
session: Database,
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码"),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
code: Annotated[str, Body(..., embed=True, description="用户提供的 TOTP 代码")],
redis: Redis,
current_user: ClientUser,
):
status, backup_codes = await finish_create_totp_key(current_user, code, redis, session)
if status == FinishStatus.SUCCESS:
@@ -122,9 +122,9 @@ async def finish_create_totp(
)
async def disable_totp(
session: Database,
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
code: Annotated[str, Body(..., embed=True, description="用户提供的 TOTP 代码或备份码")],
redis: Redis,
current_user: ClientUser,
):
totp = await session.get(TotpKeys, current_user.id)
if not totp:

View File

@@ -1,24 +1,26 @@
from __future__ import annotations
from typing import Annotated
from app.auth import validate_username
from app.config import settings
from app.database.events import Event, EventType
from app.database.user import User
from app.dependencies.database import Database
from app.dependencies.user import get_client_user
from app.dependencies.user import ClientUser
from app.utils import utcnow
from .router import router
from fastapi import Body, HTTPException, Security
from fastapi import Body, HTTPException
from sqlmodel import exists, select
@router.post("/rename", name="修改用户名", tags=["用户", "g0v0 API"])
async def user_rename(
session: Database,
new_name: str = Body(..., description="新的用户名"),
current_user: User = Security(get_client_user),
new_name: Annotated[str, Body(..., description="新的用户名")],
current_user: ClientUser,
):
"""修改用户名