refactor(user): refactor user database

**Breaking Change**

用户表变为 lazer_users

建议删除与用户关联的表进行迁移
This commit is contained in:
MingxuanGame
2025-07-30 16:17:09 +00:00
parent 3900babe3d
commit 9ce99398ab
37 changed files with 994 additions and 2073 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from datetime import timedelta
from datetime import UTC, datetime, timedelta
import re
from app.auth import (
@@ -12,17 +12,21 @@ from app.auth import (
store_token,
)
from app.config import settings
from app.database import User as DBUser
from app.database import DailyChallengeStats, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
from app.log import logger
from app.models.oauth import (
OAuthErrorResponse,
RegistrationRequestErrors,
TokenResponse,
UserRegistrationErrors,
)
from app.models.score import GameMode
from fastapi import APIRouter, Depends, Form
from fastapi.responses import JSONResponse
from sqlalchemy import text
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -110,12 +114,12 @@ async def register_user(
email_errors = validate_email(user_email)
password_errors = validate_password(user_password)
result = await db.exec(select(DBUser).where(DBUser.name == user_username))
result = await db.exec(select(User).where(User.username == user_username))
existing_user = result.first()
if existing_user:
username_errors.append("Username is already taken")
result = await db.exec(select(DBUser).where(DBUser.email == user_email))
result = await db.exec(select(User).where(User.email == user_email))
existing_email = result.first()
if existing_email:
email_errors.append("Email is already taken")
@@ -135,119 +139,41 @@ async def register_user(
try:
# 创建新用户
from datetime import datetime
import time
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
result = await db.execute( # pyright: ignore[reportDeprecated]
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
)
)
next_id = result.one()[0]
if next_id <= 2:
await db.execute(text("ALTER TABLE lazer_users AUTO_INCREMENT = 3"))
await db.commit()
new_user = DBUser(
name=user_username,
safe_name=user_username.lower(), # 安全用户名(小写)
new_user = User(
username=user_username,
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
country="CN", # 默认国家
creation_time=int(time.time()),
latest_activity=int(time.time()),
preferred_mode=0, # 默认模式
play_style=0, # 默认游戏风格
country_code="CN", # 默认国家
join_date=datetime.now(UTC),
last_visit=datetime.now(UTC),
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
# 保存用户ID因为会话可能会关闭
user_id = new_user.id
if user_id <= 2:
await db.rollback()
try:
from sqlalchemy import text
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
await db.commit()
# 重新创建用户
new_user = DBUser(
name=user_username,
safe_name=user_username.lower(),
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1,
country="CN",
creation_time=int(time.time()),
latest_activity=int(time.time()),
preferred_mode=0,
play_style=0,
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
user_id = new_user.id
# 最终检查ID是否有效
if user_id <= 2:
await db.rollback()
errors = RegistrationRequestErrors(
message=(
"Failed to create account with valid ID. "
"Please contact support."
)
)
return JSONResponse(
status_code=500, content={"form_error": errors.model_dump()}
)
except Exception as fix_error:
await db.rollback()
print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
errors = RegistrationRequestErrors(
message="Failed to create account with valid ID. Please try again."
)
return JSONResponse(
status_code=500, content={"form_error": errors.model_dump()}
)
# 创建默认的 lazer_profile
from app.database.user import LazerUserProfile
lazer_profile = LazerUserProfile(
user_id=user_id,
is_active=True,
is_bot=False,
is_deleted=False,
is_online=True,
is_supporter=False,
is_restricted=False,
session_verified=False,
has_supported=False,
pm_friends_only=False,
default_group="default",
join_date=datetime.utcnow(),
playmode="osu",
support_level=0,
max_blocks=50,
max_friends=250,
post_count=0,
)
db.add(lazer_profile)
assert new_user.id is not None, "New user ID should not be None"
for i in GameMode:
statistics = UserStatistics(mode=i, user_id=new_user.id)
db.add(statistics)
daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id)
db.add(daily_challenge_user_stats)
await db.commit()
# 返回成功响应
return JSONResponse(
status_code=201,
content={"message": "Account created successfully", "user_id": user_id},
)
except Exception as e:
except Exception:
await db.rollback()
# 打印详细错误信息用于调试
print(f"Registration error: {e}")
import traceback
traceback.print_exc()
logger.exception(f"Registration error for user {user_username}")
# 返回通用错误
errors = RegistrationRequestErrors(
@@ -323,6 +249,7 @@ async def oauth_token(
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user.id
await store_token(
db,
user.id,

View File

@@ -5,12 +5,7 @@ import hashlib
import json
from app.calculator import calculate_beatmap_attribute
from app.database import (
Beatmap,
BeatmapResp,
User as DBUser,
)
from app.database.beatmapset import Beatmapset
from app.database import Beatmap, BeatmapResp, Beatmapset, User
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -39,7 +34,7 @@ async def lookup_beatmap(
id: int | None = Query(default=None, alias="id"),
md5: str | None = Query(default=None, alias="checksum"),
filename: str | None = Query(default=None, alias="filename"),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -62,7 +57,7 @@ async def lookup_beatmap(
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -81,7 +76,7 @@ class BatchGetResp(BaseModel):
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="id", default_factory=list),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if not b_ids:
@@ -126,7 +121,7 @@ async def batch_get_beatmaps(
)
async def get_beatmap_attributes(
beatmap: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
mods: list[str] = Query(default_factory=list),
ruleset: GameMode | None = Query(default=None),
ruleset_id: int | None = Query(default=None),

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from app.database import (
Beatmapset,
BeatmapsetResp,
User as DBUser,
User,
)
from app.dependencies.database import get_db
from app.dependencies.fetcher import get_fetcher
@@ -22,7 +22,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset(
sid: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):

View File

@@ -1,28 +1,34 @@
from __future__ import annotations
from typing import Literal
from app.database import (
User as DBUser,
)
from app.database import User, UserResp
from app.dependencies import get_current_user
from app.models.user import (
User as ApiUser,
)
from app.utils import convert_db_user_to_api_user
from app.dependencies.database import get_db
from app.models.score import GameMode
from .api_router import router
from fastapi import Depends
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/me/{ruleset}", response_model=ApiUser)
@router.get("/me/", response_model=ApiUser)
@router.get("/me/{ruleset}", response_model=UserResp)
@router.get("/me/", response_model=UserResp)
async def get_user_info_default(
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
current_user: DBUser = Depends(get_current_user),
ruleset: GameMode | None = None,
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
):
"""获取当前用户信息默认使用osu模式"""
# 默认使用osu模式
api_user = await convert_db_user_to_api_user(current_user, ruleset)
return api_user
return await UserResp.from_db(
current_user,
session,
[
"friends",
"team",
"account_history",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
"achievements",
],
ruleset,
)

View File

@@ -8,6 +8,7 @@ from app.dependencies.user import get_current_user
from .api_router import router
from fastapi import Depends, HTTPException, Query, Request
from pydantic import BaseModel
from sqlalchemy.orm import joinedload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -36,7 +37,11 @@ async def get_relationship(
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
@router.post("/friends", tags=["relationship"], response_model=RelationshipResp)
class AddFriendResp(BaseModel):
user_relation: RelationshipResp
@router.post("/friends", tags=["relationship"], response_model=AddFriendResp)
@router.post("/blocks", tags=["relationship"])
async def add_relationship(
request: Request,
@@ -98,7 +103,9 @@ async def add_relationship(
)
).first()
assert relationship, "Relationship should exist after commit"
return await RelationshipResp.from_db(db, relationship)
return AddFriendResp(
user_relation=await RelationshipResp.from_db(db, relationship)
)
@router.delete("/friends/{target}", tags=["relationship"])

View File

@@ -1,11 +1,7 @@
from __future__ import annotations
from app.database import (
User as DBUser,
)
from app.database.beatmap import Beatmap
from app.database.score import Score, ScoreResp, process_score, process_user
from app.database.score_token import ScoreToken, ScoreTokenResp
from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User
from app.database.score import process_score, process_user
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -41,7 +37,7 @@ async def get_beatmap_scores(
mode: GameMode | None = Query(None),
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
type: str = Query(None),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -94,7 +90,7 @@ async def get_user_beatmap_score(
legacy_only: bool = Query(None),
mode: str = Query(None),
mods: str = Query(None), # TODO:添加mods筛选
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -134,7 +130,7 @@ async def get_user_all_beatmap_scores(
user: int,
legacy_only: bool = Query(None),
ruleset: str = Query(None),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -166,9 +162,10 @@ async def create_solo_score(
version_hash: str = Form(""),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
assert current_user.id
async with db:
score_token = ScoreToken(
user_id=current_user.id,
@@ -190,7 +187,7 @@ async def submit_solo_score(
beatmap: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),

View File

@@ -1,12 +1,8 @@
from __future__ import annotations
from typing import Literal
from app.database import User as DBUser
from app.database import User, UserResp
from app.dependencies.database import get_db
from app.models.score import INT_TO_MODE
from app.models.user import User as ApiUser
from app.utils import convert_db_user_to_api_user
from app.models.score import GameMode
from .api_router import router
@@ -17,28 +13,17 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import col
# ---------- Shared Utility ----------
async def get_user_by_lookup(
db: AsyncSession, lookup: str, key: str = "id"
) -> DBUser | None:
"""根据查找方式获取用户"""
if key == "id":
try:
user_id = int(lookup)
result = await db.exec(select(DBUser).where(DBUser.id == user_id))
return result.first()
except ValueError:
return None
elif key == "username":
result = await db.exec(select(DBUser).where(DBUser.name == lookup))
return result.first()
else:
return None
# ---------- Batch Users ----------
class BatchUserResponse(BaseModel):
users: list[ApiUser]
users: list[UserResp]
SEARCH_INCLUDE = [
"team",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
"achievements",
]
@router.get("/users", response_model=BatchUserResponse)
@@ -52,74 +37,54 @@ async def get_users(
if user_ids:
searched_users = (
await session.exec(
DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
select(User)
.options(*User.all_select_option())
.limit(50)
.where(col(User.id).in_(user_ids))
)
).all()
else:
searched_users = (
await session.exec(DBUser.all_select_clause().limit(50))
await session.exec(
select(User).options(*User.all_select_option()).limit(50)
)
).all()
return BatchUserResponse(
users=[
await convert_db_user_to_api_user(
searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value
await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDE,
)
for searched_user in searched_users
]
)
# # ---------- Individual User ----------
# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
# async def get_user_with_mode(
# user_lookup: str,
# mode: Literal["osu", "taiko", "fruits", "mania"],
# key: Literal["id", "username"] = Query("id"),
# current_user: DBUser = Depends(get_current_user),
# db: AsyncSession = Depends(get_db),
# ):
# """获取指定游戏模式的用户信息"""
# user = await get_user_by_lookup(db, user_lookup, key)
# if not user:
# raise HTTPException(status_code=404, detail="User not found")
# return await convert_db_user_to_api_user(user, mode)
# @router.get("/users/{user_lookup}", response_model=ApiUser)
# @router.get("/users/{user_lookup}/", response_model=ApiUser)
# async def get_user_default(
# user_lookup: str,
# key: Literal["id", "username"] = Query("id"),
# current_user: DBUser = Depends(get_current_user),
# db: AsyncSession = Depends(get_db),
# ):
# """获取用户信息默认使用osu模式但包含所有模式的统计信息"""
# user = await get_user_by_lookup(db, user_lookup, key)
# if not user:
# raise HTTPException(status_code=404, detail="User not found")
# return await convert_db_user_to_api_user(user, "osu")
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
@router.get("/users/{user}/", response_model=ApiUser)
@router.get("/users/{user}", response_model=ApiUser)
@router.get("/users/{user}/{ruleset}", response_model=UserResp)
@router.get("/users/{user}/", response_model=UserResp)
@router.get("/users/{user}", response_model=UserResp)
async def get_user_info(
user: str,
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
ruleset: GameMode | None = None,
session: AsyncSession = Depends(get_db),
):
searched_user = (
await session.exec(
DBUser.all_select_clause().where(
DBUser.id == int(user)
select(User)
.options(*User.all_select_option())
.where(
User.id == int(user)
if user.isdigit()
else DBUser.name == user.removeprefix("@")
else User.username == user.removeprefix("@")
)
)
).first()
if not searched_user:
raise HTTPException(404, detail="User not found")
return await convert_db_user_to_api_user(searched_user, ruleset=ruleset)
return await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDE,
ruleset=ruleset,
)