refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import re
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
|
||||
from app.auth import (
|
||||
authenticate_user,
|
||||
@@ -22,19 +22,19 @@ from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
from app.log import logger
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
RegistrationRequestErrors,
|
||||
TokenResponse,
|
||||
UserRegistrationErrors,
|
||||
)
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
from app.models.score import GameMode
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.service.email_verification_service import (
|
||||
EmailVerificationService,
|
||||
LoginSessionService
|
||||
LoginSessionService,
|
||||
)
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.service.password_reset_service import password_reset_service
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
@@ -44,13 +44,9 @@ from sqlalchemy import text
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
def create_oauth_error_response(
|
||||
error: str, description: str, hint: str, status_code: int = 400
|
||||
):
|
||||
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
|
||||
"""创建标准的 OAuth 错误响应"""
|
||||
error_data = OAuthErrorResponse(
|
||||
error=error, error_description=description, hint=hint, message=description
|
||||
)
|
||||
error_data = OAuthErrorResponse(error=error, error_description=description, hint=hint, message=description)
|
||||
return JSONResponse(status_code=status_code, content=error_data.model_dump())
|
||||
|
||||
|
||||
@@ -123,9 +119,7 @@ async def register_user(
|
||||
)
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
return JSONResponse(status_code=422, content={"form_error": errors.model_dump()})
|
||||
|
||||
try:
|
||||
# 获取客户端 IP 并查询地理位置
|
||||
@@ -137,10 +131,7 @@ async def register_user(
|
||||
geo_info = geoip.lookup(client_ip)
|
||||
if geo_info and geo_info.get("country_iso"):
|
||||
country_code = geo_info["country_iso"]
|
||||
logger.info(
|
||||
f"User {user_username} registering from "
|
||||
f"{client_ip}, country: {country_code}"
|
||||
)
|
||||
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
||||
else:
|
||||
logger.warning(f"Could not determine country for IP {client_ip}")
|
||||
except Exception as e:
|
||||
@@ -148,7 +139,7 @@ async def register_user(
|
||||
|
||||
# 创建新用户
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
result = await db.execute( # pyright: ignore[reportDeprecated]
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
|
||||
@@ -173,7 +164,6 @@ async def register_user(
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
assert new_user.id is not None, "New user ID should not be None"
|
||||
for i in [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]:
|
||||
statistics = UserStatistics(mode=i, user_id=new_user.id)
|
||||
db.add(statistics)
|
||||
@@ -193,36 +183,30 @@ async def register_user(
|
||||
logger.exception(f"Registration error for user {user_username}")
|
||||
|
||||
# 返回通用错误
|
||||
errors = RegistrationRequestErrors(
|
||||
message="An error occurred while creating your account. Please try again."
|
||||
)
|
||||
errors = RegistrationRequestErrors(message="An error occurred while creating your account. Please try again.")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"form_error": errors.model_dump()})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/oauth/token",
|
||||
response_model=Union[TokenResponse, ExtendedTokenResponse],
|
||||
response_model=TokenResponse | ExtendedTokenResponse,
|
||||
name="获取访问令牌",
|
||||
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
|
||||
)
|
||||
async def oauth_token(
|
||||
db: Database,
|
||||
request: Request,
|
||||
grant_type: Literal[
|
||||
"authorization_code", "refresh_token", "password", "client_credentials"
|
||||
] = Form(..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"),
|
||||
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
|
||||
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
|
||||
),
|
||||
client_id: int = Form(..., description="客户端 ID"),
|
||||
client_secret: str = Form(..., description="客户端密钥"),
|
||||
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
|
||||
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*')"),
|
||||
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
|
||||
password: str | None = Form(None, description="密码(仅密码模式需要)"),
|
||||
refresh_token: str | None = Form(
|
||||
None, description="刷新令牌(仅刷新令牌模式需要)"
|
||||
),
|
||||
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
):
|
||||
@@ -303,37 +287,33 @@ async def oauth_token(
|
||||
await db.refresh(user)
|
||||
|
||||
# 获取用户信息和客户端信息
|
||||
user_id = getattr(user, "id")
|
||||
assert user_id is not None, "User ID should not be None after authentication"
|
||||
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
user_id = user.id
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
|
||||
# 获取国家代码
|
||||
geo_info = geoip.lookup(ip_address)
|
||||
country_code = geo_info.get("country_iso", "XX")
|
||||
|
||||
|
||||
# 检查是否为新位置登录
|
||||
is_new_location = await LoginSessionService.check_new_location(
|
||||
db, user_id, ip_address, country_code
|
||||
)
|
||||
|
||||
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
|
||||
|
||||
# 创建登录会话记录
|
||||
login_session = await LoginSessionService.create_session(
|
||||
login_session = await LoginSessionService.create_session( # noqa: F841
|
||||
db, redis, user_id, ip_address, user_agent, country_code, is_new_location
|
||||
)
|
||||
|
||||
|
||||
# 如果是新位置登录,需要邮件验证
|
||||
if is_new_location and settings.enable_email_verification:
|
||||
# 刷新用户对象以确保属性已加载
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
# 发送邮件验证码
|
||||
verification_sent = await EmailVerificationService.send_verification_email(
|
||||
db, redis, user_id, user.username, user.email, ip_address, user_agent
|
||||
)
|
||||
|
||||
|
||||
# 记录需要二次验证的登录尝试
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
@@ -343,14 +323,16 @@ async def oauth_token(
|
||||
login_method="password_pending_verification",
|
||||
notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}",
|
||||
)
|
||||
|
||||
|
||||
if not verification_sent:
|
||||
# 邮件发送失败,记录错误
|
||||
logger.error(f"[Auth] Failed to send email verification code for user {user_id}")
|
||||
elif is_new_location and not settings.enable_email_verification:
|
||||
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
logger.debug(f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}")
|
||||
logger.debug(
|
||||
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
|
||||
)
|
||||
else:
|
||||
# 不是新位置登录,正常登录
|
||||
await LoginLogService.record_login(
|
||||
@@ -361,20 +343,17 @@ async def oauth_token(
|
||||
login_method="password",
|
||||
notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}",
|
||||
)
|
||||
|
||||
|
||||
# 无论是否新位置登录,都返回正常的token
|
||||
# session_verified状态通过/me接口的session_verified字段来体现
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
# 获取用户ID,避免触发延迟加载
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
assert user_id
|
||||
await store_token(
|
||||
db,
|
||||
user_id,
|
||||
@@ -423,9 +402,7 @@ async def oauth_token(
|
||||
|
||||
# 生成新的访问令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
access_token = create_access_token(data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires)
|
||||
new_refresh_token = generate_refresh_token()
|
||||
|
||||
# 更新令牌
|
||||
@@ -489,17 +466,11 @@ async def oauth_token(
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
# 重新查询只获取ID,避免触发延迟加载
|
||||
id_result = await db.exec(select(User.id).where(User.username == username))
|
||||
user_id = id_result.first()
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
user_id = user.id
|
||||
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
assert user_id
|
||||
await store_token(
|
||||
db,
|
||||
user_id,
|
||||
@@ -539,9 +510,7 @@ async def oauth_token(
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": "3"}, expires_delta=access_token_expires
|
||||
)
|
||||
access_token = create_access_token(data={"sub": "3"}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
@@ -567,7 +536,7 @@ async def oauth_token(
|
||||
@router.post(
|
||||
"/password-reset/request",
|
||||
name="请求密码重置",
|
||||
description="通过邮箱请求密码重置验证码"
|
||||
description="通过邮箱请求密码重置验证码",
|
||||
)
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
@@ -578,42 +547,26 @@ async def request_password_reset(
|
||||
请求密码重置
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
|
||||
# 请求密码重置
|
||||
success, message = await password_reset_service.request_password_reset(
|
||||
email=email.lower().strip(),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
redis=redis
|
||||
redis=redis,
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"message": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"success": True, "message": message})
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"success": False,
|
||||
"error": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=400, content={"success": False, "error": message})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/password-reset/reset",
|
||||
name="重置密码",
|
||||
description="使用验证码重置密码"
|
||||
)
|
||||
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
|
||||
async def reset_password(
|
||||
request: Request,
|
||||
email: str = Form(..., description="邮箱地址"),
|
||||
@@ -625,32 +578,20 @@ async def reset_password(
|
||||
重置密码
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
|
||||
# 重置密码
|
||||
success, message = await password_reset_service.reset_password(
|
||||
email=email.lower().strip(),
|
||||
reset_code=reset_code.strip(),
|
||||
new_password=new_password,
|
||||
ip_address=ip_address,
|
||||
redis=redis
|
||||
redis=redis,
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"message": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"success": True, "message": message})
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"success": False,
|
||||
"error": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=400, content={"success": False, "error": message})
|
||||
|
||||
Reference in New Issue
Block a user