chore(lint): make ruff happy
This commit is contained in:
@@ -20,10 +20,9 @@ from app.database import DailyChallengeStats, OAuthClient, User
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies import get_db
|
||||
from app.dependencies.database import get_redis
|
||||
from app.dependencies.geoip import get_geoip_helper, get_client_ip
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
from app.log import logger
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
RegistrationRequestErrors,
|
||||
@@ -31,6 +30,7 @@ from app.models.oauth import (
|
||||
UserRegistrationErrors,
|
||||
)
|
||||
from app.models.score import GameMode
|
||||
from app.service.login_log_service import LoginLogService
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -82,6 +82,7 @@ def validate_password(password: str) -> list[str]:
|
||||
|
||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/users",
|
||||
name="注册用户",
|
||||
@@ -93,9 +94,8 @@ async def register_user(
|
||||
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
|
||||
user_password: str = Form(..., alias="user[password]", description="密码"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper)
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
):
|
||||
|
||||
username_errors = validate_username(user_username)
|
||||
email_errors = validate_email(user_email)
|
||||
password_errors = validate_password(user_password)
|
||||
@@ -127,18 +127,21 @@ async def register_user(
|
||||
# 获取客户端 IP 并查询地理位置
|
||||
client_ip = get_client_ip(request)
|
||||
country_code = "CN" # 默认国家代码
|
||||
|
||||
|
||||
try:
|
||||
# 查询 IP 地理位置
|
||||
geo_info = geoip.lookup(client_ip)
|
||||
if geo_info and geo_info.get("country_iso"):
|
||||
country_code = geo_info["country_iso"]
|
||||
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
||||
logger.info(
|
||||
f"User {user_username} registering from "
|
||||
f"{client_ip}, country: {country_code}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Could not determine country for IP {client_ip}")
|
||||
except Exception as e:
|
||||
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
|
||||
|
||||
|
||||
# 创建新用户
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
result = await db.execute( # pyright: ignore[reportDeprecated]
|
||||
@@ -276,9 +279,9 @@ async def oauth_token(
|
||||
request=request,
|
||||
attempted_username=username,
|
||||
login_method="password",
|
||||
notes="Invalid credentials"
|
||||
notes="Invalid credentials",
|
||||
)
|
||||
|
||||
|
||||
return create_oauth_error_response(
|
||||
error="invalid_grant",
|
||||
description=(
|
||||
@@ -293,9 +296,9 @@ async def oauth_token(
|
||||
|
||||
# 确保用户对象与当前会话关联
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
# 记录成功的登录
|
||||
user_id = getattr(user, 'id')
|
||||
user_id = getattr(user, "id")
|
||||
assert user_id is not None, "User ID should not be None after authentication"
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
@@ -303,7 +306,7 @@ async def oauth_token(
|
||||
request=request,
|
||||
login_success=True,
|
||||
login_method="password",
|
||||
notes=f"OAuth password grant for client {client_id}"
|
||||
notes=f"OAuth password grant for client {client_id}",
|
||||
)
|
||||
|
||||
# 生成令牌
|
||||
@@ -424,16 +427,16 @@ async def oauth_token(
|
||||
hint="Invalid authorization code",
|
||||
)
|
||||
user, scopes = code_result
|
||||
|
||||
|
||||
# 确保用户对象与当前会话关联
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
# 重新查询只获取ID,避免触发延迟加载
|
||||
id_result = await db.exec(select(User.id).where(User.username == username))
|
||||
user_id = id_result.first()
|
||||
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user