refactor(api): use Annotated-style dependency injection
This commit is contained in:
@@ -1,4 +1 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .database import get_db as get_db
|
||||
from .user import get_current_user as get_current_user
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Annotated
|
||||
from fastapi import Depends, Header
|
||||
|
||||
|
||||
def get_api_version(version: int | None = Header(None, alias="x-api-version")) -> int:
|
||||
def get_api_version(version: int | None = Header(None, alias="x-api-version", include_in_schema=False)) -> int:
|
||||
if version is None:
|
||||
return 0
|
||||
if version < 1:
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.service.beatmap_download_service import download_service
|
||||
from typing import Annotated
|
||||
|
||||
from app.service.beatmap_download_service import BeatmapDownloadService, download_service
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
|
||||
def get_beatmap_download_service():
|
||||
"""获取谱面下载服务实例"""
|
||||
return download_service
|
||||
|
||||
|
||||
DownloadService = Annotated[BeatmapDownloadService, Depends(get_beatmap_download_service)]
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
"""
|
||||
Beatmapset缓存服务依赖注入
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
from app.service.beatmapset_cache_service import BeatmapsetCacheService, get_beatmapset_cache_service
|
||||
from typing import Annotated
|
||||
|
||||
from app.dependencies.database import Redis
|
||||
from app.service.beatmapset_cache_service import (
|
||||
BeatmapsetCacheService as OriginBeatmapsetCacheService,
|
||||
get_beatmapset_cache_service,
|
||||
)
|
||||
|
||||
from fastapi import Depends
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
def get_beatmapset_cache_dependency(redis: Redis = Depends(get_redis)) -> BeatmapsetCacheService:
|
||||
def get_beatmapset_cache_dependency(redis: Redis) -> OriginBeatmapsetCacheService:
|
||||
"""获取beatmapset缓存服务依赖"""
|
||||
return get_beatmapset_cache_service(redis)
|
||||
|
||||
|
||||
BeatmapsetCacheService = Annotated[OriginBeatmapsetCacheService, Depends(get_beatmapset_cache_dependency)]
|
||||
|
||||
@@ -91,6 +91,9 @@ def get_redis():
|
||||
return redis_client
|
||||
|
||||
|
||||
Redis = Annotated[redis.Redis, Depends(get_redis)]
|
||||
|
||||
|
||||
def get_redis_binary():
|
||||
"""获取二进制数据专用的 Redis 客户端 (不自动解码响应)"""
|
||||
return redis_binary_client
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies.database import get_redis
|
||||
from app.fetcher import Fetcher
|
||||
from app.fetcher import Fetcher as OriginFetcher
|
||||
from app.log import logger
|
||||
|
||||
fetcher: Fetcher | None = None
|
||||
from fastapi import Depends
|
||||
|
||||
fetcher: OriginFetcher | None = None
|
||||
|
||||
|
||||
async def get_fetcher() -> Fetcher:
|
||||
async def get_fetcher() -> OriginFetcher:
|
||||
global fetcher
|
||||
if fetcher is None:
|
||||
fetcher = Fetcher(
|
||||
fetcher = OriginFetcher(
|
||||
settings.fetcher_client_id,
|
||||
settings.fetcher_client_secret,
|
||||
settings.fetcher_scopes,
|
||||
@@ -27,3 +31,6 @@ async def get_fetcher() -> Fetcher:
|
||||
if not fetcher.access_token or not fetcher.refresh_token:
|
||||
logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>")
|
||||
return fetcher
|
||||
|
||||
|
||||
Fetcher = Annotated[OriginFetcher, Depends(get_fetcher)]
|
||||
|
||||
@@ -6,10 +6,13 @@ from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
import ipaddress
|
||||
from typing import Annotated
|
||||
|
||||
from app.config import settings
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
|
||||
from fastapi import Depends, Request
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_geoip_helper() -> GeoIPHelper:
|
||||
@@ -26,7 +29,7 @@ def get_geoip_helper() -> GeoIPHelper:
|
||||
)
|
||||
|
||||
|
||||
def get_client_ip(request) -> str:
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实 IP 地址
|
||||
支持 IPv4 和 IPv6,考虑代理、负载均衡器等情况
|
||||
@@ -66,6 +69,10 @@ def get_client_ip(request) -> str:
|
||||
return client_ip if is_valid_ip(client_ip) else "127.0.0.1"
|
||||
|
||||
|
||||
IPAddress = Annotated[str, Depends(get_client_ip)]
|
||||
GeoIPService = Annotated[GeoIPHelper, Depends(get_geoip_helper)]
|
||||
|
||||
|
||||
def is_valid_ip(ip_str: str) -> bool:
|
||||
"""
|
||||
验证 IP 地址是否有效(支持 IPv4 和 IPv6)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from typing import Annotated, cast
|
||||
|
||||
from app.config import (
|
||||
AWSS3StorageSettings,
|
||||
@@ -9,11 +9,13 @@ from app.config import (
|
||||
StorageServiceType,
|
||||
settings,
|
||||
)
|
||||
from app.storage import StorageService
|
||||
from app.storage import StorageService as OriginStorageService
|
||||
from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService
|
||||
from app.storage.local import LocalStorageService
|
||||
|
||||
storage: StorageService | None = None
|
||||
from fastapi import Depends
|
||||
|
||||
storage: OriginStorageService | None = None
|
||||
|
||||
|
||||
def init_storage_service():
|
||||
@@ -50,3 +52,6 @@ def get_storage_service():
|
||||
if storage is None:
|
||||
return init_storage_service()
|
||||
return storage
|
||||
|
||||
|
||||
StorageService = Annotated[OriginStorageService, Depends(get_storage_service)]
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Annotated
|
||||
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.config import settings
|
||||
from app.const import SUPPORT_TOTP_VERIFICATION_VER
|
||||
from app.database import User
|
||||
from app.database.auth import OAuthToken, V1APIKeys
|
||||
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||
@@ -11,7 +12,7 @@ from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||
from .api_version import APIVersion
|
||||
from .database import Database, get_redis
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import Depends, HTTPException, Security
|
||||
from fastapi.security import (
|
||||
APIKeyQuery,
|
||||
HTTPBearer,
|
||||
@@ -112,13 +113,13 @@ async def get_client_user(
|
||||
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
|
||||
# 获取当前验证方式
|
||||
verify_method = None
|
||||
if api_version >= 20250913:
|
||||
if api_version >= SUPPORT_TOTP_VERIFICATION_VER:
|
||||
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
|
||||
|
||||
if verify_method is None:
|
||||
# 智能选择验证方式(有TOTP优先TOTP)
|
||||
totp_key = await user.awaitable_attrs.totp_key
|
||||
if totp_key is not None and api_version >= 20240101:
|
||||
if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER:
|
||||
verify_method = "totp"
|
||||
else:
|
||||
verify_method = "mail"
|
||||
@@ -169,3 +170,6 @@ async def get_current_user(
|
||||
user_and_token: UserAndToken = Depends(get_current_user_and_token),
|
||||
) -> User:
|
||||
return user_and_token[0]
|
||||
|
||||
|
||||
ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])]
|
||||
|
||||
Reference in New Issue
Block a user