refactor(api): use Annotated-style dependency injection

This commit is contained in:
MingxuanGame
2025-10-03 05:41:31 +00:00
parent 37b4eadf79
commit 346c2557cf
45 changed files with 623 additions and 577 deletions

View File

@@ -1,4 +1 @@
from __future__ import annotations
from .database import get_db as get_db
from .user import get_current_user as get_current_user

View File

@@ -5,7 +5,7 @@ from typing import Annotated
from fastapi import Depends, Header
def get_api_version(version: int | None = Header(None, alias="x-api-version")) -> int:
def get_api_version(version: int | None = Header(None, alias="x-api-version", include_in_schema=False)) -> int:
if version is None:
return 0
if version < 1:

View File

@@ -1,8 +1,15 @@
from __future__ import annotations
from app.service.beatmap_download_service import download_service
from typing import Annotated
from app.service.beatmap_download_service import BeatmapDownloadService, download_service
from fastapi import Depends
def get_beatmap_download_service():
"""获取谱面下载服务实例"""
return download_service
DownloadService = Annotated[BeatmapDownloadService, Depends(get_beatmap_download_service)]

View File

@@ -1,16 +1,19 @@
"""
Beatmapset缓存服务依赖注入
"""
from __future__ import annotations
from app.dependencies.database import get_redis
from app.service.beatmapset_cache_service import BeatmapsetCacheService, get_beatmapset_cache_service
from typing import Annotated
from app.dependencies.database import Redis
from app.service.beatmapset_cache_service import (
BeatmapsetCacheService as OriginBeatmapsetCacheService,
get_beatmapset_cache_service,
)
from fastapi import Depends
from redis.asyncio import Redis
def get_beatmapset_cache_dependency(redis: Redis = Depends(get_redis)) -> BeatmapsetCacheService:
def get_beatmapset_cache_dependency(redis: Redis) -> OriginBeatmapsetCacheService:
"""获取beatmapset缓存服务依赖"""
return get_beatmapset_cache_service(redis)
BeatmapsetCacheService = Annotated[OriginBeatmapsetCacheService, Depends(get_beatmapset_cache_dependency)]

View File

@@ -91,6 +91,9 @@ def get_redis():
return redis_client
Redis = Annotated[redis.Redis, Depends(get_redis)]
def get_redis_binary():
"""获取二进制数据专用的 Redis 客户端 (不自动解码响应)"""
return redis_binary_client

View File

@@ -1,17 +1,21 @@
from __future__ import annotations
from typing import Annotated
from app.config import settings
from app.dependencies.database import get_redis
from app.fetcher import Fetcher
from app.fetcher import Fetcher as OriginFetcher
from app.log import logger
fetcher: Fetcher | None = None
from fastapi import Depends
fetcher: OriginFetcher | None = None
async def get_fetcher() -> Fetcher:
async def get_fetcher() -> OriginFetcher:
global fetcher
if fetcher is None:
fetcher = Fetcher(
fetcher = OriginFetcher(
settings.fetcher_client_id,
settings.fetcher_client_secret,
settings.fetcher_scopes,
@@ -27,3 +31,6 @@ async def get_fetcher() -> Fetcher:
if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>")
return fetcher
Fetcher = Annotated[OriginFetcher, Depends(get_fetcher)]

View File

@@ -6,10 +6,13 @@ from __future__ import annotations
from functools import lru_cache
import ipaddress
from typing import Annotated
from app.config import settings
from app.helpers.geoip_helper import GeoIPHelper
from fastapi import Depends, Request
@lru_cache
def get_geoip_helper() -> GeoIPHelper:
@@ -26,7 +29,7 @@ def get_geoip_helper() -> GeoIPHelper:
)
def get_client_ip(request) -> str:
def get_client_ip(request: Request) -> str:
"""
获取客户端真实 IP 地址
支持 IPv4 和 IPv6考虑代理、负载均衡器等情况
@@ -66,6 +69,10 @@ def get_client_ip(request) -> str:
return client_ip if is_valid_ip(client_ip) else "127.0.0.1"
IPAddress = Annotated[str, Depends(get_client_ip)]
GeoIPService = Annotated[GeoIPHelper, Depends(get_geoip_helper)]
def is_valid_ip(ip_str: str) -> bool:
"""
验证 IP 地址是否有效(支持 IPv4 和 IPv6

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import cast
from typing import Annotated, cast
from app.config import (
AWSS3StorageSettings,
@@ -9,11 +9,13 @@ from app.config import (
StorageServiceType,
settings,
)
from app.storage import StorageService
from app.storage import StorageService as OriginStorageService
from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService
from app.storage.local import LocalStorageService
storage: StorageService | None = None
from fastapi import Depends
storage: OriginStorageService | None = None
def init_storage_service():
@@ -50,3 +52,6 @@ def get_storage_service():
if storage is None:
return init_storage_service()
return storage
StorageService = Annotated[OriginStorageService, Depends(get_storage_service)]

View File

@@ -4,6 +4,7 @@ from typing import Annotated
from app.auth import get_token_by_access_token
from app.config import settings
from app.const import SUPPORT_TOTP_VERIFICATION_VER
from app.database import User
from app.database.auth import OAuthToken, V1APIKeys
from app.models.oauth import OAuth2ClientCredentialsBearer
@@ -11,7 +12,7 @@ from app.models.oauth import OAuth2ClientCredentialsBearer
from .api_version import APIVersion
from .database import Database, get_redis
from fastapi import Depends, HTTPException
from fastapi import Depends, HTTPException, Security
from fastapi.security import (
APIKeyQuery,
HTTPBearer,
@@ -112,13 +113,13 @@ async def get_client_user(
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
# 获取当前验证方式
verify_method = None
if api_version >= 20250913:
if api_version >= SUPPORT_TOTP_VERIFICATION_VER:
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
if verify_method is None:
# 智能选择验证方式有TOTP优先TOTP
totp_key = await user.awaitable_attrs.totp_key
if totp_key is not None and api_version >= 20240101:
if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER:
verify_method = "totp"
else:
verify_method = "mail"
@@ -169,3 +170,6 @@ async def get_current_user(
user_and_token: UserAndToken = Depends(get_current_user_and_token),
) -> User:
return user_and_token[0]
ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])]