chore(lint): make ruff happy

This commit is contained in:
MingxuanGame
2025-08-17 16:57:27 +00:00
parent 3c460f1d82
commit 86bea5d4b5
13 changed files with 316 additions and 181 deletions

View File

@@ -20,10 +20,9 @@ from app.database import DailyChallengeStats, OAuthClient, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
from app.dependencies.database import get_redis
from app.dependencies.geoip import get_geoip_helper, get_client_ip
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
from app.service.login_log_service import LoginLogService
from app.models.oauth import (
OAuthErrorResponse,
RegistrationRequestErrors,
@@ -31,6 +30,7 @@ from app.models.oauth import (
UserRegistrationErrors,
)
from app.models.score import GameMode
from app.service.login_log_service import LoginLogService
from fastapi import APIRouter, Depends, Form, Request
from fastapi.responses import JSONResponse
@@ -82,6 +82,7 @@ def validate_password(password: str) -> list[str]:
router = APIRouter(tags=["osu! OAuth 认证"])
@router.post(
"/users",
name="注册用户",
@@ -93,9 +94,8 @@ async def register_user(
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
user_password: str = Form(..., alias="user[password]", description="密码"),
db: AsyncSession = Depends(get_db),
geoip: GeoIPHelper = Depends(get_geoip_helper)
geoip: GeoIPHelper = Depends(get_geoip_helper),
):
username_errors = validate_username(user_username)
email_errors = validate_email(user_email)
password_errors = validate_password(user_password)
@@ -127,18 +127,21 @@ async def register_user(
# 获取客户端 IP 并查询地理位置
client_ip = get_client_ip(request)
country_code = "CN" # 默认国家代码
try:
# 查询 IP 地理位置
geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"):
country_code = geo_info["country_iso"]
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
logger.info(
f"User {user_username} registering from "
f"{client_ip}, country: {country_code}"
)
else:
logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e:
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
# 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
result = await db.execute( # pyright: ignore[reportDeprecated]
@@ -276,9 +279,9 @@ async def oauth_token(
request=request,
attempted_username=username,
login_method="password",
notes="Invalid credentials"
notes="Invalid credentials",
)
return create_oauth_error_response(
error="invalid_grant",
description=(
@@ -293,9 +296,9 @@ async def oauth_token(
# 确保用户对象与当前会话关联
await db.refresh(user)
# 记录成功的登录
user_id = getattr(user, 'id')
user_id = getattr(user, "id")
assert user_id is not None, "User ID should not be None after authentication"
await LoginLogService.record_login(
db=db,
@@ -303,7 +306,7 @@ async def oauth_token(
request=request,
login_success=True,
login_method="password",
notes=f"OAuth password grant for client {client_id}"
notes=f"OAuth password grant for client {client_id}",
)
# 生成令牌
@@ -424,16 +427,16 @@ async def oauth_token(
hint="Invalid authorization code",
)
user, scopes = code_result
# 确保用户对象与当前会话关联
await db.refresh(user)
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 重新查询只获取ID避免触发延迟加载
id_result = await db.exec(select(User.id).where(User.username == username))
user_id = id_result.first()
access_token = create_access_token(
data={"sub": str(user_id)}, expires_delta=access_token_expires
)