add Registration Interface

This commit is contained in:
咕谷酱
2025-07-28 19:41:57 +08:00
parent 2cd84f6175
commit 80310d450b
3 changed files with 253 additions and 2 deletions

1
.gitignore vendored
View File

@@ -210,3 +210,4 @@ bancho.py-master/*
# runtime file
replays/
osu-master/*

View File

@@ -1,6 +1,7 @@
# OAuth 相关模型
from __future__ import annotations
from typing import List
from pydantic import BaseModel
@@ -34,3 +35,22 @@ class OAuthErrorResponse(BaseModel):
error_description: str
hint: str
message: str
class RegistrationErrorResponse(BaseModel):
"""注册错误响应模型"""
form_error: dict
class UserRegistrationErrors(BaseModel):
"""用户注册错误模型"""
username: List[str] = []
user_email: List[str] = []
password: List[str] = []
class RegistrationRequestErrors(BaseModel):
"""注册请求错误模型"""
message: str | None = None
redirect: str | None = None
user: UserRegistrationErrors | None = None

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from datetime import timedelta
import re
from typing import List
from app.auth import (
authenticate_user,
@@ -8,13 +10,16 @@ from app.auth import (
generate_refresh_token,
get_token_by_refresh_token,
store_token,
get_password_hash,
)
from app.config import settings
from app.dependencies import get_db
from app.models.oauth import TokenResponse, OAuthErrorResponse
from app.models.oauth import TokenResponse, OAuthErrorResponse, RegistrationErrorResponse, UserRegistrationErrors, RegistrationRequestErrors
from app.database import User as DBUser
from fastapi import APIRouter, Depends, Form
from fastapi import APIRouter, Depends, Form, HTTPException
from fastapi.responses import JSONResponse
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -31,9 +36,234 @@ def create_oauth_error_response(error: str, description: str, hint: str, status_
content=error_data.model_dump()
)
def validate_username(username: str) -> List[str]:
"""验证用户名"""
errors = []
if not username:
errors.append("Username is required")
return errors
if len(username) < 3:
errors.append("Username must be at least 3 characters long")
if len(username) > 15:
errors.append("Username must be at most 15 characters long")
# 检查用户名格式(只允许字母、数字、下划线、连字符)
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
errors.append("Username can only contain letters, numbers, underscores, and hyphens")
# 检查是否以数字开头
if username[0].isdigit():
errors.append("Username cannot start with a number")
return errors
def validate_email(email: str) -> List[str]:
"""验证邮箱"""
errors = []
if not email:
errors.append("Email is required")
return errors
# 基本的邮箱格式验证
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(email_pattern, email):
errors.append("Please enter a valid email address")
return errors
def validate_password(password: str) -> List[str]:
"""验证密码"""
errors = []
if not password:
errors.append("Password is required")
return errors
if len(password) < 8:
errors.append("Password must be at least 8 characters long")
return errors
router = APIRouter(tags=["osu! OAuth 认证"])
@router.post("/users")
async def register_user(
user_username: str = Form(..., alias="user[username]"),
user_email: str = Form(..., alias="user[user_email]"),
user_password: str = Form(..., alias="user[password]"),
db: AsyncSession = Depends(get_db),
):
"""用户注册接口 - 匹配 osu! 客户端的注册请求"""
username_errors = validate_username(user_username)
email_errors = validate_email(user_email)
password_errors = validate_password(user_password)
result = await db.exec(
select(DBUser).where(DBUser.name == user_username)
)
existing_user = result.first()
if existing_user:
username_errors.append("Username is already taken")
result = await db.exec(
select(DBUser).where(DBUser.email == user_email)
)
existing_email = result.first()
if existing_email:
email_errors.append("Email is already taken")
if username_errors or email_errors or password_errors:
errors = RegistrationRequestErrors(
user=UserRegistrationErrors(
username=username_errors,
user_email=email_errors,
password=password_errors
)
)
return JSONResponse(
status_code=422,
content={"form_error": errors.model_dump()}
)
try:
# 创建新用户
from datetime import datetime
import time
new_user = DBUser(
name=user_username,
safe_name=user_username.lower(), # 安全用户名(小写)
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
country="CN", # 默认国家
creation_time=int(time.time()),
latest_activity=int(time.time()),
preferred_mode=0, # 默认模式
play_style=0, # 默认游戏风格
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
# 保存用户ID因为会话可能会关闭
user_id = new_user.id
if user_id <= 2:
await db.rollback()
try:
from sqlalchemy import text
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
await db.commit()
# 重新创建用户
new_user = DBUser(
name=user_username,
safe_name=user_username.lower(),
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1,
country="CN",
creation_time=int(time.time()),
latest_activity=int(time.time()),
preferred_mode=0,
play_style=0,
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
user_id = new_user.id
# 最终检查ID是否有效
if user_id <= 2:
await db.rollback()
errors = RegistrationRequestErrors(
message="Failed to create account with valid ID. Please contact support."
)
return JSONResponse(
status_code=500,
content={"form_error": errors.model_dump()}
)
except Exception as fix_error:
await db.rollback()
print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
errors = RegistrationRequestErrors(
message="Failed to create account with valid ID. Please try again."
)
return JSONResponse(
status_code=500,
content={"form_error": errors.model_dump()}
)
# 创建默认的 lazer_profile
from app.database.user import LazerUserProfile
lazer_profile = LazerUserProfile(
user_id=user_id,
is_active=True,
is_bot=False,
is_deleted=False,
is_online=True,
is_supporter=False,
is_restricted=False,
session_verified=False,
has_supported=False,
pm_friends_only=False,
default_group="default",
join_date=datetime.utcnow(),
playmode="osu",
support_level=0,
max_blocks=50,
max_friends=250,
post_count=0,
)
db.add(lazer_profile)
await db.commit()
# 返回成功响应
return JSONResponse(
status_code=201,
content={
"message": "Account created successfully",
"user_id": user_id
}
)
except Exception as e:
await db.rollback()
# 打印详细错误信息用于调试
print(f"Registration error: {e}")
import traceback
traceback.print_exc()
# 返回通用错误
errors = RegistrationRequestErrors(
message="An error occurred while creating your account. Please try again."
)
return JSONResponse(
status_code=500,
content={"form_error": errors.model_dump()}
)
@router.post("/oauth/token", response_model=TokenResponse)
async def oauth_token(
grant_type: str = Form(...),