From 80310d450b8266fe17299427bb91f2f6d845ea68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= Date: Mon, 28 Jul 2025 19:41:57 +0800 Subject: [PATCH] add Registration Interface --- .gitignore | 1 + app/models/oauth.py | 20 ++++ app/router/auth.py | 234 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 253 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 6923d04..369e759 100644 --- a/.gitignore +++ b/.gitignore @@ -210,3 +210,4 @@ bancho.py-master/* # runtime file replays/ +osu-master/* \ No newline at end of file diff --git a/app/models/oauth.py b/app/models/oauth.py index 3776bdd..22fcf63 100644 --- a/app/models/oauth.py +++ b/app/models/oauth.py @@ -1,6 +1,7 @@ # OAuth 相关模型 from __future__ import annotations +from typing import List from pydantic import BaseModel @@ -34,3 +35,22 @@ class OAuthErrorResponse(BaseModel): error_description: str hint: str message: str + + +class RegistrationErrorResponse(BaseModel): + """注册错误响应模型""" + form_error: dict + + +class UserRegistrationErrors(BaseModel): + """用户注册错误模型""" + username: List[str] = [] + user_email: List[str] = [] + password: List[str] = [] + + +class RegistrationRequestErrors(BaseModel): + """注册请求错误模型""" + message: str | None = None + redirect: str | None = None + user: UserRegistrationErrors | None = None diff --git a/app/router/auth.py b/app/router/auth.py index 9c7df98..e534f18 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -1,6 +1,8 @@ from __future__ import annotations from datetime import timedelta +import re +from typing import List from app.auth import ( authenticate_user, @@ -8,13 +10,16 @@ from app.auth import ( generate_refresh_token, get_token_by_refresh_token, store_token, + get_password_hash, ) from app.config import settings from app.dependencies import get_db -from app.models.oauth import TokenResponse, OAuthErrorResponse +from app.models.oauth import TokenResponse, OAuthErrorResponse, RegistrationErrorResponse, UserRegistrationErrors, RegistrationRequestErrors +from app.database import User as DBUser -from fastapi import APIRouter, Depends, Form +from fastapi import APIRouter, Depends, Form, HTTPException from fastapi.responses import JSONResponse +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -31,9 +36,234 @@ def create_oauth_error_response(error: str, description: str, hint: str, status_ content=error_data.model_dump() ) + + +def validate_username(username: str) -> List[str]: + """验证用户名""" + errors = [] + + if not username: + errors.append("Username is required") + return errors + + if len(username) < 3: + errors.append("Username must be at least 3 characters long") + + if len(username) > 15: + errors.append("Username must be at most 15 characters long") + + # 检查用户名格式(只允许字母、数字、下划线、连字符) + if not re.match(r'^[a-zA-Z0-9_-]+$', username): + errors.append("Username can only contain letters, numbers, underscores, and hyphens") + + # 检查是否以数字开头 + if username[0].isdigit(): + errors.append("Username cannot start with a number") + + return errors + + +def validate_email(email: str) -> List[str]: + """验证邮箱""" + errors = [] + + if not email: + errors.append("Email is required") + return errors + + # 基本的邮箱格式验证 + email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + if not re.match(email_pattern, email): + errors.append("Please enter a valid email address") + + return errors + + +def validate_password(password: str) -> List[str]: + """验证密码""" + errors = [] + + if not password: + errors.append("Password is required") + return errors + + if len(password) < 8: + errors.append("Password must be at least 8 characters long") + + return errors + + router = APIRouter(tags=["osu! OAuth 认证"]) +@router.post("/users") +async def register_user( + user_username: str = Form(..., alias="user[username]"), + user_email: str = Form(..., alias="user[user_email]"), + user_password: str = Form(..., alias="user[password]"), + db: AsyncSession = Depends(get_db), +): + """用户注册接口 - 匹配 osu! 客户端的注册请求""" + + username_errors = validate_username(user_username) + email_errors = validate_email(user_email) + password_errors = validate_password(user_password) + + result = await db.exec( + select(DBUser).where(DBUser.name == user_username) + ) + existing_user = result.first() + if existing_user: + username_errors.append("Username is already taken") + + result = await db.exec( + select(DBUser).where(DBUser.email == user_email) + ) + existing_email = result.first() + if existing_email: + email_errors.append("Email is already taken") + + if username_errors or email_errors or password_errors: + errors = RegistrationRequestErrors( + user=UserRegistrationErrors( + username=username_errors, + user_email=email_errors, + password=password_errors + ) + ) + + return JSONResponse( + status_code=422, + content={"form_error": errors.model_dump()} + ) + + try: + # 创建新用户 + from datetime import datetime + import time + + new_user = DBUser( + name=user_username, + safe_name=user_username.lower(), # 安全用户名(小写) + email=user_email, + pw_bcrypt=get_password_hash(user_password), + priv=1, # 普通用户权限 + country="CN", # 默认国家 + creation_time=int(time.time()), + latest_activity=int(time.time()), + preferred_mode=0, # 默认模式 + play_style=0, # 默认游戏风格 + ) + + db.add(new_user) + await db.commit() + await db.refresh(new_user) + + # 保存用户ID,因为会话可能会关闭 + user_id = new_user.id + + if user_id <= 2: + await db.rollback() + try: + from sqlalchemy import text + + # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy) + await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3")) + await db.commit() + + # 重新创建用户 + new_user = DBUser( + name=user_username, + safe_name=user_username.lower(), + email=user_email, + pw_bcrypt=get_password_hash(user_password), + priv=1, + country="CN", + creation_time=int(time.time()), + latest_activity=int(time.time()), + preferred_mode=0, + play_style=0, + ) + + db.add(new_user) + await db.commit() + await db.refresh(new_user) + user_id = new_user.id + + # 最终检查ID是否有效 + if user_id <= 2: + await db.rollback() + errors = RegistrationRequestErrors( + message="Failed to create account with valid ID. Please contact support." + ) + return JSONResponse( + status_code=500, + content={"form_error": errors.model_dump()} + ) + + except Exception as fix_error: + await db.rollback() + print(f"Failed to fix AUTO_INCREMENT: {fix_error}") + errors = RegistrationRequestErrors( + message="Failed to create account with valid ID. Please try again." + ) + return JSONResponse( + status_code=500, + content={"form_error": errors.model_dump()} + ) + + # 创建默认的 lazer_profile + from app.database.user import LazerUserProfile + lazer_profile = LazerUserProfile( + user_id=user_id, + is_active=True, + is_bot=False, + is_deleted=False, + is_online=True, + is_supporter=False, + is_restricted=False, + session_verified=False, + has_supported=False, + pm_friends_only=False, + default_group="default", + join_date=datetime.utcnow(), + playmode="osu", + support_level=0, + max_blocks=50, + max_friends=250, + post_count=0, + ) + + db.add(lazer_profile) + await db.commit() + + # 返回成功响应 + return JSONResponse( + status_code=201, + content={ + "message": "Account created successfully", + "user_id": user_id + } + ) + + except Exception as e: + await db.rollback() + # 打印详细错误信息用于调试 + print(f"Registration error: {e}") + import traceback + traceback.print_exc() + + # 返回通用错误 + errors = RegistrationRequestErrors( + message="An error occurred while creating your account. Please try again." + ) + + return JSONResponse( + status_code=500, + content={"form_error": errors.model_dump()} + ) + + @router.post("/oauth/token", response_model=TokenResponse) async def oauth_token( grant_type: str = Form(...),