add Registration Interface
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -210,3 +210,4 @@ bancho.py-master/*
|
|||||||
|
|
||||||
# runtime file
|
# runtime file
|
||||||
replays/
|
replays/
|
||||||
|
osu-master/*
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
# OAuth 相关模型
|
# OAuth 相关模型
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@@ -34,3 +35,22 @@ class OAuthErrorResponse(BaseModel):
|
|||||||
error_description: str
|
error_description: str
|
||||||
hint: str
|
hint: str
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationErrorResponse(BaseModel):
|
||||||
|
"""注册错误响应模型"""
|
||||||
|
form_error: dict
|
||||||
|
|
||||||
|
|
||||||
|
class UserRegistrationErrors(BaseModel):
|
||||||
|
"""用户注册错误模型"""
|
||||||
|
username: List[str] = []
|
||||||
|
user_email: List[str] = []
|
||||||
|
password: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationRequestErrors(BaseModel):
|
||||||
|
"""注册请求错误模型"""
|
||||||
|
message: str | None = None
|
||||||
|
redirect: str | None = None
|
||||||
|
user: UserRegistrationErrors | None = None
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from app.auth import (
|
from app.auth import (
|
||||||
authenticate_user,
|
authenticate_user,
|
||||||
@@ -8,13 +10,16 @@ from app.auth import (
|
|||||||
generate_refresh_token,
|
generate_refresh_token,
|
||||||
get_token_by_refresh_token,
|
get_token_by_refresh_token,
|
||||||
store_token,
|
store_token,
|
||||||
|
get_password_hash,
|
||||||
)
|
)
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.dependencies import get_db
|
from app.dependencies import get_db
|
||||||
from app.models.oauth import TokenResponse, OAuthErrorResponse
|
from app.models.oauth import TokenResponse, OAuthErrorResponse, RegistrationErrorResponse, UserRegistrationErrors, RegistrationRequestErrors
|
||||||
|
from app.database import User as DBUser
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Form
|
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
@@ -31,9 +36,234 @@ def create_oauth_error_response(error: str, description: str, hint: str, status_
|
|||||||
content=error_data.model_dump()
|
content=error_data.model_dump()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def validate_username(username: str) -> List[str]:
|
||||||
|
"""验证用户名"""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
if not username:
|
||||||
|
errors.append("Username is required")
|
||||||
|
return errors
|
||||||
|
|
||||||
|
if len(username) < 3:
|
||||||
|
errors.append("Username must be at least 3 characters long")
|
||||||
|
|
||||||
|
if len(username) > 15:
|
||||||
|
errors.append("Username must be at most 15 characters long")
|
||||||
|
|
||||||
|
# 检查用户名格式(只允许字母、数字、下划线、连字符)
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
||||||
|
errors.append("Username can only contain letters, numbers, underscores, and hyphens")
|
||||||
|
|
||||||
|
# 检查是否以数字开头
|
||||||
|
if username[0].isdigit():
|
||||||
|
errors.append("Username cannot start with a number")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def validate_email(email: str) -> List[str]:
|
||||||
|
"""验证邮箱"""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
if not email:
|
||||||
|
errors.append("Email is required")
|
||||||
|
return errors
|
||||||
|
|
||||||
|
# 基本的邮箱格式验证
|
||||||
|
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||||||
|
if not re.match(email_pattern, email):
|
||||||
|
errors.append("Please enter a valid email address")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def validate_password(password: str) -> List[str]:
|
||||||
|
"""验证密码"""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
if not password:
|
||||||
|
errors.append("Password is required")
|
||||||
|
return errors
|
||||||
|
|
||||||
|
if len(password) < 8:
|
||||||
|
errors.append("Password must be at least 8 characters long")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/users")
|
||||||
|
async def register_user(
|
||||||
|
user_username: str = Form(..., alias="user[username]"),
|
||||||
|
user_email: str = Form(..., alias="user[user_email]"),
|
||||||
|
user_password: str = Form(..., alias="user[password]"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""用户注册接口 - 匹配 osu! 客户端的注册请求"""
|
||||||
|
|
||||||
|
username_errors = validate_username(user_username)
|
||||||
|
email_errors = validate_email(user_email)
|
||||||
|
password_errors = validate_password(user_password)
|
||||||
|
|
||||||
|
result = await db.exec(
|
||||||
|
select(DBUser).where(DBUser.name == user_username)
|
||||||
|
)
|
||||||
|
existing_user = result.first()
|
||||||
|
if existing_user:
|
||||||
|
username_errors.append("Username is already taken")
|
||||||
|
|
||||||
|
result = await db.exec(
|
||||||
|
select(DBUser).where(DBUser.email == user_email)
|
||||||
|
)
|
||||||
|
existing_email = result.first()
|
||||||
|
if existing_email:
|
||||||
|
email_errors.append("Email is already taken")
|
||||||
|
|
||||||
|
if username_errors or email_errors or password_errors:
|
||||||
|
errors = RegistrationRequestErrors(
|
||||||
|
user=UserRegistrationErrors(
|
||||||
|
username=username_errors,
|
||||||
|
user_email=email_errors,
|
||||||
|
password=password_errors
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422,
|
||||||
|
content={"form_error": errors.model_dump()}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建新用户
|
||||||
|
from datetime import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
|
new_user = DBUser(
|
||||||
|
name=user_username,
|
||||||
|
safe_name=user_username.lower(), # 安全用户名(小写)
|
||||||
|
email=user_email,
|
||||||
|
pw_bcrypt=get_password_hash(user_password),
|
||||||
|
priv=1, # 普通用户权限
|
||||||
|
country="CN", # 默认国家
|
||||||
|
creation_time=int(time.time()),
|
||||||
|
latest_activity=int(time.time()),
|
||||||
|
preferred_mode=0, # 默认模式
|
||||||
|
play_style=0, # 默认游戏风格
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(new_user)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(new_user)
|
||||||
|
|
||||||
|
# 保存用户ID,因为会话可能会关闭
|
||||||
|
user_id = new_user.id
|
||||||
|
|
||||||
|
if user_id <= 2:
|
||||||
|
await db.rollback()
|
||||||
|
try:
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||||
|
await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 重新创建用户
|
||||||
|
new_user = DBUser(
|
||||||
|
name=user_username,
|
||||||
|
safe_name=user_username.lower(),
|
||||||
|
email=user_email,
|
||||||
|
pw_bcrypt=get_password_hash(user_password),
|
||||||
|
priv=1,
|
||||||
|
country="CN",
|
||||||
|
creation_time=int(time.time()),
|
||||||
|
latest_activity=int(time.time()),
|
||||||
|
preferred_mode=0,
|
||||||
|
play_style=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(new_user)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(new_user)
|
||||||
|
user_id = new_user.id
|
||||||
|
|
||||||
|
# 最终检查ID是否有效
|
||||||
|
if user_id <= 2:
|
||||||
|
await db.rollback()
|
||||||
|
errors = RegistrationRequestErrors(
|
||||||
|
message="Failed to create account with valid ID. Please contact support."
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"form_error": errors.model_dump()}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as fix_error:
|
||||||
|
await db.rollback()
|
||||||
|
print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
|
||||||
|
errors = RegistrationRequestErrors(
|
||||||
|
message="Failed to create account with valid ID. Please try again."
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"form_error": errors.model_dump()}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建默认的 lazer_profile
|
||||||
|
from app.database.user import LazerUserProfile
|
||||||
|
lazer_profile = LazerUserProfile(
|
||||||
|
user_id=user_id,
|
||||||
|
is_active=True,
|
||||||
|
is_bot=False,
|
||||||
|
is_deleted=False,
|
||||||
|
is_online=True,
|
||||||
|
is_supporter=False,
|
||||||
|
is_restricted=False,
|
||||||
|
session_verified=False,
|
||||||
|
has_supported=False,
|
||||||
|
pm_friends_only=False,
|
||||||
|
default_group="default",
|
||||||
|
join_date=datetime.utcnow(),
|
||||||
|
playmode="osu",
|
||||||
|
support_level=0,
|
||||||
|
max_blocks=50,
|
||||||
|
max_friends=250,
|
||||||
|
post_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(lazer_profile)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 返回成功响应
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=201,
|
||||||
|
content={
|
||||||
|
"message": "Account created successfully",
|
||||||
|
"user_id": user_id
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
# 打印详细错误信息用于调试
|
||||||
|
print(f"Registration error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# 返回通用错误
|
||||||
|
errors = RegistrationRequestErrors(
|
||||||
|
message="An error occurred while creating your account. Please try again."
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"form_error": errors.model_dump()}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/oauth/token", response_model=TokenResponse)
|
@router.post("/oauth/token", response_model=TokenResponse)
|
||||||
async def oauth_token(
|
async def oauth_token(
|
||||||
grant_type: str = Form(...),
|
grant_type: str = Form(...),
|
||||||
|
|||||||
Reference in New Issue
Block a user