From 585cb9d98aac6a99767c2b3252bd8282dabb91af Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 26 Jul 2025 12:05:54 +0800 Subject: [PATCH] fix(database): fix score database --- app/database/__init__.py | 2 -- app/database/auth.py | 1 - app/database/beatmap.py | 1 - app/database/beatmapset.py | 2 -- app/database/legacy.py | 3 +- app/database/score.py | 33 +++++++++------------ app/database/team.py | 3 +- app/database/user.py | 59 ++++++++++++++++++------------------- app/models/score.py | 8 ++--- app/models/spectator_hub.py | 5 ++-- app/router/beatmap.py | 18 ++++++----- create_sample_data.py | 57 ++++++++++++++++------------------- main.py | 3 +- pyproject.toml | 2 ++ 14 files changed, 90 insertions(+), 107 deletions(-) diff --git a/app/database/__init__.py b/app/database/__init__.py index baf7677..a0cdc2a 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from .auth import OAuthToken from .beatmap import ( Beatmap as Beatmap, diff --git a/app/database/auth.py b/app/database/auth.py index e00debe..8e9032b 100644 --- a/app/database/auth.py +++ b/app/database/auth.py @@ -1,4 +1,3 @@ -# ruff: noqa: I002 from datetime import datetime from typing import TYPE_CHECKING diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 4c0155e..b959770 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -1,4 +1,3 @@ -# ruff: noqa: I002 from datetime import datetime from app.models.beatmap import BeatmapRankStatus diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 781fa9f..4141212 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -1,5 +1,3 @@ -# ruff: noqa: I002 - from datetime import datetime from typing import TYPE_CHECKING, cast diff --git a/app/database/legacy.py b/app/database/legacy.py index 245e267..c0db405 100644 --- a/app/database/legacy.py +++ b/app/database/legacy.py @@ -1,10 +1,9 @@ -# ruff: noqa: I002 from datetime import datetime from typing import TYPE_CHECKING from sqlalchemy import JSON, Column, DateTime -from sqlmodel import Field, Relationship, SQLModel from sqlalchemy.orm import Mapped +from sqlmodel import Field, Relationship, SQLModel if TYPE_CHECKING: from .user import User diff --git a/app/database/score.py b/app/database/score.py index 34eb327..50cd097 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -1,20 +1,16 @@ -# ruff: noqa: I002 - from datetime import datetime import math -from typing import Literal, TYPE_CHECKING, List -from app.models.score import Rank, APIMod, GameMode, MODE_TO_INT +from app.database.user import User +from app.models.score import MODE_TO_INT, APIMod, GameMode, Rank from .beatmap import Beatmap, BeatmapResp -from .beatmapset import Beatmapset, BeatmapsetResp +from .beatmapset import BeatmapsetResp from pydantic import BaseModel -from sqlalchemy import Column, DateTime, JSON -from sqlmodel import BigInteger, Field, Relationship, SQLModel, JSON as SQLModeJSON +from sqlalchemy import Column, DateTime +from sqlmodel import JSON, BigInteger, Field, Relationship, SQLModel -if TYPE_CHECKING: - from .user import User class ScoreBase(SQLModel): # 基本字段 @@ -35,7 +31,6 @@ class ScoreBase(SQLModel): preserve: bool = Field(default=True) rank: Rank room_id: int | None = Field(default=None) # multiplayer - ruleset_id: GameMode = Field(index=True) started_at: datetime = Field(sa_column=Column(DateTime)) total_score: int = Field(default=0, sa_column=Column(BigInteger)) type: str @@ -59,8 +54,8 @@ class ScoreStatistics(BaseModel): class Score(ScoreBase, table=True): __tablename__ = "scores" # pyright: ignore[reportAssignmentType] id: int = Field(primary_key=True) - beatmap_id: int = Field(index=True, foreign_key="beatmap.id") - user_id: int = Field(foreign_key="user.id", index=True) + beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") + user_id: int = Field(foreign_key="users.id", index=True) # ScoreStatistics n300: int = Field(exclude=True) n100: int = Field(exclude=True) @@ -70,11 +65,11 @@ class Score(ScoreBase, table=True): nkatu: int = Field(exclude=True) nlarge_tick_miss: int | None = Field(default=None, exclude=True) nslider_tail_hit: int | None = Field(default=None, exclude=True) + gamemode: GameMode = Field(index=True, alias="ruleset_id") # optional - beatmap: "Beatmap" = Relationship(back_populates="scores") - beatmapset: "Beatmapset" = Relationship(back_populates="scores") - # FIXME: user: "User" = Relationship(back_populates="scores") + beatmap: "Beatmap" = Relationship() + user: "User" = Relationship() class ScoreResp(ScoreBase): @@ -84,7 +79,7 @@ class ScoreResp(ScoreBase): legacy_total_score: int = 0 # FIXME processed: bool = False # solo_score weight: float = 0.0 - ruleset_id: int | None + ruleset_id: int | None = None beatmap: BeatmapResp | None = None beatmapset: BeatmapsetResp | None = None # FIXME: user: APIUser | None = None @@ -92,12 +87,12 @@ class ScoreResp(ScoreBase): @classmethod def from_db(cls, score: Score) -> "ScoreResp": - s = cls.model_validate(score) + s = cls.model_validate(score.model_dump()) s.beatmap = BeatmapResp.from_db(score.beatmap) s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset) s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo - s.ruleset_id=MODE_TO_INT[score.ruleset_id] + s.ruleset_id = MODE_TO_INT[score.gamemode] if score.best_id: # https://osu.ppy.sh/wiki/Performance_points/Weighting_system s.weight = math.pow(0.95, score.best_id) @@ -111,4 +106,4 @@ class ScoreResp(ScoreBase): count_large_tick_miss=score.nlarge_tick_miss, count_slider_tail_hit=score.nslider_tail_hit, ) - return s \ No newline at end of file + return s diff --git a/app/database/team.py b/app/database/team.py index e7e277d..5dabf71 100644 --- a/app/database/team.py +++ b/app/database/team.py @@ -1,4 +1,3 @@ -# ruff: noqa: I002 from datetime import datetime from typing import TYPE_CHECKING @@ -35,4 +34,4 @@ class TeamMember(SQLModel, table=True): ) user: Mapped["User"] = Relationship(back_populates="team_membership") - team: Mapped["Team"] = Relationship(back_populates="members") \ No newline at end of file + team: Mapped["Team"] = Relationship(back_populates="members") diff --git a/app/database/user.py b/app/database/user.py index 160f27b..8b6fe02 100644 --- a/app/database/user.py +++ b/app/database/user.py @@ -1,6 +1,3 @@ -# ruff: noqa: I002 -from __future__ import annotations - from dataclasses import dataclass from datetime import datetime from typing import Optional @@ -10,7 +7,6 @@ from .team import TeamMember from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text from sqlalchemy.dialects.mysql import VARCHAR -from sqlalchemy.orm import Mapped from sqlmodel import BigInteger, Field, Relationship, SQLModel @@ -70,34 +66,35 @@ class User(SQLModel, table=True): return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None # 关联关系 - lazer_profile: Mapped[Optional["LazerUserProfile"]] = Relationship(back_populates="user") - lazer_statistics: Mapped[list["LazerUserStatistics"]] = Relationship(back_populates="user") - lazer_counts: Mapped[Optional["LazerUserCounts"]] = Relationship(back_populates="user") - lazer_achievements: Mapped[list["LazerUserAchievement"]] = Relationship( + lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user") + lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user") + lazer_counts: Optional["LazerUserCounts"] = Relationship(back_populates="user") + lazer_achievements: list["LazerUserAchievement"] = Relationship( back_populates="user" ) - lazer_profile_sections: Mapped[list["LazerUserProfileSections"]] = Relationship( + lazer_profile_sections: list["LazerUserProfileSections"] = Relationship( back_populates="user" ) statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user") - team_membership: Mapped[list["TeamMember"]] = Relationship(back_populates="user") - daily_challenge_stats: Mapped[Optional["DailyChallengeStats"]] = Relationship( + team_membership: list["TeamMember"] = Relationship(back_populates="user") + daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship( back_populates="user" ) - rank_history: Mapped[list["RankHistory"]] = Relationship(back_populates="user") - avatar: Mapped[Optional["UserAvatar"]] = Relationship(back_populates="user") - active_banners: Mapped[list["LazerUserBanners"]] = Relationship(back_populates="user") - lazer_badges: Mapped[list["LazerUserBadge"]] = Relationship(back_populates="user") - lazer_monthly_playcounts: Mapped[list["LazerUserMonthlyPlaycounts"]] = Relationship( + rank_history: list["RankHistory"] = Relationship(back_populates="user") + avatar: Optional["UserAvatar"] = Relationship(back_populates="user") + active_banners: list["LazerUserBanners"] = Relationship(back_populates="user") + lazer_badges: list["LazerUserBadge"] = Relationship(back_populates="user") + lazer_monthly_playcounts: list["LazerUserMonthlyPlaycounts"] = Relationship( back_populates="user" ) - lazer_previous_usernames: Mapped[list["LazerUserPreviousUsername"]] = Relationship( + lazer_previous_usernames: list["LazerUserPreviousUsername"] = Relationship( back_populates="user" ) - lazer_replays_watched: Mapped[list["LazerUserReplaysWatched"]] = Relationship( + lazer_replays_watched: list["LazerUserReplaysWatched"] = Relationship( back_populates="user" ) + # ============================================ # Lazer API 专用表模型 # ============================================ @@ -155,7 +152,7 @@ class LazerUserProfile(SQLModel, table=True): ) # 关联关系 - user: Mapped["User"] = Relationship(back_populates="lazer_profile") + user: "User" = Relationship(back_populates="lazer_profile") class LazerUserProfileSections(SQLModel, table=True): @@ -173,7 +170,7 @@ class LazerUserProfileSections(SQLModel, table=True): default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: Mapped["User"] = Relationship(back_populates="lazer_profile_sections") + user: "User" = Relationship(back_populates="lazer_profile_sections") class LazerUserCountry(SQLModel, table=True): @@ -238,7 +235,7 @@ class LazerUserCounts(SQLModel, table=True): ) # 关联关系 - user: Mapped["User"] = Relationship(back_populates="lazer_counts") + user: "User" = Relationship(back_populates="lazer_counts") class LazerUserStatistics(SQLModel, table=True): @@ -298,7 +295,7 @@ class LazerUserStatistics(SQLModel, table=True): ) # 关联关系 - user: Mapped["User"] = Relationship(back_populates="lazer_statistics") + user: "User" = Relationship(back_populates="lazer_statistics") class LazerUserBanners(SQLModel, table=True): @@ -311,7 +308,7 @@ class LazerUserBanners(SQLModel, table=True): is_active: bool | None = Field(default=None) # 修正user关系的back_populates值 - user: Mapped["User"] = Relationship(back_populates="active_banners") + user: "User" = Relationship(back_populates="active_banners") class LazerUserAchievement(SQLModel, table=True): @@ -324,7 +321,7 @@ class LazerUserAchievement(SQLModel, table=True): default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: Mapped["User"] = Relationship(back_populates="lazer_achievements") + user: "User" = Relationship(back_populates="lazer_achievements") class LazerUserBadge(SQLModel, table=True): @@ -345,7 +342,7 @@ class LazerUserBadge(SQLModel, table=True): default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: Mapped["User"] = Relationship(back_populates="lazer_badges") + user: "User" = Relationship(back_populates="lazer_badges") class LazerUserMonthlyPlaycounts(SQLModel, table=True): @@ -363,7 +360,7 @@ class LazerUserMonthlyPlaycounts(SQLModel, table=True): default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: Mapped["User"] = Relationship(back_populates="lazer_monthly_playcounts") + user: "User" = Relationship(back_populates="lazer_monthly_playcounts") class LazerUserPreviousUsername(SQLModel, table=True): @@ -381,7 +378,7 @@ class LazerUserPreviousUsername(SQLModel, table=True): default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: Mapped["User"] = Relationship(back_populates="lazer_previous_usernames") + user: "User" = Relationship(back_populates="lazer_previous_usernames") class LazerUserReplaysWatched(SQLModel, table=True): @@ -399,7 +396,7 @@ class LazerUserReplaysWatched(SQLModel, table=True): default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: Mapped["User"] = Relationship(back_populates="lazer_replays_watched") + user: "User" = Relationship(back_populates="lazer_replays_watched") # 类型转换用的 UserAchievement(不是 SQLAlchemy 模型) @@ -427,7 +424,7 @@ class DailyChallengeStats(SQLModel, table=True): weekly_streak_best: int = Field(default=0) weekly_streak_current: int = Field(default=0) - user: Mapped["User"] = Relationship(back_populates="daily_challenge_stats") + user: "User" = Relationship(back_populates="daily_challenge_stats") class RankHistory(SQLModel, table=True): @@ -441,7 +438,7 @@ class RankHistory(SQLModel, table=True): default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: Mapped["User"] = Relationship(back_populates="rank_history") + user: "User" = Relationship(back_populates="rank_history") class UserAvatar(SQLModel, table=True): @@ -459,4 +456,4 @@ class UserAvatar(SQLModel, table=True): r2_original_url: str | None = Field(default=None, max_length=500) r2_game_url: str | None = Field(default=None, max_length=500) - user: Mapped["User"] = Relationship(back_populates="avatar") + user: "User" = Relationship(back_populates="avatar") diff --git a/app/models/score.py b/app/models/score.py index 7c91fe0..eb3f590 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -1,9 +1,7 @@ from __future__ import annotations from enum import Enum, IntEnum -from typing import Any - -from pydantic import BaseModel +from typing import Any, TypedDict class GameMode(str, Enum): @@ -34,9 +32,9 @@ class Rank(str, Enum): F = "f" -class APIMod(BaseModel): +class APIMod(TypedDict): acronym: str - # settings: dict[str, Any] = {} + settings: dict[str, Any] # https://github.com/ppy/osu/blob/master/osu.Game/Rulesets/Scoring/HitResult.cs diff --git a/app/models/spectator_hub.py b/app/models/spectator_hub.py index bdf82a9..d9aa296 100644 --- a/app/models/spectator_hub.py +++ b/app/models/spectator_hub.py @@ -5,7 +5,6 @@ from enum import IntEnum from typing import Any from .score import ( - APIMod as APIModBase, HitResult, ) from .signalr import MessagePackArrayModel @@ -14,7 +13,9 @@ import msgpack from pydantic import Field, field_validator -class APIMod(APIModBase, MessagePackArrayModel): ... +class APIMod(MessagePackArrayModel): + acronym: str + settings: dict[str, Any] = Field(default_factory=dict) class SpectatedUserState(IntEnum): diff --git a/app/router/beatmap.py b/app/router/beatmap.py index aa8864a..4731994 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -5,11 +5,10 @@ from app.database import ( BeatmapResp, User as DBUser, ) -from app.database.score import Score, ScoreResp, APIMod from app.database.beatmapset import Beatmapset +from app.database.score import Score, ScoreResp from app.dependencies.database import get_db from app.dependencies.user import get_current_user -from typing import List, Optional from .api_router import router @@ -29,8 +28,7 @@ async def get_beatmap( beatmap = ( await db.exec( select(Beatmap) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType] + .options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType] .where(Beatmap.id == bid) ) ).first() @@ -78,8 +76,8 @@ async def batch_get_beatmaps( class BeatmapScores(BaseModel): - scores: List[ScoreResp] - userScore: Optional[ScoreResp] = None + scores: list[ScoreResp] + userScore: ScoreResp | None = None @router.get( @@ -101,8 +99,7 @@ async def get_beatmapset_scores( all_scores = ( await db.exec( - select(Score) - .where(Score.beatmap_id == beatmap) + select(Score).where(Score.beatmap_id == beatmap) # .where(Score.mods == mods if mods else True) ) ).all() @@ -110,6 +107,11 @@ async def get_beatmapset_scores( user_score = ( await db.exec( select(Score) + .options( + joinedload(Score.beatmap) # pyright: ignore[reportArgumentType] + .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] + .selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] + ) .where(Score.beatmap_id == beatmap) .where(Score.user_id == current_user.id) ) diff --git a/create_sample_data.py b/create_sample_data.py index b91223e..610bff7 100644 --- a/create_sample_data.py +++ b/create_sample_data.py @@ -7,17 +7,18 @@ from __future__ import annotations import asyncio from datetime import datetime +import random from app.auth import get_password_hash from app.database import ( User, ) -from app.database.beatmapset import Beatmapset, BeatmapsetResp -from app.database.beatmap import Beatmap, BeatmapResp +from app.database.beatmap import Beatmap +from app.database.beatmapset import Beatmapset from app.database.score import Score -from app.models.score import GameMode, Rank, APIMod -from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.dependencies.database import create_tables, engine +from app.models.beatmap import BeatmapRankStatus, Genre, Language +from app.models.score import APIMod, GameMode, Rank from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -29,8 +30,8 @@ async def create_sample_user(): async with session.begin(): # 检查用户是否已存在 statement = select(User).where(User.name == "Googujiang") - result = await session.execute(statement) - existing_user = result.scalars().first() + result = await session.exec(statement) + existing_user = result.first() if existing_user: print("示例用户已存在,跳过创建") return existing_user @@ -63,13 +64,6 @@ async def create_sample_user(): ) session.add(user) - await session.commit() - await session.refresh(user) - - # 确保用户ID存在 - if user.id is None: - raise ValueError("User ID is None after saving to database") - print(f"成功创建示例用户: {user.name} (ID: {user.id})") print(f"安全用户名: {user.safe_name}") print(f"邮箱: {user.email}") @@ -77,14 +71,15 @@ async def create_sample_user(): return user -async def create_sample_beatmap_data(user: User): +async def create_sample_beatmap_data(): """创建示例谱面数据""" async with AsyncSession(engine) as session: async with session.begin(): + user_id = random.randint(1, 1000) # 检查谱面集是否已存在 statement = select(Beatmapset).where(Beatmapset.id == 1) - result = await session.execute(statement) - existing_beatmapset = result.scalars().first() + result = await session.exec(statement) + existing_beatmapset = result.first() if existing_beatmapset: print("示例谱面集已存在,跳过创建") return existing_beatmapset @@ -106,7 +101,7 @@ async def create_sample_beatmap_data(user: User): spotlight=False, title="Example Song", title_unicode="Example Song", - user_id=user.id, + user_id=user_id, video=False, availability_info=None, download_disabled=False, @@ -127,7 +122,6 @@ async def create_sample_beatmap_data(user: User): ratings=[], ) session.add(beatmapset) - await session.flush() # 创建谱面 beatmap = Beatmap( @@ -138,7 +132,7 @@ async def create_sample_beatmap_data(user: User): difficulty_rating=5.5, beatmap_status=BeatmapRankStatus.RANKED, total_length=195, - user_id=user.id, + user_id=user_id, version="Example Difficulty", checksum="example_checksum", current_user_playcount=0, @@ -158,33 +152,35 @@ async def create_sample_beatmap_data(user: User): playcount=50, ) session.add(beatmap) - await session.flush() # 创建成绩 score = Score( id=1, accuracy=0.9876, map_md5="example_checksum", + user_id=1, best_id=1, build_id=None, classic_total_score=1234567, ended_at=datetime.now(), has_replay=True, max_combo=1100, - mods=[APIMod(acronym="HD"), APIMod(acronym="DT")], + mods=[ + APIMod(acronym="HD", settings={}), + APIMod(acronym="DT", settings={}), + ], passed=True, playlist_item_id=None, pp=250.5, preserve=True, rank=Rank.S, room_id=None, - ruleset_id=GameMode.OSU, + gamemode=GameMode.OSU, started_at=datetime.now(), total_score=1234567, type="solo_score", position=None, beatmap_id=1, - user_id=user.id, n300=950, n100=30, n50=20, @@ -195,8 +191,6 @@ async def create_sample_beatmap_data(user: User): nslider_tail_hit=None, ) session.add(score) - await session.commit() - await session.refresh(beatmapset) print(f"成功创建示例谱面集: {beatmapset.title} (ID: {beatmapset.id})") print(f"成功创建示例谱面: {beatmap.version} (ID: {beatmap.id})") @@ -207,13 +201,14 @@ async def create_sample_beatmap_data(user: User): async def main(): print("开始创建示例数据...") await create_tables() - user = await create_sample_user() - await create_sample_beatmap_data(user) + await create_sample_user() + await create_sample_beatmap_data() print("示例数据创建完成!") - print(f"用户名: {user.name}") - print("密码: password123") - print("现在您可以使用这些凭据来测试API了。") + # print(f"用户名: {user.name}") + # print("密码: password123") + # print("现在您可以使用这些凭据来测试API了。") + await engine.dispose() if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/main.py b/main.py index ce4d222..754be83 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager from datetime import datetime from app.config import settings -from app.dependencies.database import create_tables +from app.dependencies.database import create_tables, engine from app.router import api_router, auth_router, signalr_router from fastapi import FastAPI @@ -19,6 +19,7 @@ async def lifespan(app: FastAPI): await create_tables() # on shutdown yield + await engine.dispose() app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan) diff --git a/pyproject.toml b/pyproject.toml index cf3bf81..b813496 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ ignore = [ "RUF003", # ambiguous-unicode-character-comment ] +[tool.ruff.lint.extend-per-file-ignores] +"app/database/**/*.py" = ["I002"] [tool.ruff.lint.isort] force-sort-within-sections = true