add geoip

This commit is contained in:
咕谷酱
2025-08-17 23:56:46 +08:00
parent eaa6d4d92b
commit de0c86f4a2
11 changed files with 503 additions and 82 deletions

View File

@@ -20,6 +20,8 @@ from app.database import DailyChallengeStats, OAuthClient, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
from app.dependencies.database import get_redis
from app.dependencies.geoip import get_geoip_helper, get_client_ip
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
from app.models.oauth import (
OAuthErrorResponse,
@@ -29,7 +31,7 @@ from app.models.oauth import (
)
from app.models.score import GameMode
from fastapi import APIRouter, Depends, Form
from fastapi import APIRouter, Depends, Form, Request
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text
@@ -79,18 +81,20 @@ def validate_password(password: str) -> list[str]:
router = APIRouter(tags=["osu! OAuth 认证"])
@router.post(
"/users",
name="注册用户",
description="用户注册接口",
)
async def register_user(
request: Request,
user_username: str = Form(..., alias="user[username]", description="用户名"),
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
user_password: str = Form(..., alias="user[password]", description="密码"),
db: AsyncSession = Depends(get_db),
geoip: GeoIPHelper = Depends(get_geoip_helper)
):
username_errors = validate_username(user_username)
email_errors = validate_email(user_email)
password_errors = validate_password(user_password)
@@ -119,6 +123,21 @@ async def register_user(
)
try:
# 获取客户端 IP 并查询地理位置
client_ip = get_client_ip(request)
country_code = "CN" # 默认国家代码
try:
# 查询 IP 地理位置
geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"):
country_code = geo_info["country_iso"]
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
else:
logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e:
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
# 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
result = await db.execute( # pyright: ignore[reportDeprecated]
@@ -137,7 +156,7 @@ async def register_user(
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
country_code="CN", # 默认国家
country_code=country_code, # 根据 IP 地理位置设置国家
join_date=datetime.now(UTC),
last_visit=datetime.now(UTC),
is_supporter=settings.enable_supporter_for_all_users,