add geoip
This commit is contained in:
@@ -20,6 +20,8 @@ from app.database import DailyChallengeStats, OAuthClient, User
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies import get_db
|
||||
from app.dependencies.database import get_redis
|
||||
from app.dependencies.geoip import get_geoip_helper, get_client_ip
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
from app.log import logger
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
@@ -29,7 +31,7 @@ from app.models.oauth import (
|
||||
)
|
||||
from app.models.score import GameMode
|
||||
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import text
|
||||
@@ -79,18 +81,20 @@ def validate_password(password: str) -> list[str]:
|
||||
|
||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/users",
|
||||
name="注册用户",
|
||||
description="用户注册接口",
|
||||
)
|
||||
async def register_user(
|
||||
request: Request,
|
||||
user_username: str = Form(..., alias="user[username]", description="用户名"),
|
||||
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
|
||||
user_password: str = Form(..., alias="user[password]", description="密码"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper)
|
||||
):
|
||||
|
||||
username_errors = validate_username(user_username)
|
||||
email_errors = validate_email(user_email)
|
||||
password_errors = validate_password(user_password)
|
||||
@@ -119,6 +123,21 @@ async def register_user(
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取客户端 IP 并查询地理位置
|
||||
client_ip = get_client_ip(request)
|
||||
country_code = "CN" # 默认国家代码
|
||||
|
||||
try:
|
||||
# 查询 IP 地理位置
|
||||
geo_info = geoip.lookup(client_ip)
|
||||
if geo_info and geo_info.get("country_iso"):
|
||||
country_code = geo_info["country_iso"]
|
||||
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
||||
else:
|
||||
logger.warning(f"Could not determine country for IP {client_ip}")
|
||||
except Exception as e:
|
||||
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
|
||||
|
||||
# 创建新用户
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
result = await db.execute( # pyright: ignore[reportDeprecated]
|
||||
@@ -137,7 +156,7 @@ async def register_user(
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1, # 普通用户权限
|
||||
country_code="CN", # 默认国家
|
||||
country_code=country_code, # 根据 IP 地理位置设置国家
|
||||
join_date=datetime.now(UTC),
|
||||
last_visit=datetime.now(UTC),
|
||||
is_supporter=settings.enable_supporter_for_all_users,
|
||||
|
||||
Reference in New Issue
Block a user