feat(user): support search user

This commit is contained in:
MingxuanGame
2025-07-28 14:18:43 +00:00
parent 9b889bc602
commit e1ce364ac9
5 changed files with 174 additions and 75 deletions

View File

@@ -7,7 +7,8 @@ from .team import TeamMember
from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text
from sqlalchemy.dialects.mysql import VARCHAR
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
from sqlalchemy.orm import joinedload, selectinload
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select
class User(SQLModel, table=True):
@@ -100,6 +101,26 @@ class User(SQLModel, table=True):
back_populates="user"
)
@classmethod
def all_select_clause(cls):
return select(cls).options(
joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType]
joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType]
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
joinedload(cls.avatar), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType]
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
joinedload(cls.team_membership), # pyright: ignore[reportArgumentType]
selectinload(cls.rank_history), # pyright: ignore[reportArgumentType]
selectinload(cls.active_banners), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType]
)
# ============================================
# Lazer API 专用表模型

View File

@@ -9,8 +9,6 @@ from .database import get_db
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import joinedload, selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer()
@@ -35,25 +33,7 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | No
return None
user = (
await db.exec(
select(DBUser)
.options(
joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType]
joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType]
joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType]
joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
joinedload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType]
)
.where(DBUser.id == token_record.user_id)
DBUser.all_select_clause().where(DBUser.id == token_record.user_id)
)
).first()
return user

View File

@@ -8,6 +8,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
me,
relationship,
score,
user,
)
from .api_router import router as api_router
from .auth import router as auth_router

View File

@@ -2,43 +2,42 @@ from __future__ import annotations
from datetime import timedelta
import re
from typing import List
from app.auth import (
authenticate_user,
create_access_token,
generate_refresh_token,
get_password_hash,
get_token_by_refresh_token,
store_token,
get_password_hash,
)
from app.config import settings
from app.dependencies import get_db
from app.models.oauth import TokenResponse, OAuthErrorResponse, RegistrationErrorResponse, UserRegistrationErrors, RegistrationRequestErrors
from app.database import User as DBUser
from app.dependencies import get_db
from app.models.oauth import (
OAuthErrorResponse,
RegistrationRequestErrors,
TokenResponse,
UserRegistrationErrors,
)
from fastapi import APIRouter, Depends, Form, HTTPException
from fastapi import APIRouter, Depends, Form
from fastapi.responses import JSONResponse
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
def create_oauth_error_response(
error: str, description: str, hint: str, status_code: int = 400
):
"""创建标准的 OAuth 错误响应"""
error_data = OAuthErrorResponse(
error=error,
error_description=description,
hint=hint,
message=description
)
return JSONResponse(
status_code=status_code,
content=error_data.model_dump()
error=error, error_description=description, hint=hint, message=description
)
return JSONResponse(status_code=status_code, content=error_data.model_dump())
def validate_username(username: str) -> List[str]:
def validate_username(username: str) -> list[str]:
"""验证用户名"""
errors = []
@@ -53,8 +52,10 @@ def validate_username(username: str) -> List[str]:
errors.append("Username must be at most 15 characters long")
# 检查用户名格式(只允许字母、数字、下划线、连字符)
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
errors.append("Username can only contain letters, numbers, underscores, and hyphens")
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
errors.append(
"Username can only contain letters, numbers, underscores, and hyphens"
)
# 检查是否以数字开头
if username[0].isdigit():
@@ -63,7 +64,7 @@ def validate_username(username: str) -> List[str]:
return errors
def validate_email(email: str) -> List[str]:
def validate_email(email: str) -> list[str]:
"""验证邮箱"""
errors = []
@@ -72,14 +73,14 @@ def validate_email(email: str) -> List[str]:
return errors
# 基本的邮箱格式验证
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if not re.match(email_pattern, email):
errors.append("Please enter a valid email address")
return errors
def validate_password(password: str) -> List[str]:
def validate_password(password: str) -> list[str]:
"""验证密码"""
errors = []
@@ -109,16 +110,12 @@ async def register_user(
email_errors = validate_email(user_email)
password_errors = validate_password(user_password)
result = await db.exec(
select(DBUser).where(DBUser.name == user_username)
)
result = await db.exec(select(DBUser).where(DBUser.name == user_username))
existing_user = result.first()
if existing_user:
username_errors.append("Username is already taken")
result = await db.exec(
select(DBUser).where(DBUser.email == user_email)
)
result = await db.exec(select(DBUser).where(DBUser.email == user_email))
existing_email = result.first()
if existing_email:
email_errors.append("Email is already taken")
@@ -128,13 +125,12 @@ async def register_user(
user=UserRegistrationErrors(
username=username_errors,
user_email=email_errors,
password=password_errors
password=password_errors,
)
)
return JSONResponse(
status_code=422,
content={"form_error": errors.model_dump()}
status_code=422, content={"form_error": errors.model_dump()}
)
try:
@@ -194,11 +190,13 @@ async def register_user(
if user_id <= 2:
await db.rollback()
errors = RegistrationRequestErrors(
message="Failed to create account with valid ID. Please contact support."
message=(
"Failed to create account with valid ID. "
"Please contact support."
)
)
return JSONResponse(
status_code=500,
content={"form_error": errors.model_dump()}
status_code=500, content={"form_error": errors.model_dump()}
)
except Exception as fix_error:
@@ -208,12 +206,12 @@ async def register_user(
message="Failed to create account with valid ID. Please try again."
)
return JSONResponse(
status_code=500,
content={"form_error": errors.model_dump()}
status_code=500, content={"form_error": errors.model_dump()}
)
# 创建默认的 lazer_profile
from app.database.user import LazerUserProfile
lazer_profile = LazerUserProfile(
user_id=user_id,
is_active=True,
@@ -240,10 +238,7 @@ async def register_user(
# 返回成功响应
return JSONResponse(
status_code=201,
content={
"message": "Account created successfully",
"user_id": user_id
}
content={"message": "Account created successfully", "user_id": user_id},
)
except Exception as e:
@@ -251,6 +246,7 @@ async def register_user(
# 打印详细错误信息用于调试
print(f"Registration error: {e}")
import traceback
traceback.print_exc()
# 返回通用错误
@@ -259,8 +255,7 @@ async def register_user(
)
return JSONResponse(
status_code=500,
content={"form_error": errors.model_dump()}
status_code=500, content={"form_error": errors.model_dump()}
)
@@ -283,9 +278,13 @@ async def oauth_token(
):
return create_oauth_error_response(
error="invalid_client",
description="Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method).",
description=(
"Client authentication failed (e.g., unknown client, "
"no client authentication included, "
"or unsupported authentication method)."
),
hint="Invalid client credentials",
status_code=401
status_code=401,
)
if grant_type == "password":
@@ -293,8 +292,12 @@ async def oauth_token(
if not username or not password:
return create_oauth_error_response(
error="invalid_request",
description="The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.",
hint="Username and password required"
description=(
"The request is missing a required parameter, includes an "
"invalid parameter value, "
"includes a parameter more than once, or is otherwise malformed."
),
hint="Username and password required",
)
# 验证用户
@@ -302,8 +305,14 @@ async def oauth_token(
if not user:
return create_oauth_error_response(
error="invalid_grant",
description="The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client.",
hint="Incorrect sign in"
description=(
"The provided authorization grant (e.g., authorization code, "
"resource owner credentials) "
"or refresh token is invalid, expired, revoked, "
"does not match the redirection URI used in "
"the authorization request, or was issued to another client."
),
hint="Incorrect sign in",
)
# 生成令牌
@@ -335,8 +344,12 @@ async def oauth_token(
if not refresh_token:
return create_oauth_error_response(
error="invalid_request",
description="The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.",
hint="Refresh token required"
description=(
"The request is missing a required parameter, "
"includes an invalid parameter value, "
"includes a parameter more than once, or is otherwise malformed."
),
hint="Refresh token required",
)
# 验证刷新令牌
@@ -344,8 +357,14 @@ async def oauth_token(
if not token_record:
return create_oauth_error_response(
error="invalid_grant",
description="The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client.",
hint="Invalid refresh token"
description=(
"The provided authorization grant (e.g., authorization code, "
"resource owner credentials) or refresh token is "
"invalid, expired, revoked, "
"does not match the redirection URI used "
"in the authorization request, or was issued to another client."
),
hint="Invalid refresh token",
)
# 生成新的访问令牌
@@ -375,6 +394,9 @@ async def oauth_token(
else:
return create_oauth_error_response(
error="unsupported_grant_type",
description="The authorization grant type is not supported by the authorization server.",
hint="Unsupported grant type"
description=(
"The authorization grant type is not supported "
"by the authorization server."
),
hint="Unsupported grant type",
)

75
app/router/user.py Normal file
View File

@@ -0,0 +1,75 @@
from __future__ import annotations
from typing import Literal
from app.database import (
User as DBUser,
)
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.models.score import INT_TO_MODE
from app.models.user import (
User as ApiUser,
)
from app.utils import convert_db_user_to_api_user
from .api_router import router
from fastapi import Depends, HTTPException, Query
from pydantic import BaseModel
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import col
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
@router.get("/users/{user}", response_model=ApiUser)
async def get_user_info_default(
user: str,
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
current_user: DBUser = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
):
searched_user = (
await session.exec(
DBUser.all_select_clause().where(
DBUser.id == int(user)
if user.isdigit()
else DBUser.name == user.removeprefix("@")
)
)
).first()
if not searched_user:
raise HTTPException(404, detail="User not found")
return await convert_db_user_to_api_user(searched_user, ruleset=ruleset)
class BatchUserResponse(BaseModel):
users: list[ApiUser]
@router.get("/users", response_model=BatchUserResponse)
@router.get("/users/lookup", response_model=BatchUserResponse)
async def get_users(
user_ids: list[int] = Query(default_factory=list, alias="ids[]"),
include_variant_statistics: bool = Query(default=False), # TODO
current_user: DBUser = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
):
if user_ids:
searched_users = (
await session.exec(
DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
)
).all()
else:
searched_users = (
await session.exec(DBUser.all_select_clause().limit(50))
).all()
return BatchUserResponse(
users=[
await convert_db_user_to_api_user(
searched_user, ruleset=INT_TO_MODE[current_user.preferred_mode].value
)
for searched_user in searched_users
]
)