refactor(api): use Annotated-style dependency injection
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
import re
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.auth import (
|
||||
authenticate_user,
|
||||
@@ -19,10 +19,9 @@ from app.const import BANCHOBOT_ID
|
||||
from app.database import DailyChallengeStats, OAuthClient, User
|
||||
from app.database.auth import TotpKeys
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.geoip import GeoIPService, IPAddress
|
||||
from app.dependencies.user_agent import UserAgentInfo
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
from app.log import logger
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
from app.models.oauth import (
|
||||
@@ -40,9 +39,8 @@ from app.service.verification_service import (
|
||||
)
|
||||
from app.utils import utcnow
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Header, Request
|
||||
from fastapi import APIRouter, Form, Header, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import exists, select
|
||||
|
||||
@@ -93,11 +91,11 @@ router = APIRouter(tags=["osu! OAuth 认证"])
|
||||
)
|
||||
async def register_user(
|
||||
db: Database,
|
||||
request: Request,
|
||||
user_username: str = Form(..., alias="user[username]", description="用户名"),
|
||||
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
|
||||
user_password: str = Form(..., alias="user[password]", description="密码"),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
user_username: Annotated[str, Form(..., alias="user[username]", description="用户名")],
|
||||
user_email: Annotated[str, Form(..., alias="user[user_email]", description="电子邮箱")],
|
||||
user_password: Annotated[str, Form(..., alias="user[password]", description="密码")],
|
||||
geoip: GeoIPService,
|
||||
client_ip: IPAddress,
|
||||
):
|
||||
username_errors = validate_username(user_username)
|
||||
email_errors = validate_email(user_email)
|
||||
@@ -126,7 +124,6 @@ async def register_user(
|
||||
|
||||
try:
|
||||
# 获取客户端 IP 并查询地理位置
|
||||
client_ip = get_client_ip(request)
|
||||
country_code = "CN" # 默认国家代码
|
||||
|
||||
try:
|
||||
@@ -201,19 +198,21 @@ async def oauth_token(
|
||||
db: Database,
|
||||
request: Request,
|
||||
user_agent: UserAgentInfo,
|
||||
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
|
||||
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
|
||||
),
|
||||
client_id: int = Form(..., description="客户端 ID"),
|
||||
client_secret: str = Form(..., description="客户端密钥"),
|
||||
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
|
||||
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*')"),
|
||||
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
|
||||
password: str | None = Form(None, description="密码(仅密码模式需要)"),
|
||||
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
|
||||
ip_address: IPAddress,
|
||||
grant_type: Annotated[
|
||||
Literal["authorization_code", "refresh_token", "password", "client_credentials"],
|
||||
Form(..., description="授权类型:密码、刷新令牌和授权码三种授权方式。"),
|
||||
],
|
||||
client_id: Annotated[int, Form(..., description="客户端 ID")],
|
||||
client_secret: Annotated[str, Form(..., description="客户端密钥")],
|
||||
redis: Redis,
|
||||
geoip: GeoIPService,
|
||||
code: Annotated[str | None, Form(description="授权码(仅授权码模式需要)")] = None,
|
||||
scope: Annotated[str, Form(description="权限范围(空格分隔,默认为 '*')")] = "*",
|
||||
username: Annotated[str | None, Form(description="用户名(仅密码模式需要)")] = None,
|
||||
password: Annotated[str | None, Form(description="密码(仅密码模式需要)")] = None,
|
||||
refresh_token: Annotated[str | None, Form(description="刷新令牌(仅刷新令牌模式需要)")] = None,
|
||||
web_uuid: Annotated[str | None, Header(include_in_schema=False, alias="X-UUID")] = None,
|
||||
):
|
||||
scopes = scope.split(" ")
|
||||
|
||||
@@ -311,8 +310,6 @@ async def oauth_token(
|
||||
)
|
||||
token_id = token.id
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
# 获取国家代码
|
||||
geo_info = geoip.lookup(ip_address)
|
||||
country_code = geo_info.get("country_iso", "XX")
|
||||
@@ -571,16 +568,14 @@ async def oauth_token(
|
||||
)
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
email: str = Form(..., description="邮箱地址"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
email: Annotated[str, Form(..., description="邮箱地址")],
|
||||
redis: Redis,
|
||||
ip_address: IPAddress,
|
||||
):
|
||||
"""
|
||||
请求密码重置
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
# 请求密码重置
|
||||
@@ -599,20 +594,16 @@ async def request_password_reset(
|
||||
|
||||
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
|
||||
async def reset_password(
|
||||
request: Request,
|
||||
email: str = Form(..., description="邮箱地址"),
|
||||
reset_code: str = Form(..., description="重置验证码"),
|
||||
new_password: str = Form(..., description="新密码"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
email: Annotated[str, Form(..., description="邮箱地址")],
|
||||
reset_code: Annotated[str, Form(..., description="重置验证码")],
|
||||
new_password: Annotated[str, Form(..., description="新密码")],
|
||||
redis: Redis,
|
||||
ip_address: IPAddress,
|
||||
):
|
||||
"""
|
||||
重置密码
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
# 重置密码
|
||||
success, message = await password_reset_service.reset_password(
|
||||
email=email.lower().strip(),
|
||||
|
||||
Reference in New Issue
Block a user