From c43ca883a5a68f668177d43bdd24b16ae3571bc4 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 24 Jul 2025 20:49:07 +0800 Subject: [PATCH] refactor(database): migrate to sqlmodel --- app/auth.py | 49 +- app/database.py | 873 +++++++++++++++++------------------ app/dependencies/database.py | 11 +- app/dependencies/user.py | 14 +- app/models/user.py | 13 +- app/router/auth.py | 4 +- app/utils.py | 282 +++-------- create_sample_data.py | 53 ++- pyproject.toml | 1 + test_lazer.py | 10 +- uv.lock | 15 + 11 files changed, 582 insertions(+), 743 deletions(-) diff --git a/app/auth.py b/app/auth.py index 7bebdd5..da141a0 100644 --- a/app/auth.py +++ b/app/auth.py @@ -4,7 +4,6 @@ from datetime import datetime, timedelta import hashlib import secrets import string -from typing import Optional from app.config import settings from app.database import ( @@ -15,7 +14,7 @@ from app.database import ( import bcrypt from jose import JWTError, jwt from passlib.context import CryptContext -from sqlalchemy.orm import Session +from sqlmodel import Session, select # 密码哈希上下文 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -71,7 +70,7 @@ def get_password_hash(password: str) -> str: return pw_bcrypt.decode() -def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[DBUser]: +def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None: """ 验证用户身份 - 使用类似 from_login 的逻辑 """ @@ -79,12 +78,13 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[ pw_md5 = hashlib.md5(password.encode()).hexdigest() # 2. 根据用户名查找用户 - user = db.query(DBUser).filter(DBUser.name == name).first() + statement = select(DBUser).where(DBUser.name == name) + user = db.exec(statement).first() if not user: return None - # 3. 验证密码 - if not (user.pw_bcrypt is None and user.pw_bcrypt != ""): + # 3. 验证密码 - 修复逻辑错误 + if user.pw_bcrypt is None or user.pw_bcrypt == "": return None # 4. 检查缓存 @@ -107,12 +107,12 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[ return None -def authenticate_user(db: Session, username: str, password: str) -> Optional[DBUser]: +def authenticate_user(db: Session, username: str, password: str) -> DBUser | None: """验证用户身份""" return authenticate_user_legacy(db, username, password) -def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: +def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str: """创建访问令牌""" to_encode = data.copy() if expires_delta: @@ -136,7 +136,7 @@ def generate_refresh_token() -> str: return "".join(secrets.choice(characters) for _ in range(length)) -def verify_token(token: str) -> Optional[dict]: +def verify_token(token: str) -> dict | None: """验证访问令牌""" try: payload = jwt.decode( @@ -154,7 +154,10 @@ def store_token( expires_at = datetime.utcnow() + timedelta(seconds=expires_in) # 删除用户的旧令牌 - db.query(OAuthToken).filter(OAuthToken.user_id == user_id).delete() + statement = select(OAuthToken).where(OAuthToken.user_id == user_id) + old_tokens = db.exec(statement).all() + for token in old_tokens: + db.delete(token) # 创建新令牌记录 token_record = OAuthToken( @@ -169,25 +172,19 @@ def store_token( return token_record -def get_token_by_access_token(db: Session, access_token: str) -> Optional[OAuthToken]: +def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None: """根据访问令牌获取令牌记录""" - return ( - db.query(OAuthToken) - .filter( - OAuthToken.access_token == access_token, - OAuthToken.expires_at > datetime.utcnow(), - ) - .first() + statement = select(OAuthToken).where( + OAuthToken.access_token == access_token, + OAuthToken.expires_at > datetime.utcnow(), ) + return db.exec(statement).first() -def get_token_by_refresh_token(db: Session, refresh_token: str) -> Optional[OAuthToken]: +def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None: """根据刷新令牌获取令牌记录""" - return ( - db.query(OAuthToken) - .filter( - OAuthToken.refresh_token == refresh_token, - OAuthToken.expires_at > datetime.utcnow(), - ) - .first() + statement = select(OAuthToken).where( + OAuthToken.refresh_token == refresh_token, + OAuthToken.expires_at > datetime.utcnow(), ) + return db.exec(statement).first() diff --git a/app/database.py b/app/database.py index 5bf536a..f45772e 100644 --- a/app/database.py +++ b/app/database.py @@ -1,61 +1,47 @@ -from __future__ import annotations - +# ruff: noqa: I002 from dataclasses import dataclass from datetime import datetime +from typing import TYPE_CHECKING, Optional -from sqlalchemy import ( - DECIMAL, - JSON, - Boolean, - Column, - Date, - DateTime, - Float, - ForeignKey, - Integer, - String, - Text, -) -from sqlalchemy.dialects.mysql import TINYINT, VARCHAR -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import relationship +from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text +from sqlalchemy.dialects.mysql import VARCHAR +from sqlmodel import Field, Relationship, SQLModel -Base = declarative_base() +if TYPE_CHECKING: + pass -class User(Base): - __tablename__ = "users" +class User(SQLModel, table=True): + __tablename__ = "users" # pyright: ignore[reportAssignmentType] # 主键 - id = Column(Integer, primary_key=True, index=True) + id: int | None = Field(default=None, primary_key=True, index=True) # 基本信息(匹配 migrations 中的结构) - name = Column(String(32), unique=True, index=True, nullable=False) # 用户名 - safe_name = Column( - String(32), unique=True, index=True, nullable=False - ) # 安全用户名 - email = Column(String(254), unique=True, index=True, nullable=False) - priv = Column(Integer, default=1, nullable=False) # 权限 - pw_bcrypt = Column(String(60), nullable=False) # bcrypt 哈希密码 - country = Column(String(2), default="CN", nullable=False) # 国家代码 + name: str = Field(max_length=32, unique=True, index=True) # 用户名 + safe_name: str = Field(max_length=32, unique=True, index=True) # 安全用户名 + email: str = Field(max_length=254, unique=True, index=True) + priv: int = Field(default=1) # 权限 + pw_bcrypt: str = Field(max_length=60) # bcrypt 哈希密码 + country: str = Field(default="CN", max_length=2) # 国家代码 # 状态和时间 - silence_end = Column(Integer, default=0, nullable=False) - donor_end = Column(Integer, default=0, nullable=False) - creation_time = Column(Integer, default=0, nullable=False) # Unix 时间戳 - latest_activity = Column(Integer, default=0, nullable=False) # Unix 时间戳 + silence_end: int = Field(default=0) + donor_end: int = Field(default=0) + creation_time: int = Field(default=0) # Unix 时间戳 + latest_activity: int = Field(default=0) # Unix 时间戳 # 游戏相关 - preferred_mode = Column(Integer, default=0, nullable=False) # 偏好游戏模式 - play_style = Column(Integer, default=0, nullable=False) # 游戏风格 + preferred_mode: int = Field(default=0) # 偏好游戏模式 + play_style: int = Field(default=0) # 游戏风格 # 扩展信息 - clan_id = Column(Integer, default=0, nullable=False) - clan_priv = Column(Integer, default=0, nullable=False) - custom_badge_name = Column(String(16)) - custom_badge_icon = Column(String(64)) - userpage_content = Column(String(2048)) - api_key = Column(String(36), unique=True) + clan_id: int = Field(default=0) + clan_priv: int = Field(default=0) + custom_badge_name: str | None = Field(default=None, max_length=16) + custom_badge_icon: str | None = Field(default=None, max_length=64) + userpage_content: str | None = Field(default=None, max_length=2048) + api_key: str | None = Field(default=None, max_length=36, unique=True) # 虚拟字段用于兼容性 @property @@ -81,73 +67,33 @@ class User(Base): return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None # 关联关系 - lazer_profile = relationship( - "LazerUserProfile", - back_populates="user", - uselist=False, - cascade="all, delete-orphan", + lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user") + lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user") + lazer_counts: Optional["LazerUserCounts"] = Relationship(back_populates="user") + lazer_achievements: list["LazerUserAchievement"] = Relationship( + back_populates="user" ) - lazer_statistics = relationship( - "LazerUserStatistics", back_populates="user", cascade="all, delete-orphan" + lazer_profile_sections: list["LazerUserProfileSections"] = Relationship( + back_populates="user" ) - lazer_achievements = relationship( - "LazerUserAchievement", back_populates="user", cascade="all, delete-orphan" + statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user") + team_membership: list["TeamMember"] = Relationship(back_populates="user") + daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship( + back_populates="user" ) - lazer_profile_sections = relationship( - "LazerUserProfileSections", # 修正类名拼写(添加s) - back_populates="user", - cascade="all, delete-orphan", + rank_history: list["RankHistory"] = Relationship(back_populates="user") + avatar: Optional["UserAvatar"] = Relationship(back_populates="user") + active_banners: list["LazerUserBanners"] = Relationship(back_populates="user") + lazer_badges: list["LazerUserBadge"] = Relationship(back_populates="user") + lazer_monthly_playcounts: list["LazerUserMonthlyPlaycounts"] = Relationship( + back_populates="user" ) - statistics = relationship( - "LegacyUserStatistics", back_populates="user", cascade="all, delete-orphan" + lazer_previous_usernames: list["LazerUserPreviousUsername"] = Relationship( + back_populates="user" ) - achievements = relationship( - "LazerUserAchievement", - back_populates="user", - cascade="all, delete-orphan", - overlaps="lazer_achievements", + lazer_replays_watched: list["LazerUserReplaysWatched"] = Relationship( + back_populates="user" ) - team_membership = relationship( - "TeamMember", back_populates="user", cascade="all, delete-orphan" - ) - daily_challenge_stats = relationship( - "DailyChallengeStats", - back_populates="user", - uselist=False, - cascade="all, delete-orphan", - ) - rank_history = relationship( - "RankHistory", back_populates="user", cascade="all, delete-orphan" - ) - avatar = relationship( - "UserAvatar", - back_populates="user", - primaryjoin="and_(User.id==UserAvatar.user_id, UserAvatar.is_active==True)", - uselist=False, - ) - active_banners = relationship( - "LazerUserBanners", # 原定义指向LazerUserBanners,实际应为UserAvatar - back_populates="user", - primaryjoin=( - "and_(User.id==LazerUserBanners.user_id, LazerUserBanners.is_active==True)" - ), - uselist=False, - ) - lazer_badges = relationship( - "LazerUserBadge", back_populates="user", cascade="all, delete-orphan" - ) - lazer_monthly_playcounts = relationship( - "LazerUserMonthlyPlaycounts", - back_populates="user", - cascade="all, delete-orphan", - ) - lazer_previous_usernames = relationship( - "LazerUserPreviousUsername", back_populates="user", cascade="all, delete-orphan" - ) - lazer_replays_watched = relationship( - "LazerUserReplaysWatched", back_populates="user", cascade="all, delete-orphan" - ) - # ============================================ @@ -155,265 +101,303 @@ class User(Base): # ============================================ -class LazerUserProfile(Base): - __tablename__ = "lazer_user_profiles" +class LazerUserProfile(SQLModel, table=True): + __tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType] - user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) + user_id: int = Field(foreign_key="users.id", primary_key=True) # 基本状态字段 - is_active = Column(Boolean, default=True) - is_bot = Column(Boolean, default=False) - is_deleted = Column(Boolean, default=False) - is_online = Column(Boolean, default=True) - is_supporter = Column(Boolean, default=False) - is_restricted = Column(Boolean, default=False) - session_verified = Column(Boolean, default=False) - has_supported = Column(Boolean, default=False) - pm_friends_only = Column(Boolean, default=False) + is_active: bool = Field(default=True) + is_bot: bool = Field(default=False) + is_deleted: bool = Field(default=False) + is_online: bool = Field(default=True) + is_supporter: bool = Field(default=False) + is_restricted: bool = Field(default=False) + session_verified: bool = Field(default=False) + has_supported: bool = Field(default=False) + pm_friends_only: bool = Field(default=False) # 基本资料字段 - default_group = Column(String(50), default="default") - last_visit = Column(DateTime) - join_date = Column(DateTime) - profile_colour = Column(String(7)) - profile_hue = Column(Integer) + default_group: str = Field(default="default", max_length=50) + last_visit: datetime | None = Field(default=None, sa_column=Column(DateTime)) + join_date: datetime | None = Field(default=None, sa_column=Column(DateTime)) + profile_colour: str | None = Field(default=None, max_length=7) + profile_hue: int | None = Field(default=None) # 社交媒体和个人资料字段 - avatar_url = Column(String(500)) - cover_url = Column(String(500)) - discord = Column(String(100)) - twitter = Column(String(100)) - website = Column(String(500)) - title = Column(String(100)) - title_url = Column(String(500)) - interests = Column(Text) - location = Column(String(100)) + avatar_url: str | None = Field(default=None, max_length=500) + cover_url: str | None = Field(default=None, max_length=500) + discord: str | None = Field(default=None, max_length=100) + twitter: str | None = Field(default=None, max_length=100) + website: str | None = Field(default=None, max_length=500) + title: str | None = Field(default=None, max_length=100) + title_url: str | None = Field(default=None, max_length=500) + interests: str | None = Field(default=None, sa_column=Column(Text)) + location: str | None = Field(default=None, max_length=100) - occupation = None # 职业字段,默认为 None + occupation: str | None = Field(default=None) # 职业字段,默认为 None # 游戏相关字段 - playmode = Column(String(10), default="osu") - support_level = Column(Integer, default=0) - max_blocks = Column(Integer, default=100) - max_friends = Column(Integer, default=500) - post_count = Column(Integer, default=0) + playmode: str = Field(default="osu", max_length=10) + support_level: int = Field(default=0) + max_blocks: int = Field(default=100) + max_friends: int = Field(default=500) + post_count: int = Field(default=0) # 页面内容 - page_html = Column(Text) - page_raw = Column(Text) + page_html: str | None = Field(default=None, sa_column=Column(Text)) + page_raw: str | None = Field(default=None, sa_column=Column(Text)) - # created_at = Column(DateTime, default=datetime.utcnow) - # updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + profile_order: str = Field( + default="me,recent_activity,top_ranks,medals,historical,beatmaps,kudosu" + ) # 关联关系 - user = relationship("User", back_populates="lazer_profile") + user: "User" = Relationship(back_populates="lazer_profile") -class LazerUserProfileSections(Base): - __tablename__ = "lazer_user_profile_sections" +class LazerUserProfileSections(SQLModel, table=True): + __tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - section_name = Column(VARCHAR(50), nullable=False) - display_order = Column(Integer) + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="users.id") + section_name: str = Field(sa_column=Column(VARCHAR(50))) + display_order: int | None = Field(default=None) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) - user = relationship("User", back_populates="lazer_profile_sections") + user: "User" = Relationship(back_populates="lazer_profile_sections") -class LazerUserCountry(Base): - __tablename__ = "lazer_user_countries" +class LazerUserCountry(SQLModel, table=True): + __tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType] - user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - code = Column(String(2), nullable=False) - name = Column(String(100), nullable=False) + user_id: int = Field(foreign_key="users.id", primary_key=True) + code: str = Field(max_length=2) + name: str = Field(max_length=100) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - -class LazerUserKudosu(Base): - __tablename__ = "lazer_user_kudosu" - - user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - available = Column(Integer, default=0) - total = Column(Integer, default=0) - - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - -class LazerUserCounts(Base): - __tablename__ = "lazer_user_counts" - - user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - - # 统计计数字段 - beatmap_playcounts_count = Column(Integer, default=0) - comments_count = Column(Integer, default=0) - favourite_beatmapset_count = Column(Integer, default=0) - follower_count = Column(Integer, default=0) - graveyard_beatmapset_count = Column(Integer, default=0) - guest_beatmapset_count = Column(Integer, default=0) - loved_beatmapset_count = Column(Integer, default=0) - mapping_follower_count = Column(Integer, default=0) - nominated_beatmapset_count = Column(Integer, default=0) - pending_beatmapset_count = Column(Integer, default=0) - ranked_beatmapset_count = Column(Integer, default=0) - ranked_and_approved_beatmapset_count = Column(Integer, default=0) - unranked_beatmapset_count = Column(Integer, default=0) - scores_best_count = Column(Integer, default=0) - scores_first_count = Column(Integer, default=0) - scores_pinned_count = Column(Integer, default=0) - scores_recent_count = Column(Integer, default=0) - - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - -class LazerUserStatistics(Base): - __tablename__ = "lazer_user_statistics" - - user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - mode = Column(String(10), nullable=False, default="osu", primary_key=True) - - # 基本命中统计 - count_100 = Column(Integer, default=0) - count_300 = Column(Integer, default=0) - count_50 = Column(Integer, default=0) - count_miss = Column(Integer, default=0) - - # 等级信息 - level_current = Column(Integer, default=1) - level_progress = Column(Integer, default=0) - - # 排名信息 - global_rank = Column(Integer) - global_rank_exp = Column(Integer) - country_rank = Column(Integer) - - # PP 和分数 - pp = Column(DECIMAL(10, 2), default=0.00) - pp_exp = Column(DECIMAL(10, 2), default=0.00) - ranked_score = Column(Integer, default=0) - hit_accuracy = Column(DECIMAL(5, 2), default=0.00) - total_score = Column(Integer, default=0) - total_hits = Column(Integer, default=0) - maximum_combo = Column(Integer, default=0) - - # 游戏统计 - play_count = Column(Integer, default=0) - play_time = Column(Integer, default=0) # 秒 - replays_watched_by_others = Column(Integer, default=0) - is_ranked = Column(Boolean, default=False) - - # 成绩等级计数 - grade_ss = Column(Integer, default=0) - grade_ssh = Column(Integer, default=0) - grade_s = Column(Integer, default=0) - grade_sh = Column(Integer, default=0) - grade_a = Column(Integer, default=0) - - # 最高排名记录 - rank_highest = Column(Integer) - rank_highest_updated_at = Column(DateTime) - - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - # 关联关系 - user = relationship("User", back_populates="lazer_statistics") - - -class LazerUserBanners(Base): - __tablename__ = "lazer_user_tournament_banners" - - id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - tournament_id = Column(Integer, nullable=False) - image_url = Column(VARCHAR(500), nullable=False) - is_active = Column(TINYINT(1)) - - # 修正user关系的back_populates值 - user = relationship( - "User", - back_populates="active_banners", # 改为实际存在的属性名 + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) ) -class LazerUserAchievement(Base): - __tablename__ = "lazer_user_achievements" +class LazerUserKudosu(SQLModel, table=True): + __tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - achievement_id = Column(Integer, nullable=False) - achieved_at = Column(DateTime, default=datetime.utcnow) + user_id: int = Field(foreign_key="users.id", primary_key=True) + available: int = Field(default=0) + total: int = Field(default=0) - # created_at = Column(DateTime, default=datetime.utcnow) - # updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - user = relationship("User", back_populates="lazer_achievements") + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) -class LazerUserBadge(Base): - __tablename__ = "lazer_user_badges" +class LazerUserCounts(SQLModel, table=True): + __tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - badge_id = Column(Integer, nullable=False) - awarded_at = Column(DateTime) - description = Column(Text) - image_url = Column(String(500)) - url = Column(String(500)) + user_id: int = Field(foreign_key="users.id", primary_key=True) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + # 统计计数字段 + beatmap_playcounts_count: int = Field(default=0) + comments_count: int = Field(default=0) + favourite_beatmapset_count: int = Field(default=0) + follower_count: int = Field(default=0) + graveyard_beatmapset_count: int = Field(default=0) + guest_beatmapset_count: int = Field(default=0) + loved_beatmapset_count: int = Field(default=0) + mapping_follower_count: int = Field(default=0) + nominated_beatmapset_count: int = Field(default=0) + pending_beatmapset_count: int = Field(default=0) + ranked_beatmapset_count: int = Field(default=0) + ranked_and_approved_beatmapset_count: int = Field(default=0) + unranked_beatmapset_count: int = Field(default=0) + scores_best_count: int = Field(default=0) + scores_first_count: int = Field(default=0) + scores_pinned_count: int = Field(default=0) + scores_recent_count: int = Field(default=0) - user = relationship("User", back_populates="lazer_badges") + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + + # 关联关系 + user: "User" = Relationship(back_populates="lazer_counts") -class LazerUserMonthlyPlaycounts(Base): - __tablename__ = "lazer_user_monthly_playcounts" +class LazerUserStatistics(SQLModel, table=True): + __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - start_date = Column(Date, nullable=False) - play_count = Column(Integer, default=0) + user_id: int = Field(foreign_key="users.id", primary_key=True) + mode: str = Field(default="osu", max_length=10, primary_key=True) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + # 基本命中统计 + count_100: int = Field(default=0) + count_300: int = Field(default=0) + count_50: int = Field(default=0) + count_miss: int = Field(default=0) - user = relationship("User", back_populates="lazer_monthly_playcounts") + # 等级信息 + level_current: int = Field(default=1) + level_progress: int = Field(default=0) + + # 排名信息 + global_rank: int | None = Field(default=None) + global_rank_exp: int | None = Field(default=None) + country_rank: int | None = Field(default=None) + + # PP 和分数 + pp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2))) + pp_exp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2))) + ranked_score: int = Field(default=0) + hit_accuracy: float = Field(default=0.00, sa_column=Column(DECIMAL(5, 2))) + total_score: int = Field(default=0) + total_hits: int = Field(default=0) + maximum_combo: int = Field(default=0) + + # 游戏统计 + play_count: int = Field(default=0) + play_time: int = Field(default=0) # 秒 + replays_watched_by_others: int = Field(default=0) + is_ranked: bool = Field(default=False) + + # 成绩等级计数 + grade_ss: int = Field(default=0) + grade_ssh: int = Field(default=0) + grade_s: int = Field(default=0) + grade_sh: int = Field(default=0) + grade_a: int = Field(default=0) + + # 最高排名记录 + rank_highest: int | None = Field(default=None) + rank_highest_updated_at: datetime | None = Field( + default=None, sa_column=Column(DateTime) + ) + + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + + # 关联关系 + user: "User" = Relationship(back_populates="lazer_statistics") -class LazerUserPreviousUsername(Base): - __tablename__ = "lazer_user_previous_usernames" +class LazerUserBanners(SQLModel, table=True): + __tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - username = Column(String(32), nullable=False) - changed_at = Column(DateTime, nullable=False) + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="users.id") + tournament_id: int + image_url: str = Field(sa_column=Column(VARCHAR(500))) + is_active: bool | None = Field(default=None) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - - user = relationship("User", back_populates="lazer_previous_usernames") + # 修正user关系的back_populates值 + user: "User" = Relationship(back_populates="active_banners") -class LazerUserReplaysWatched(Base): - __tablename__ = "lazer_user_replays_watched" +class LazerUserAchievement(SQLModel, table=True): + __tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - start_date = Column(Date, nullable=False) - count = Column(Integer, default=0) + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id") + achievement_id: int + achieved_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + user: "User" = Relationship(back_populates="lazer_achievements") - user = relationship("User", back_populates="lazer_replays_watched") + +class LazerUserBadge(SQLModel, table=True): + __tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType] + + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id") + badge_id: int + awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime)) + description: str | None = Field(default=None, sa_column=Column(Text)) + image_url: str | None = Field(default=None, max_length=500) + url: str | None = Field(default=None, max_length=500) + + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + + user: "User" = Relationship(back_populates="lazer_badges") + + +class LazerUserMonthlyPlaycounts(SQLModel, table=True): + __tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType] + + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id") + start_date: datetime = Field(sa_column=Column(Date)) + play_count: int = Field(default=0) + + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + + user: "User" = Relationship(back_populates="lazer_monthly_playcounts") + + +class LazerUserPreviousUsername(SQLModel, table=True): + __tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType] + + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id") + username: str = Field(max_length=32) + changed_at: datetime = Field(sa_column=Column(DateTime)) + + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + + user: "User" = Relationship(back_populates="lazer_previous_usernames") + + +class LazerUserReplaysWatched(SQLModel, table=True): + __tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType] + + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id") + start_date: datetime = Field(sa_column=Column(Date)) + count: int = Field(default=0) + + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + + user: "User" = Relationship(back_populates="lazer_replays_watched") # ============================================ @@ -421,87 +405,86 @@ class LazerUserReplaysWatched(Base): # ============================================ -class LegacyUserStatistics(Base): - __tablename__ = "user_statistics" +class LegacyUserStatistics(SQLModel, table=True): + __tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - mode = Column(String(10), nullable=False) # osu, taiko, fruits, mania + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id") + mode: str = Field(max_length=10) # osu, taiko, fruits, mania # 基本统计 - count_100 = Column(Integer, default=0) - count_300 = Column(Integer, default=0) - count_50 = Column(Integer, default=0) - count_miss = Column(Integer, default=0) + count_100: int = Field(default=0) + count_300: int = Field(default=0) + count_50: int = Field(default=0) + count_miss: int = Field(default=0) # 等级信息 - level_current = Column(Integer, default=1) - level_progress = Column(Integer, default=0) + level_current: int = Field(default=1) + level_progress: int = Field(default=0) # 排名信息 - global_rank = Column(Integer) - global_rank_exp = Column(Integer) - country_rank = Column(Integer) + global_rank: int | None = Field(default=None) + global_rank_exp: int | None = Field(default=None) + country_rank: int | None = Field(default=None) # PP 和分数 - pp = Column(Float, default=0.0) - pp_exp = Column(Float, default=0.0) - ranked_score = Column(Integer, default=0) - hit_accuracy = Column(Float, default=0.0) - total_score = Column(Integer, default=0) - total_hits = Column(Integer, default=0) - maximum_combo = Column(Integer, default=0) + pp: float = Field(default=0.0) + pp_exp: float = Field(default=0.0) + ranked_score: int = Field(default=0) + hit_accuracy: float = Field(default=0.0) + total_score: int = Field(default=0) + total_hits: int = Field(default=0) + maximum_combo: int = Field(default=0) # 游戏统计 - play_count = Column(Integer, default=0) - play_time = Column(Integer, default=0) - replays_watched_by_others = Column(Integer, default=0) - is_ranked = Column(Boolean, default=False) + play_count: int = Field(default=0) + play_time: int = Field(default=0) + replays_watched_by_others: int = Field(default=0) + is_ranked: bool = Field(default=False) # 成绩等级计数 - grade_ss = Column(Integer, default=0) - grade_ssh = Column(Integer, default=0) - grade_s = Column(Integer, default=0) - grade_sh = Column(Integer, default=0) - grade_a = Column(Integer, default=0) + grade_ss: int = Field(default=0) + grade_ssh: int = Field(default=0) + grade_s: int = Field(default=0) + grade_sh: int = Field(default=0) + grade_a: int = Field(default=0) # 最高排名记录 - rank_highest = Column(Integer) - rank_highest_updated_at = Column(DateTime) + rank_highest: int | None = Field(default=None) + rank_highest_updated_at: datetime | None = Field( + default=None, sa_column=Column(DateTime) + ) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) # 关联关系 - user = relationship("User", back_populates="statistics") + user: "User" = Relationship(back_populates="statistics") -class LegacyOAuthToken(Base): - __tablename__ = "legacy_oauth_tokens" +class LegacyOAuthToken(SQLModel, table=True): + __tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - access_token = Column(String(255), nullable=False, index=True) - refresh_token = Column(String(255), nullable=False, index=True) - expires_at = Column(DateTime, nullable=False) - created_at = Column(DateTime, default=datetime.utcnow) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - previous_usernames = Column(JSON, default=list) - replays_watched_counts = Column(JSON, default=list) + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="users.id") + access_token: str = Field(max_length=255, index=True) + refresh_token: str = Field(max_length=255, index=True) + expires_at: datetime = Field(sa_column=Column(DateTime)) + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) + previous_usernames: list = Field(default_factory=list, sa_column=Column(JSON)) + replays_watched_counts: list = Field(default_factory=list, sa_column=Column(JSON)) # 用户关系 - user = relationship("User") - - -# class UserAchievement(Base): -# __tablename__ = "lazer_user_achievements" - -# id = Column(Integer, primary_key=True, index=True) -# user_id = Column(Integer, ForeignKey("users.id"), nullable=False) -# achievement_id = Column(Integer, nullable=False) -# achieved_at = Column(DateTime, default=datetime.utcnow) - -# user = relationship("User", back_populates="achievements") + user: "User" = Relationship() # 类型转换用的 UserAchievement(不是 SQLAlchemy 模型) @@ -511,91 +494,99 @@ class UserAchievement: achievement_id: int -class Team(Base): - __tablename__ = "teams" +class Team(SQLModel, table=True): + __tablename__ = "teams" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - name = Column(String(100), nullable=False) - short_name = Column(String(10), nullable=False) - flag_url = Column(String(500)) - created_at = Column(DateTime, default=datetime.utcnow) - - members = relationship( - "TeamMember", back_populates="team", cascade="all, delete-orphan" + id: int | None = Field(default=None, primary_key=True, index=True) + name: str = Field(max_length=100) + short_name: str = Field(max_length=10) + flag_url: str | None = Field(default=None, max_length=500) + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - -class TeamMember(Base): - __tablename__ = "team_members" - - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - team_id = Column(Integer, ForeignKey("teams.id"), nullable=False) - joined_at = Column(DateTime, default=datetime.utcnow) - - user = relationship("User", back_populates="team_membership") - team = relationship("Team", back_populates="members") + members: list["TeamMember"] = Relationship(back_populates="team") -class DailyChallengeStats(Base): - __tablename__ = "daily_challenge_stats" +class TeamMember(SQLModel, table=True): + __tablename__ = "team_members" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False, unique=True) + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id") + team_id: int = Field(foreign_key="teams.id") + joined_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) - daily_streak_best = Column(Integer, default=0) - daily_streak_current = Column(Integer, default=0) - last_update = Column(DateTime) - last_weekly_streak = Column(DateTime) - playcount = Column(Integer, default=0) - top_10p_placements = Column(Integer, default=0) - top_50p_placements = Column(Integer, default=0) - weekly_streak_best = Column(Integer, default=0) - weekly_streak_current = Column(Integer, default=0) - - user = relationship("User", back_populates="daily_challenge_stats") + user: "User" = Relationship(back_populates="team_membership") + team: "Team" = Relationship(back_populates="members") -class RankHistory(Base): - __tablename__ = "rank_history" +class DailyChallengeStats(SQLModel, table=True): + __tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - mode = Column(String(10), nullable=False) - rank_data = Column(JSON, nullable=False) # Array of ranks - date_recorded = Column(DateTime, default=datetime.utcnow) + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id", unique=True) - user = relationship("User", back_populates="rank_history") + daily_streak_best: int = Field(default=0) + daily_streak_current: int = Field(default=0) + last_update: datetime | None = Field(default=None, sa_column=Column(DateTime)) + last_weekly_streak: datetime | None = Field( + default=None, sa_column=Column(DateTime) + ) + playcount: int = Field(default=0) + top_10p_placements: int = Field(default=0) + top_50p_placements: int = Field(default=0) + weekly_streak_best: int = Field(default=0) + weekly_streak_current: int = Field(default=0) + + user: "User" = Relationship(back_populates="daily_challenge_stats") -class OAuthToken(Base): - __tablename__ = "oauth_tokens" +class RankHistory(SQLModel, table=True): + __tablename__ = "rank_history" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - access_token = Column(String(500), unique=True, nullable=False) - refresh_token = Column(String(500), unique=True, nullable=False) - token_type = Column(String(20), default="Bearer") - scope = Column(String(100), default="*") - expires_at = Column(DateTime, nullable=False) - created_at = Column(DateTime, default=datetime.utcnow) + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id") + mode: str = Field(max_length=10) + rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks + date_recorded: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) - user = relationship("User") + user: "User" = Relationship(back_populates="rank_history") -class UserAvatar(Base): - __tablename__ = "user_avatars" +class OAuthToken(SQLModel, table=True): + __tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType] - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - filename = Column(String(255), nullable=False) - original_filename = Column(String(255), nullable=False) - file_size = Column(Integer, nullable=False) - mime_type = Column(String(100), nullable=False) - is_active = Column(Boolean, default=True) - created_at = Column(Integer, default=lambda: int(datetime.now().timestamp())) - updated_at = Column(Integer, default=lambda: int(datetime.now().timestamp())) - r2_original_url = Column(String(500)) - r2_game_url = Column(String(500)) + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id") + access_token: str = Field(max_length=500, unique=True) + refresh_token: str = Field(max_length=500, unique=True) + token_type: str = Field(default="Bearer", max_length=20) + scope: str = Field(default="*", max_length=100) + expires_at: datetime = Field(sa_column=Column(DateTime)) + created_at: datetime = Field( + default_factory=datetime.utcnow, sa_column=Column(DateTime) + ) - user = relationship("User", back_populates="avatar") + user: "User" = Relationship() + + +class UserAvatar(SQLModel, table=True): + __tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType] + + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field(foreign_key="users.id") + filename: str = Field(max_length=255) + original_filename: str = Field(max_length=255) + file_size: int + mime_type: str = Field(max_length=100) + is_active: bool = Field(default=True) + created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) + updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) + r2_original_url: str | None = Field(default=None, max_length=500) + r2_game_url: str | None = Field(default=None, max_length=500) + + user: "User" = Relationship(back_populates="avatar") diff --git a/app/dependencies/database.py b/app/dependencies/database.py index 2eb3fce..c297c1c 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -1,7 +1,6 @@ from __future__ import annotations -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlmodel import Session, create_engine try: import redis @@ -11,7 +10,6 @@ from app.config import settings # 数据库引擎 engine = create_engine(settings.DATABASE_URL) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Redis 连接 if redis: @@ -22,11 +20,8 @@ else: # 数据库依赖 def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() + with Session(engine) as session: + yield session # Redis 依赖 diff --git a/app/dependencies/user.py b/app/dependencies/user.py index b6c68e5..580fbe0 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -1,14 +1,16 @@ -from fastapi import Depends, HTTPException -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from sqlalchemy.orm import Session +from __future__ import annotations from app.auth import get_token_by_access_token - -from .database import get_db from app.database import ( User as DBUser, ) +from .database import get_db + +from fastapi import Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlmodel import Session, select + security = HTTPBearer() @@ -29,5 +31,5 @@ async def get_current_user_by_token(token: str, db: Session) -> DBUser | None: token_record = get_token_by_access_token(db, token) if not token_record: return None - user = db.query(DBUser).filter(DBUser.id == token_record.user_id).first() + user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first() return user diff --git a/app/models/user.py b/app/models/user.py index 6bc6734..3b7d26b 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -3,10 +3,14 @@ from __future__ import annotations from datetime import datetime from enum import Enum +from app.database import ( + LazerUserAchievement, + Team as Team, +) + from .score import GameMode from pydantic import BaseModel -from app.database import LazerUserAchievement # 添加数据库模型导入 class PlayStyle(str, Enum): @@ -110,13 +114,6 @@ class DailyChallengeStats(BaseModel): weekly_streak_current: int = 0 -class Team(BaseModel): - flag_url: str - id: int - name: str - short_name: str - - class Page(BaseModel): html: str = "" raw: str = "" diff --git a/app/router/auth.py b/app/router/auth.py index 5faf1dc..838a75e 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -14,7 +14,7 @@ from app.dependencies import get_db from app.models.oauth import TokenResponse from fastapi import APIRouter, Depends, Form, HTTPException -from sqlalchemy.orm import Session +from sqlmodel import Session router = APIRouter(tags=["osu! OAuth 认证"]) @@ -92,7 +92,7 @@ async def oauth_token( new_refresh_token = generate_refresh_token() # 更新令牌 - user_id = int(getattr(token_record, 'user_id')) + user_id = int(getattr(token_record, "user_id")) store_token( db, user_id, diff --git a/app/utils.py b/app/utils.py index 19ebe81..89c5d86 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,8 +1,13 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime -from app.database import User as DBUser +from app.database import ( + LazerUserCounts, + LazerUserProfile, + LazerUserStatistics, + User as DBUser, +) from app.models.user import ( Country, Cover, @@ -14,7 +19,6 @@ from app.models.user import ( RankHighest, RankHistory, Statistics, - Team, User, UserAchievement, ) @@ -37,25 +41,28 @@ def convert_db_user_to_api_user( profile = db_user.lazer_profile if not profile: # 如果没有 lazer 资料,使用默认值 - profile = create_default_profile(db_user) + profile = LazerUserProfile( + user_id=user_id, + ) - # 获取 Lazer 用户计数 + # 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系 + lzrcnt = db_user.lazer_counts - lzrcnt = db_user.lazer_statistics if not lzrcnt: # 如果没有 lazer 计数,使用默认值 - lzrcnt = create_default_counts() + lzrcnt = LazerUserCounts(user_id=user_id) # 获取指定模式的统计信息 user_stats = None - for stat in db_user.lazer_statistics: - if stat.mode == ruleset: - user_stats = stat - break + if db_user.lazer_statistics: + for stat in db_user.lazer_statistics: + if stat.mode == ruleset: + user_stats = stat + break if not user_stats: # 如果没有找到指定模式的统计,创建默认统计 - user_stats = create_default_lazer_statistics(ruleset) + user_stats = LazerUserStatistics(user_id=user_id) # 获取国家信息 country_code = db_user.country_code if db_user.country_code is not None else "XX" @@ -66,7 +73,7 @@ def convert_db_user_to_api_user( kudosu = Kudosu(available=0, total=0) # 获取计数信息 - counts = create_default_counts() + counts = LazerUserCounts(user_id=user_id) # 转换统计信息 statistics = Statistics( @@ -199,12 +206,7 @@ def convert_db_user_to_api_user( team = None if db_user.team_membership: team_member = db_user.team_membership[0] # 假设用户只属于一个团队 - team = Team( - flag_url=team_member.team.flag_url or "", - id=team_member.team.id, - name=team_member.team.name, - short_name=team_member.team.short_name, - ) + team = team_member.team # 创建用户对象 # 从db_user获取基本字段值 @@ -229,27 +231,25 @@ def convert_db_user_to_api_user( avatar_url = str(db_user.avatar.r2_original_url) # 如果还是没有找到,通过查询获取 - if db_session and avatar_url is None: - try: - # 导入UserAvatar模型 - from app.database import UserAvatar + # if db_session and avatar_url is None: + # try: + # # 导入UserAvatar模型 - # 尝试查找用户的头像记录 - avatar_record = ( - db_session.query(UserAvatar) - .filter_by(user_id=user_id, is_active=True) - .first() - ) - if avatar_record is not None: - if avatar_record.r2_game_url is not None: - # 优先使用游戏用的头像URL - avatar_url = str(avatar_record.r2_game_url) - elif avatar_record.r2_original_url is not None: - # 其次使用原始头像URL - avatar_url = str(avatar_record.r2_original_url) - except Exception as e: - print(f"获取用户头像时出错: {e}") - print(f"最终头像URL: {avatar_url}") + # # 尝试查找用户的头像记录 + # statement = select(UserAvatar).where( + # UserAvatar.user_id == user_id, UserAvatar.is_active == True + # ) + # avatar_record = db_session.exec(statement).first() + # if avatar_record is not None: + # if avatar_record.r2_game_url is not None: + # # 优先使用游戏用的头像URL + # avatar_url = str(avatar_record.r2_game_url) + # elif avatar_record.r2_original_url is not None: + # # 其次使用原始头像URL + # avatar_url = str(avatar_record.r2_original_url) + # except Exception as e: + # print(f"获取用户头像时出错: {e}") + # print(f"最终头像URL: {avatar_url}") # 如果仍然没有找到头像URL,则使用默认URL if avatar_url is None: avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1" @@ -265,15 +265,12 @@ def convert_db_user_to_api_user( "kudosu", ] if profile and profile.profile_order: - profile_order = profile.profile_order + profile_order = profile.profile_order.split(",") # 在convert_db_user_to_api_user函数中添加active_tournament_banners处理 active_tournament_banners = [] - if ( - hasattr(db_user, "lazer_tournament_banners") - and db_user.lazer_tournament_banners - ): - for banner in db_user.lazer_tournament_banners: + if db_user.active_banners: + for banner in db_user.active_banners: active_tournament_banners.append( { "tournament_id": banner.tournament_id, @@ -284,7 +281,7 @@ def convert_db_user_to_api_user( # 在convert_db_user_to_api_user函数中添加badges处理 badges = [] - if hasattr(db_user, "lazer_badges") and db_user.lazer_badges: + if db_user.lazer_badges: for badge in db_user.lazer_badges: badges.append( { @@ -298,10 +295,7 @@ def convert_db_user_to_api_user( # 在convert_db_user_to_api_user函数中添加monthly_playcounts处理 monthly_playcounts = [] - if ( - hasattr(db_user, "lazer_monthly_playcounts") - and db_user.lazer_monthly_playcounts - ): + if db_user.lazer_monthly_playcounts: for playcount in db_user.lazer_monthly_playcounts: monthly_playcounts.append( { @@ -314,10 +308,7 @@ def convert_db_user_to_api_user( # 在convert_db_user_to_api_user函数中添加previous_usernames处理 previous_usernames = [] - if ( - hasattr(db_user, "lazer_previous_usernames") - and db_user.lazer_previous_usernames - ): + if db_user.lazer_previous_usernames: for username in db_user.lazer_previous_usernames: previous_usernames.append( { @@ -348,22 +339,22 @@ def convert_db_user_to_api_user( avatar_url=avatar_url, country_code=str(country_code), default_group=profile.default_group if profile else "default", - is_active=profile.is_active if profile else True, - is_bot=profile.is_bot if profile else False, - is_deleted=profile.is_deleted if profile else False, - is_online=profile.is_online if profile else True, - is_supporter=profile.is_supporter if profile else False, - is_restricted=profile.is_restricted if profile else False, - last_visit=db_user.last_visit if db_user.last_visit else None, - pm_friends_only=profile.pm_friends_only if profile else False, - profile_colour=profile.profile_colour if profile else None, + is_active=profile.is_active, + is_bot=profile.is_bot, + is_deleted=profile.is_deleted, + is_online=profile.is_online, + is_supporter=profile.is_supporter, + is_restricted=profile.is_restricted, + last_visit=db_user.last_visit, + pm_friends_only=profile.pm_friends_only, + profile_colour=profile.profile_colour, cover_url=profile.cover_url if profile and profile.cover_url else "https://assets.ppy.sh/user-profile-covers/default.jpeg", discord=profile.discord if profile else None, has_supported=profile.has_supported if profile else False, interests=profile.interests if profile else None, - join_date=profile.join_date, + join_date=profile.join_date if profile.join_date else datetime.now(UTC), location=profile.location if profile else None, max_blocks=profile.max_blocks if profile and profile.max_blocks else 100, max_friends=profile.max_friends if profile and profile.max_friends else 500, @@ -408,7 +399,7 @@ def convert_db_user_to_api_user( daily_challenge_user_stats=None, groups=[], monthly_playcounts=monthly_playcounts, - page=Page(html=profile.page_html, raw=profile.page_raw) + page=Page(html=profile.page_html or "", raw=profile.page_raw or "") if profile.page_html or profile.page_raw else Page(), previous_usernames=previous_usernames, @@ -439,164 +430,3 @@ def get_country_name(country_code: str) -> str: # 可以添加更多国家 } return country_names.get(country_code, "Unknown") - - -def create_default_profile(db_user: DBUser): - """创建默认的用户资料""" - - # 完善 MockProfile 类定义 - class MockProfile: - def __init__(self): - self.is_active = True - self.is_bot = False - self.is_deleted = False - self.is_online = True - self.is_supporter = False - self.is_restricted = False - self.session_verified = False - self.has_supported = False - self.pm_friends_only = False - self.default_group = "default" - self.last_visit = None - self.join_date = db_user.join_date if db_user else datetime.utcnow() - self.profile_colour = None - self.profile_hue = None - self.avatar_url = None - self.cover_url = None - self.discord = None - self.twitter = None - self.website = None - self.title = None - self.title_url = None - self.interests = None - self.location = None - self.occupation = None - self.playmode = "osu" - self.support_level = 0 - self.max_blocks = 100 - self.max_friends = 500 - self.post_count = 0 - # 添加profile_order字段 - self.profile_order = [ - "me", - "recent_activity", - "top_ranks", - "medals", - "historical", - "beatmaps", - "kudosu", - ] - self.page_html = "" - self.page_raw = "" - # 在MockProfile类中添加active_tournament_banners字段 - self.active_tournament_banners = ( - MockLazerTournamentBanner.create_default_banners() - ) - self.active_tournament_banners = [] # 默认空列表 - - return MockProfile() - - -def create_default_lazer_statistics(mode: str): - """创建默认的 Lazer 统计信息""" - - class MockLazerStatistics: - def __init__(self, mode: str): - self.mode = mode - self.count_100 = 0 - self.count_300 = 0 - self.count_50 = 0 - self.count_miss = 0 - self.level_current = 1 - self.level_progress = 0 - self.global_rank = None - self.global_rank_exp = None - self.pp = 0.0 - self.pp_exp = 0.0 - self.ranked_score = 0 - self.hit_accuracy = 0.0 - self.total_score = 0 - self.total_hits = 0 - self.maximum_combo = 0 - self.play_count = 0 - self.play_time = 0 - self.replays_watched_by_others = 0 - self.is_ranked = False - self.grade_ss = 0 - self.grade_ssh = 0 - self.grade_s = 0 - self.grade_sh = 0 - self.grade_a = 0 - self.country_rank = None - self.rank_highest = None - self.rank_highest_updated_at = None - - return MockLazerStatistics(mode) - - -def create_default_country(country_code: str): - """创建默认的国家信息""" - - class MockCountry: - def __init__(self, code: str): - self.code = code - self.name = get_country_name(code) - - return MockCountry(country_code) - - -def create_default_kudosu(): - """创建默认的 Kudosu 信息""" - - class MockKudosu: - def __init__(self): - self.available = 0 - self.total = 0 - - return MockKudosu() - - -def create_default_counts(): - """创建默认的计数信息""" - - class MockCounts: - def __init__(self): - self.recent_scores_count = None - self.beatmap_playcounts_count = 0 - self.scores_first_count = 0 - self.scores_pinned_count = 0 - self.comments_count = 0 - self.favourite_beatmapset_count = 0 - self.follower_count = 0 - self.graveyard_beatmapset_count = 0 - self.guest_beatmapset_count = 0 - self.loved_beatmapset_count = 0 - self.mapping_follower_count = 0 - self.nominated_beatmapset_count = 0 - self.pending_beatmapset_count = 0 - self.ranked_beatmapset_count = 0 - self.ranked_and_approved_beatmapset_count = 0 - self.unranked_beatmapset_count = 0 - self.scores_best_count = 0 - self.scores_first_count = 0 - self.scores_pinned_count = 0 - self.scores_recent_count = 0 - - return MockCounts() - - -class MockLazerTournamentBanner: - def __init__(self, tournament_id: int, image_url: str, is_active: bool = True): - self.tournament_id = tournament_id - self.image_url = image_url - self.is_active = is_active - - @staticmethod - def create_default_banners(): - """创建默认的锦标赛横幅配置""" - return [ - MockLazerTournamentBanner(1, "https://example.com/banner1.jpg", True), - MockLazerTournamentBanner(2, "https://example.com/banner2.jpg", False), - ] - - diff --git a/create_sample_data.py b/create_sample_data.py index d2d7ba8..5758cd8 100644 --- a/create_sample_data.py +++ b/create_sample_data.py @@ -10,33 +10,36 @@ import time from app.auth import get_password_hash from app.database import ( - Base, DailyChallengeStats, + LazerUserAchievement, LazerUserStatistics, RankHistory, User, - UserAchievement, ) from app.dependencies.database import engine, get_db +from sqlmodel import SQLModel + # 创建所有表 -Base.metadata.create_all(bind=engine) +SQLModel.metadata.create_all(bind=engine) def create_sample_user(): """创建示例用户数据""" - db = next(get_db()) + with next(get_db()) as db: + # 检查用户是否已存在 + from sqlmodel import select - # 检查用户是否已存在 - existing_user = db.query(User).filter(User.name == "Googujiang").first() - if existing_user: - print("示例用户已存在,跳过创建") - return existing_user + statement = select(User).where(User.name == "Googujiang") + existing_user = db.exec(statement).first() + if existing_user: + print("示例用户已存在,跳过创建") + return existing_user - # 当前时间戳 - current_timestamp = int(time.time()) - join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp()) - last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp()) + # 当前时间戳 + current_timestamp = int(time.time()) + join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp()) + last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp()) # 创建用户 user = User( @@ -150,6 +153,10 @@ def create_sample_user(): db.commit() db.refresh(user) + # 确保用户ID存在 + if user.id is None: + raise ValueError("User ID is None after saving to database") + # 创建 osu! 模式统计 osu_stats = LazerUserStatistics( user_id=user.id, @@ -370,28 +377,28 @@ def create_sample_user(): # 创建一些成就 achievements = [ - UserAchievement( - # user_id=user.id, + LazerUserAchievement( + user_id=user.id, achievement_id=336, achieved_at=datetime(2025, 6, 21, 19, 6, 32), ), - UserAchievement( - # user_id=user.id, + LazerUserAchievement( + user_id=user.id, achievement_id=319, achieved_at=datetime(2025, 6, 1, 0, 52, 0), ), - UserAchievement( - # user_id=user.id, + LazerUserAchievement( + user_id=user.id, achievement_id=222, achieved_at=datetime(2025, 5, 28, 12, 24, 37), ), - UserAchievement( - # user_id=user.id, + LazerUserAchievement( + user_id=user.id, achievement_id=38, achieved_at=datetime(2024, 7, 5, 15, 43, 23), ), - UserAchievement( - # user_id=user.id, + LazerUserAchievement( + user_id=user.id, achievement_id=67, achieved_at=datetime(2024, 6, 24, 5, 6, 44), ), diff --git a/pyproject.toml b/pyproject.toml index 6e5c1cb..050918b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "python-multipart>=0.0.6", "redis>=5.0.1", "sqlalchemy>=2.0.23", + "sqlmodel>=0.0.24", "uvicorn[standard]>=0.24.0", ] diff --git a/test_lazer.py b/test_lazer.py index 8bdae16..31ebb3c 100644 --- a/test_lazer.py +++ b/test_lazer.py @@ -12,9 +12,11 @@ import sys sys.path.append(os.path.dirname(os.path.dirname(__file__))) from app.database import User -from app.dependencies import get_db +from app.dependencies.database import get_db from app.utils import convert_db_user_to_api_user +from sqlmodel import select + def test_lazer_tables(): """测试 lazer 表的基本功能""" @@ -26,7 +28,8 @@ def test_lazer_tables(): try: # 测试查询用户 - user = db.query(User).first() + statement = select(User) + user = db.exec(statement).first() if not user: print("❌ 没有找到用户,请先同步数据") return False @@ -83,7 +86,8 @@ def test_authentication(): try: # 尝试认证第一个用户 - user = db.query(User).first() + statement = select(User) + user = db.exec(statement).first() if not user: print("❌ 没有用户进行认证测试") return False diff --git a/uv.lock b/uv.lock index 12d956d..580869c 100644 --- a/uv.lock +++ b/uv.lock @@ -475,6 +475,7 @@ dependencies = [ { name = "python-multipart" }, { name = "redis" }, { name = "sqlalchemy" }, + { name = "sqlmodel" }, { name = "uvicorn", extra = ["standard"] }, ] @@ -499,6 +500,7 @@ requires-dist = [ { name = "python-multipart", specifier = ">=0.0.6" }, { name = "redis", specifier = ">=5.0.1" }, { name = "sqlalchemy", specifier = ">=2.0.23" }, + { name = "sqlmodel", specifier = ">=0.0.24" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.24.0" }, ] @@ -810,6 +812,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" }, ] +[[package]] +name = "sqlmodel" +version = "0.0.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "sqlalchemy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/4b/c2ad0496f5bdc6073d9b4cef52be9c04f2b37a5773441cc6600b1857648b/sqlmodel-0.0.24.tar.gz", hash = "sha256:cc5c7613c1a5533c9c7867e1aab2fd489a76c9e8a061984da11b4e613c182423", size = 116780, upload-time = "2025-03-07T05:43:32.887Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/91/484cd2d05569892b7fef7f5ceab3bc89fb0f8a8c0cde1030d383dbc5449c/sqlmodel-0.0.24-py3-none-any.whl", hash = "sha256:6778852f09370908985b667d6a3ab92910d0d5ec88adcaf23dbc242715ff7193", size = 28622, upload-time = "2025-03-07T05:43:30.37Z" }, +] + [[package]] name = "starlette" version = "0.47.2"