refactor(api): use Annotated-style dependency injection

This commit is contained in:
MingxuanGame
2025-10-03 05:41:31 +00:00
parent 37b4eadf79
commit 346c2557cf
45 changed files with 623 additions and 577 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import timedelta
import re
from typing import Literal
from typing import Annotated, Literal
from app.auth import (
authenticate_user,
@@ -19,10 +19,9 @@ from app.const import BANCHOBOT_ID
from app.database import DailyChallengeStats, OAuthClient, User
from app.database.auth import TotpKeys
from app.database.statistics import UserStatistics
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.database import Database, Redis
from app.dependencies.geoip import GeoIPService, IPAddress
from app.dependencies.user_agent import UserAgentInfo
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
from app.models.extended_auth import ExtendedTokenResponse
from app.models.oauth import (
@@ -40,9 +39,8 @@ from app.service.verification_service import (
)
from app.utils import utcnow
from fastapi import APIRouter, Depends, Form, Header, Request
from fastapi import APIRouter, Form, Header, Request
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text
from sqlmodel import exists, select
@@ -93,11 +91,11 @@ router = APIRouter(tags=["osu! OAuth 认证"])
)
async def register_user(
db: Database,
request: Request,
user_username: str = Form(..., alias="user[username]", description="用户名"),
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
user_password: str = Form(..., alias="user[password]", description="密码"),
geoip: GeoIPHelper = Depends(get_geoip_helper),
user_username: Annotated[str, Form(..., alias="user[username]", description="用户名")],
user_email: Annotated[str, Form(..., alias="user[user_email]", description="电子邮箱")],
user_password: Annotated[str, Form(..., alias="user[password]", description="密码")],
geoip: GeoIPService,
client_ip: IPAddress,
):
username_errors = validate_username(user_username)
email_errors = validate_email(user_email)
@@ -126,7 +124,6 @@ async def register_user(
try:
# 获取客户端 IP 并查询地理位置
client_ip = get_client_ip(request)
country_code = "CN" # 默认国家代码
try:
@@ -201,19 +198,21 @@ async def oauth_token(
db: Database,
request: Request,
user_agent: UserAgentInfo,
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
),
client_id: int = Form(..., description="客户端 ID"),
client_secret: str = Form(..., description="客户端密钥"),
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*'"),
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
password: str | None = Form(None, description="密码(仅密码模式需要)"),
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
redis: Redis = Depends(get_redis),
geoip: GeoIPHelper = Depends(get_geoip_helper),
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
ip_address: IPAddress,
grant_type: Annotated[
Literal["authorization_code", "refresh_token", "password", "client_credentials"],
Form(..., description="授权类型:密码、刷新令牌和授权码三种授权方式。"),
],
client_id: Annotated[int, Form(..., description="客户端 ID")],
client_secret: Annotated[str, Form(..., description="客户端密钥")],
redis: Redis,
geoip: GeoIPService,
code: Annotated[str | None, Form(description="授权码(仅授权码模式需要)")] = None,
scope: Annotated[str, Form(description="权限范围(空格分隔,默认为 '*'")] = "*",
username: Annotated[str | None, Form(description="用户名(仅密码模式需要)")] = None,
password: Annotated[str | None, Form(description="密码(仅密码模式需要)")] = None,
refresh_token: Annotated[str | None, Form(description="刷新令牌(仅刷新令牌模式需要)")] = None,
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
):
scopes = scope.split(" ")
@@ -311,8 +310,6 @@ async def oauth_token(
)
token_id = token.id
ip_address = get_client_ip(request)
# 获取国家代码
geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX")
@@ -571,16 +568,14 @@ async def oauth_token(
)
async def request_password_reset(
request: Request,
email: str = Form(..., description="邮箱地址"),
redis: Redis = Depends(get_redis),
email: Annotated[str, Form(..., description="邮箱地址")],
redis: Redis,
ip_address: IPAddress,
):
"""
请求密码重置
"""
from app.dependencies.geoip import get_client_ip
# 获取客户端信息
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")
# 请求密码重置
@@ -599,20 +594,16 @@ async def request_password_reset(
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
async def reset_password(
request: Request,
email: str = Form(..., description="邮箱地址"),
reset_code: str = Form(..., description="重置验证"),
new_password: str = Form(..., description="新密码"),
redis: Redis = Depends(get_redis),
email: Annotated[str, Form(..., description="邮箱地址")],
reset_code: Annotated[str, Form(..., description="重置验证码")],
new_password: Annotated[str, Form(..., description="新密")],
redis: Redis,
ip_address: IPAddress,
):
"""
重置密码
"""
from app.dependencies.geoip import get_client_ip
# 获取客户端信息
ip_address = get_client_ip(request)
# 重置密码
success, message = await password_reset_service.reset_password(
email=email.lower().strip(),