diff --git a/.gitignore b/.gitignore index 369e759..05622b7 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports +test-cert/ htmlcov/ .tox/ .nox/ diff --git a/app/auth.py b/app/auth.py index 4c690f8..4762662 100644 --- a/app/auth.py +++ b/app/auth.py @@ -8,7 +8,7 @@ import string from app.config import settings from app.database import ( OAuthToken, - User as DBUser, + User, ) from app.log import logger @@ -74,7 +74,7 @@ def get_password_hash(password: str) -> str: async def authenticate_user_legacy( db: AsyncSession, name: str, password: str -) -> DBUser | None: +) -> User | None: """ 验证用户身份 - 使用类似 from_login 的逻辑 """ @@ -82,7 +82,7 @@ async def authenticate_user_legacy( pw_md5 = hashlib.md5(password.encode()).hexdigest() # 2. 根据用户名查找用户 - statement = select(DBUser).where(DBUser.name == name) + statement = select(User).where(User.username == name) user = (await db.exec(statement)).first() if not user: return None @@ -113,7 +113,7 @@ async def authenticate_user_legacy( async def authenticate_user( db: AsyncSession, username: str, password: str -) -> DBUser | None: +) -> User | None: """验证用户身份""" return await authenticate_user_legacy(db, username, password) diff --git a/app/database/__init__.py b/app/database/__init__.py index 191a193..0ee253b 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -1,3 +1,4 @@ +from .achievement import UserAchievement, UserAchievementResp from .auth import OAuthToken from .beatmap import ( Beatmap as Beatmap, @@ -8,65 +9,64 @@ from .beatmapset import ( BeatmapsetResp as BeatmapsetResp, ) from .best_score import BestScore -from .legacy import LegacyOAuthToken, LegacyUserStatistics +from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp +from .favourite_beatmapset import FavouriteBeatmapset +from .lazer_user import ( + User, + UserResp, +) +from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp +from .playlist_attempts import ItemAttemptsCount, ItemAttemptsResp +from .playlist_best_score import PlaylistBestScore +from .playlists import Playlist, PlaylistResp +from .pp_best_score import PPBestScore from .relationship import Relationship, RelationshipResp, RelationshipType +from .room import Room, RoomResp from .score import ( + MultiplayerScores, Score, + ScoreAround, ScoreBase, ScoreResp, ScoreStatistics, ) from .score_token import ScoreToken, ScoreTokenResp +from .statistics import ( + UserStatistics, + UserStatisticsResp, +) from .team import Team, TeamMember -from .user import ( - DailyChallengeStats, - LazerUserAchievement, - LazerUserBadge, - LazerUserBanners, - LazerUserCountry, - LazerUserCounts, - LazerUserKudosu, - LazerUserMonthlyPlaycounts, - LazerUserPreviousUsername, - LazerUserProfile, - LazerUserProfileSections, - LazerUserReplaysWatched, - LazerUserStatistics, - RankHistory, - User, - UserAchievement, - UserAvatar, +from .user_account_history import ( + UserAccountHistory, + UserAccountHistoryResp, + UserAccountHistoryType, ) -BeatmapsetResp.model_rebuild() -BeatmapResp.model_rebuild() __all__ = [ "Beatmap", - "BeatmapResp", "Beatmapset", "BeatmapsetResp", "BestScore", "DailyChallengeStats", - "LazerUserAchievement", - "LazerUserBadge", - "LazerUserBanners", - "LazerUserCountry", - "LazerUserCounts", - "LazerUserKudosu", - "LazerUserMonthlyPlaycounts", - "LazerUserPreviousUsername", - "LazerUserProfile", - "LazerUserProfileSections", - "LazerUserReplaysWatched", - "LazerUserStatistics", - "LegacyOAuthToken", - "LegacyUserStatistics", + "DailyChallengeStatsResp", + "FavouriteBeatmapset", + "ItemAttemptsCount", + "ItemAttemptsResp", + "MultiplayerEvent", + "MultiplayerEventResp", + "MultiplayerScores", "OAuthToken", - "RankHistory", + "PPBestScore", + "Playlist", + "PlaylistBestScore", + "PlaylistResp", "Relationship", "RelationshipResp", "RelationshipType", + "Room", + "RoomResp", "Score", + "ScoreAround", "ScoreBase", "ScoreResp", "ScoreStatistics", @@ -75,6 +75,17 @@ __all__ = [ "Team", "TeamMember", "User", + "UserAccountHistory", + "UserAccountHistoryResp", + "UserAccountHistoryType", "UserAchievement", - "UserAvatar", + "UserAchievement", + "UserAchievementResp", + "UserResp", + "UserStatistics", + "UserStatisticsResp", ] + +for i in __all__: + if i.endswith("Resp"): + globals()[i].model_rebuild() # type: ignore[call-arg] diff --git a/app/database/achievement.py b/app/database/achievement.py new file mode 100644 index 0000000..4be587f --- /dev/null +++ b/app/database/achievement.py @@ -0,0 +1,40 @@ +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from app.models.model import UTCBaseModel + +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .lazer_user import User + + +class UserAchievementBase(SQLModel, UTCBaseModel): + achievement_id: int = Field(primary_key=True) + achieved_at: datetime = Field( + default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)) + ) + + +class UserAchievement(UserAchievementBase, table=True): + __tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType] + + id: int | None = Field(default=None, primary_key=True, index=True) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True + ) + user: "User" = Relationship(back_populates="achievement") + + +class UserAchievementResp(UserAchievementBase): + @classmethod + def from_db(cls, db_model: UserAchievement) -> "UserAchievementResp": + return cls.model_validate(db_model) diff --git a/app/database/auth.py b/app/database/auth.py index ae49676..554dced 100644 --- a/app/database/auth.py +++ b/app/database/auth.py @@ -1,19 +1,21 @@ from datetime import datetime from typing import TYPE_CHECKING +from app.models.model import UTCBaseModel + from sqlalchemy import Column, DateTime from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel if TYPE_CHECKING: - from .user import User + from .lazer_user import User -class OAuthToken(SQLModel, table=True): +class OAuthToken(UTCBaseModel, SQLModel, table=True): __tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("users.id"), index=True) + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) ) access_token: str = Field(max_length=500, unique=True) refresh_token: str = Field(max_length=500, unique=True) diff --git a/app/database/beatmap.py b/app/database/beatmap.py index b28473e..62a0120 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -7,13 +7,14 @@ from app.models.score import MODE_TO_INT, GameMode from .beatmapset import Beatmapset, BeatmapsetResp from sqlalchemy import DECIMAL, Column, DateTime -from sqlalchemy.orm import joinedload from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.fetcher import Fetcher + from .lazer_user import User + class BeatmapOwner(SQLModel): id: int @@ -66,7 +67,9 @@ class Beatmap(BeatmapBase, table=True): beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmap_status: BeatmapRankStatus # optional - beatmapset: Beatmapset = Relationship(back_populates="beatmaps") + beatmapset: Beatmapset = Relationship( + back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"} + ) @property def can_ranked(self) -> bool: @@ -87,13 +90,7 @@ class Beatmap(BeatmapBase, table=True): session.add(beatmap) await session.commit() beatmap = ( - await session.exec( - select(Beatmap) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] - ) - .where(Beatmap.id == resp.id) - ) + await session.exec(select(Beatmap).where(Beatmap.id == resp.id)) ).first() assert beatmap is not None, "Beatmap should not be None after commit" return beatmap @@ -131,13 +128,9 @@ class Beatmap(BeatmapBase, table=True): ) -> "Beatmap": beatmap = ( await session.exec( - select(Beatmap) - .where( + select(Beatmap).where( Beatmap.id == bid if bid is not None else Beatmap.checksum == md5 ) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] - ) ) ).first() if not beatmap: @@ -164,11 +157,13 @@ class BeatmapResp(BeatmapBase): url: str = "" @classmethod - def from_db( + async def from_db( cls, beatmap: Beatmap, query_mode: GameMode | None = None, from_set: bool = False, + session: AsyncSession | None = None, + user: "User | None" = None, ) -> "BeatmapResp": beatmap_ = beatmap.model_dump() if query_mode is not None and beatmap.mode != query_mode: @@ -178,5 +173,7 @@ class BeatmapResp(BeatmapBase): beatmap_["ranked"] = beatmap.beatmap_status.value beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode] if not from_set: - beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset) + beatmap_["beatmapset"] = await BeatmapsetResp.from_db( + beatmap.beatmapset, session=session, user=user + ) return cls.model_validate(beatmap_) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 1e6ba27..3bad7e9 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -4,13 +4,17 @@ from typing import TYPE_CHECKING, TypedDict, cast from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.score import GameMode +from .lazer_user import BASE_INCLUDES, User, UserResp + from pydantic import BaseModel, model_serializer from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text -from sqlmodel import Field, Relationship, SQLModel +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel import Field, Relationship, SQLModel, col, func, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from .beatmap import Beatmap, BeatmapResp + from .favourite_beatmapset import FavouriteBeatmapset class BeatmapCovers(SQLModel): @@ -88,7 +92,6 @@ class BeatmapsetBase(SQLModel): artist_unicode: str = Field(index=True) covers: BeatmapCovers | None = Field(sa_column=Column(JSON)) creator: str - favourite_count: int nsfw: bool = Field(default=False) play_count: int preview_url: str @@ -112,11 +115,9 @@ class BeatmapsetBase(SQLModel): pack_tags: list[str] = Field(default=[], sa_column=Column(JSON)) ratings: list[int] = Field(default=None, sa_column=Column(JSON)) - # TODO: recent_favourites: Optional[list[User]] = None # TODO: related_users: Optional[list[User]] = None # TODO: user: Optional[User] = Field(default=None) track_id: int | None = Field(default=None) # feature artist? - # TODO: has_favourited # BeatmapsetExtended bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2))) @@ -129,7 +130,7 @@ class BeatmapsetBase(SQLModel): tags: str = Field(default="", sa_column=Column(Text)) -class Beatmapset(BeatmapsetBase, table=True): +class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): __tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) @@ -150,6 +151,7 @@ class Beatmapset(BeatmapsetBase, table=True): hype_required: int = Field(default=0) availability_info: str | None = Field(default=None) download_disabled: bool = Field(default=False) + favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") @classmethod async def from_resp( @@ -197,40 +199,88 @@ class BeatmapsetResp(BeatmapsetBase): genre: BeatmapTranslationText | None = None language: BeatmapTranslationText | None = None nominations: BeatmapNominations | None = None + has_favourited: bool = False + favourite_count: int = 0 + recent_favourites: list[UserResp] = Field(default_factory=list) @classmethod - def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": + async def from_db( + cls, + beatmapset: Beatmapset, + include: list[str] = [], + session: AsyncSession | None = None, + user: User | None = None, + ) -> "BeatmapsetResp": from .beatmap import BeatmapResp + from .favourite_beatmapset import FavouriteBeatmapset - beatmaps = [ - BeatmapResp.from_db(beatmap, from_set=True) - for beatmap in beatmapset.beatmaps - ] + update = { + "beatmaps": [ + await BeatmapResp.from_db(beatmap, from_set=True) + for beatmap in await beatmapset.awaitable_attrs.beatmaps + ], + "hype": BeatmapHype( + current=beatmapset.hype_current, required=beatmapset.hype_required + ), + "availability": BeatmapAvailability( + more_information=beatmapset.availability_info, + download_disabled=beatmapset.download_disabled, + ), + "genre": BeatmapTranslationText( + name=beatmapset.beatmap_genre.name, + id=beatmapset.beatmap_genre.value, + ), + "language": BeatmapTranslationText( + name=beatmapset.beatmap_language.name, + id=beatmapset.beatmap_language.value, + ), + "nominations": BeatmapNominations( + required=beatmapset.nominations_required, + current=beatmapset.nominations_current, + ), + "status": beatmapset.beatmap_status.name.lower(), + "ranked": beatmapset.beatmap_status.value, + "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING, + **beatmapset.model_dump(), + } + if session and user: + existing_favourite = ( + await session.exec( + select(FavouriteBeatmapset).where( + FavouriteBeatmapset.beatmapset_id == beatmapset.id + ) + ) + ).first() + update["has_favourited"] = existing_favourite is not None + + if session and "recent_favourites" in include: + recent_favourites = ( + await session.exec( + select(FavouriteBeatmapset) + .where( + FavouriteBeatmapset.beatmapset_id == beatmapset.id, + ) + .order_by(col(FavouriteBeatmapset.date).desc()) + .limit(50) + ) + ).all() + update["recent_favourites"] = [ + await UserResp.from_db( + await favourite.awaitable_attrs.user, + session=session, + include=BASE_INCLUDES, + ) + for favourite in recent_favourites + ] + + if session: + update["favourite_count"] = ( + await session.exec( + select(func.count()) + .select_from(FavouriteBeatmapset) + .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) + ) + ).one() return cls.model_validate( - { - "beatmaps": beatmaps, - "hype": BeatmapHype( - current=beatmapset.hype_current, required=beatmapset.hype_required - ), - "availability": BeatmapAvailability( - more_information=beatmapset.availability_info, - download_disabled=beatmapset.download_disabled, - ), - "genre": BeatmapTranslationText( - name=beatmapset.beatmap_genre.name, - id=beatmapset.beatmap_genre.value, - ), - "language": BeatmapTranslationText( - name=beatmapset.beatmap_language.name, - id=beatmapset.beatmap_language.value, - ), - "nominations": BeatmapNominations( - required=beatmapset.nominations_required, - current=beatmapset.nominations_current, - ), - "status": beatmapset.beatmap_status.name.lower(), - "ranked": beatmapset.beatmap_status.value, - "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING, - **beatmapset.model_dump(), - } + update, ) diff --git a/app/database/best_score.py b/app/database/best_score.py index 313da3e..8688d5b 100644 --- a/app/database/best_score.py +++ b/app/database/best_score.py @@ -1,14 +1,14 @@ from typing import TYPE_CHECKING -from app.models.score import GameMode +from app.models.score import GameMode, Rank -from .user import User +from .lazer_user import User from sqlmodel import ( + JSON, BigInteger, Column, Field, - Float, ForeignKey, Relationship, SQLModel, @@ -20,22 +20,27 @@ if TYPE_CHECKING: class BestScore(SQLModel, table=True): - __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] + __tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType] user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("users.id"), index=True) + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) ) score_id: int = Field( sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) ) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) gamemode: GameMode = Field(index=True) - pp: float = Field( - sa_column=Column(Float, default=0), - ) - acc: float = Field( - sa_column=Column(Float, default=0), + total_score: int = Field(default=0, sa_column=Column(BigInteger)) + mods: list[str] = Field( + default_factory=list, + sa_column=Column(JSON), ) + rank: Rank user: User = Relationship() - score: "Score" = Relationship() + score: "Score" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[BestScore.score_id]", + "lazy": "joined", + } + ) beatmap: "Beatmap" = Relationship() diff --git a/app/database/daily_challenge.py b/app/database/daily_challenge.py new file mode 100644 index 0000000..abf874f --- /dev/null +++ b/app/database/daily_challenge.py @@ -0,0 +1,58 @@ +from datetime import datetime +from typing import TYPE_CHECKING + +from app.models.model import UTCBaseModel + +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .lazer_user import User + + +class DailyChallengeStatsBase(SQLModel, UTCBaseModel): + daily_streak_best: int = Field(default=0) + daily_streak_current: int = Field(default=0) + last_update: datetime | None = Field(default=None, sa_column=Column(DateTime)) + last_weekly_streak: datetime | None = Field( + default=None, sa_column=Column(DateTime) + ) + playcount: int = Field(default=0) + top_10p_placements: int = Field(default=0) + top_50p_placements: int = Field(default=0) + weekly_streak_best: int = Field(default=0) + weekly_streak_current: int = Field(default=0) + + +class DailyChallengeStats(DailyChallengeStatsBase, table=True): + __tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType] + + user_id: int | None = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("lazer_users.id"), + unique=True, + index=True, + primary_key=True, + ), + ) + user: "User" = Relationship(back_populates="daily_challenge_stats") + + +class DailyChallengeStatsResp(DailyChallengeStatsBase): + user_id: int + + @classmethod + def from_db( + cls, + obj: DailyChallengeStats, + ) -> "DailyChallengeStatsResp": + return cls.model_validate(obj) diff --git a/app/database/favourite_beatmapset.py b/app/database/favourite_beatmapset.py new file mode 100644 index 0000000..51bd578 --- /dev/null +++ b/app/database/favourite_beatmapset.py @@ -0,0 +1,53 @@ +import datetime + +from app.database.beatmapset import Beatmapset +from app.database.lazer_user import User + +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + + +class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True): + __tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, autoincrement=True, primary_key=True), + exclude=True, + ) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("lazer_users.id"), + index=True, + ), + ) + beatmapset_id: int = Field( + default=None, + sa_column=Column( + ForeignKey("beatmapsets.id"), + index=True, + ), + ) + date: datetime.datetime = Field( + default=datetime.datetime.now(datetime.UTC), + sa_column=Column( + DateTime, + ), + ) + + user: User = Relationship(back_populates="favourite_beatmapsets") + beatmapset: Beatmapset = Relationship( + sa_relationship_kwargs={ + "lazy": "selectin", + }, + back_populates="favourites", + ) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py new file mode 100644 index 0000000..2717c3a --- /dev/null +++ b/app/database/lazer_user.py @@ -0,0 +1,333 @@ +from datetime import UTC, datetime +from typing import TYPE_CHECKING, NotRequired, TypedDict + +from app.models.model import UTCBaseModel +from app.models.score import GameMode +from app.models.user import Country, Page, RankHistory + +from .achievement import UserAchievement, UserAchievementResp +from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp +from .monthly_playcounts import MonthlyPlaycounts, MonthlyPlaycountsResp +from .statistics import UserStatistics, UserStatisticsResp +from .team import Team, TeamMember +from .user_account_history import UserAccountHistory, UserAccountHistoryResp + +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel import ( + JSON, + BigInteger, + Column, + DateTime, + Field, + Relationship, + SQLModel, + func, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from .favourite_beatmapset import FavouriteBeatmapset + from .relationship import RelationshipResp + + +class Kudosu(TypedDict): + available: int + total: int + + +class RankHighest(TypedDict): + rank: int + updated_at: datetime + + +class UserProfileCover(TypedDict): + url: str + custom_url: NotRequired[str] + id: NotRequired[str] + + +Badge = TypedDict( + "Badge", + { + "awarded_at": datetime, + "description": str, + "image@2x_url": str, + "image_url": str, + "url": str, + }, +) + + +class UserBase(UTCBaseModel, SQLModel): + avatar_url: str = "" + country_code: str = Field(default="CN", max_length=2, index=True) + # ? default_group: str|None + is_active: bool = True + is_bot: bool = False + is_supporter: bool = False + last_visit: datetime | None = Field( + default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)) + ) + pm_friends_only: bool = False + profile_colour: str | None = None + username: str = Field(max_length=32, unique=True, index=True) + page: Page = Field(sa_column=Column(JSON), default=Page(html="", raw="")) + previous_usernames: list[str] = Field(default_factory=list, sa_column=Column(JSON)) + # TODO: replays_watched_counts + support_level: int = 0 + badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON)) + + # optional + is_restricted: bool = False + # blocks + cover: UserProfileCover = Field( + default=UserProfileCover( + url="https://assets.ppy.sh/user-profile-covers/default.jpeg" + ), + sa_column=Column(JSON), + ) + beatmap_playcounts_count: int = 0 + # kudosu + + # UserExtended + playmode: GameMode = GameMode.OSU + discord: str | None = None + has_supported: bool = False + interests: str | None = None + join_date: datetime = Field(default=datetime.now(UTC)) + location: str | None = None + max_blocks: int = 50 + max_friends: int = 500 + occupation: str | None = None + playstyle: list[str] = Field(default_factory=list, sa_column=Column(JSON)) + # TODO: post_count + profile_hue: int | None = None + profile_order: list[str] = Field( + default_factory=lambda: [ + "me", + "recent_activity", + "top_ranks", + "medals", + "historical", + "beatmaps", + "kudosu", + ], + sa_column=Column(JSON), + ) + title: str | None = None + title_url: str | None = None + twitter: str | None = None + website: str | None = None + + # undocumented + comments_count: int = 0 + post_count: int = 0 + is_admin: bool = False + is_gmt: bool = False + is_qat: bool = False + is_bng: bool = False + + +class User(AsyncAttrs, UserBase, table=True): + __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] + + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), + ) + account_history: list[UserAccountHistory] = Relationship() + statistics: list[UserStatistics] = Relationship() + achievement: list[UserAchievement] = Relationship(back_populates="user") + team_membership: TeamMember | None = Relationship(back_populates="user") + daily_challenge_stats: DailyChallengeStats | None = Relationship( + back_populates="user" + ) + monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user") + favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship( + back_populates="user" + ) + + email: str = Field(max_length=254, unique=True, index=True, exclude=True) + priv: int = Field(default=1, exclude=True) + pw_bcrypt: str = Field(max_length=60, exclude=True) + silence_end_at: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True)), exclude=True + ) + donor_end_at: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True)), exclude=True + ) + + +class UserResp(UserBase): + id: int | None = None + is_online: bool = False + groups: list = [] # TODO + country: Country = Field(default_factory=lambda: Country(code="CN", name="China")) + favourite_beatmapset_count: int = 0 # TODO + graveyard_beatmapset_count: int = 0 # TODO + guest_beatmapset_count: int = 0 # TODO + loved_beatmapset_count: int = 0 # TODO + mapping_follower_count: int = 0 # TODO + nominated_beatmapset_count: int = 0 # TODO + pending_beatmapset_count: int = 0 # TODO + ranked_beatmapset_count: int = 0 # TODO + follow_user_mapping: list[int] = Field(default_factory=list) + follower_count: int = 0 + friends: list["RelationshipResp"] | None = None + scores_best_count: int = 0 + scores_first_count: int = 0 + scores_recent_count: int = 0 + scores_pinned_count: int = 0 + account_history: list[UserAccountHistoryResp] = [] + active_tournament_banners: list[dict] = [] # TODO + kudosu: Kudosu = Field(default_factory=lambda: Kudosu(available=0, total=0)) # TODO + monthly_playcounts: list[MonthlyPlaycountsResp] = Field(default_factory=list) + unread_pm_count: int = 0 # TODO + rank_history: RankHistory | None = None # TODO + rank_highest: RankHighest | None = None # TODO + statistics: UserStatisticsResp | None = None + statistics_rulesets: dict[str, UserStatisticsResp] | None = None + user_achievements: list[UserAchievementResp] = Field(default_factory=list) + cover_url: str = "" # deprecated + team: Team | None = None + session_verified: bool = True + daily_challenge_user_stats: DailyChallengeStatsResp | None = None + + # TODO: monthly_playcounts, unread_pm_count, rank_history, user_preferences + + @classmethod + async def from_db( + cls, + obj: User, + session: AsyncSession, + include: list[str] = [], + ruleset: GameMode | None = None, + ) -> "UserResp": + from app.dependencies.database import get_redis + + from .best_score import BestScore + from .relationship import Relationship, RelationshipResp, RelationshipType + + u = cls.model_validate(obj.model_dump()) + u.id = obj.id + u.follower_count = ( + await session.exec( + select(func.count()) + .select_from(Relationship) + .where( + Relationship.target_id == obj.id, + Relationship.type == RelationshipType.FOLLOW, + ) + ) + ).one() + u.scores_best_count = ( + await session.exec( + select(func.count()) + .select_from(BestScore) + .where( + BestScore.user_id == obj.id, + ) + .limit(200) + ) + ).one() + redis = get_redis() + u.is_online = await redis.exists(f"metadata:online:{obj.id}") + u.cover_url = ( + obj.cover.get( + "url", "https://assets.ppy.sh/user-profile-covers/default.jpeg" + ) + if obj.cover + else "https://assets.ppy.sh/user-profile-covers/default.jpeg" + ) + + if "friends" in include: + u.friends = [ + await RelationshipResp.from_db(session, r) + for r in ( + await session.exec( + select(Relationship).where( + Relationship.user_id == obj.id, + Relationship.type == RelationshipType.FOLLOW, + ) + ) + ).all() + ] + + if "team" in include: + if await obj.awaitable_attrs.team_membership: + assert obj.team_membership + u.team = obj.team_membership.team + + if "account_history" in include: + u.account_history = [ + UserAccountHistoryResp.from_db(ah) + for ah in await obj.awaitable_attrs.account_history + ] + + if "daily_challenge_user_stats": + if await obj.awaitable_attrs.daily_challenge_stats: + assert obj.daily_challenge_stats + u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db( + obj.daily_challenge_stats + ) + + if "statistics" in include: + current_stattistics = None + for i in await obj.awaitable_attrs.statistics: + if i.mode == (ruleset or obj.playmode): + current_stattistics = i + break + u.statistics = ( + UserStatisticsResp.from_db(current_stattistics) + if current_stattistics + else None + ) + + if "statistics_rulesets" in include: + u.statistics_rulesets = { + i.mode.value: UserStatisticsResp.from_db(i) + for i in await obj.awaitable_attrs.statistics + } + + if "monthly_playcounts" in include: + u.monthly_playcounts = [ + MonthlyPlaycountsResp.from_db(pc) + for pc in await obj.awaitable_attrs.monthly_playcounts + ] + + if "achievements" in include: + u.user_achievements = [ + UserAchievementResp.from_db(ua) + for ua in await obj.awaitable_attrs.achievement + ] + + return u + + +ALL_INCLUDED = [ + "friends", + "team", + "account_history", + "daily_challenge_user_stats", + "statistics", + "statistics_rulesets", + "achievements", + "monthly_playcounts", +] + + +SEARCH_INCLUDED = [ + "team", + "daily_challenge_user_stats", + "statistics", + "statistics_rulesets", + "achievements", + "monthly_playcounts", +] + +BASE_INCLUDES = [ + "team", + "daily_challenge_user_stats", + "statistics", +] diff --git a/app/database/legacy.py b/app/database/legacy.py deleted file mode 100644 index ff1e957..0000000 --- a/app/database/legacy.py +++ /dev/null @@ -1,94 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING - -from sqlalchemy import JSON, Column, DateTime -from sqlalchemy.orm import Mapped -from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel - -if TYPE_CHECKING: - from .user import User -# ============================================ -# 旧的兼容性表模型(保留以便向后兼容) -# ============================================ - - -class LegacyUserStatistics(SQLModel, table=True): - __tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - mode: str = Field(max_length=10) # osu, taiko, fruits, mania - - # 基本统计 - count_100: int = Field(default=0) - count_300: int = Field(default=0) - count_50: int = Field(default=0) - count_miss: int = Field(default=0) - - # 等级信息 - level_current: int = Field(default=1) - level_progress: int = Field(default=0) - - # 排名信息 - global_rank: int | None = Field(default=None) - global_rank_exp: int | None = Field(default=None) - country_rank: int | None = Field(default=None) - - # PP 和分数 - pp: float = Field(default=0.0) - pp_exp: float = Field(default=0.0) - ranked_score: int = Field(default=0) - hit_accuracy: float = Field(default=0.0) - total_score: int = Field(default=0) - total_hits: int = Field(default=0) - maximum_combo: int = Field(default=0) - - # 游戏统计 - play_count: int = Field(default=0) - play_time: int = Field(default=0) - replays_watched_by_others: int = Field(default=0) - is_ranked: bool = Field(default=False) - - # 成绩等级计数 - grade_ss: int = Field(default=0) - grade_ssh: int = Field(default=0) - grade_s: int = Field(default=0) - grade_sh: int = Field(default=0) - grade_a: int = Field(default=0) - - # 最高排名记录 - rank_highest: int | None = Field(default=None) - rank_highest_updated_at: datetime | None = Field( - default=None, sa_column=Column(DateTime) - ) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - # 关联关系 - user: Mapped["User"] = Relationship(back_populates="statistics") - - -class LegacyOAuthToken(SQLModel, table=True): - __tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - access_token: str = Field(max_length=255, index=True) - refresh_token: str = Field(max_length=255, index=True) - expires_at: datetime = Field(sa_column=Column(DateTime)) - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - previous_usernames: list = Field(default_factory=list, sa_column=Column(JSON)) - replays_watched_counts: list = Field(default_factory=list, sa_column=Column(JSON)) - - # 用户关系 - user: "User" = Relationship() diff --git a/app/database/monthly_playcounts.py b/app/database/monthly_playcounts.py new file mode 100644 index 0000000..46192d1 --- /dev/null +++ b/app/database/monthly_playcounts.py @@ -0,0 +1,43 @@ +from datetime import date +from typing import TYPE_CHECKING + +from sqlmodel import ( + BigInteger, + Column, + Field, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .lazer_user import User + + +class MonthlyPlaycounts(SQLModel, table=True): + __tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType] + + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, primary_key=True, autoincrement=True), + ) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + year: int = Field(index=True) + month: int = Field(index=True) + playcount: int = Field(default=0) + + user: "User" = Relationship(back_populates="monthly_playcounts") + + +class MonthlyPlaycountsResp(SQLModel): + start_date: date + count: int + + @classmethod + def from_db(cls, db_model: MonthlyPlaycounts) -> "MonthlyPlaycountsResp": + return cls( + start_date=date(db_model.year, db_model.month, 1), + count=db_model.playcount, + ) diff --git a/app/database/multiplayer_event.py b/app/database/multiplayer_event.py new file mode 100644 index 0000000..904fbe4 --- /dev/null +++ b/app/database/multiplayer_event.py @@ -0,0 +1,56 @@ +from datetime import UTC, datetime +from typing import Any + +from app.models.model import UTCBaseModel + +from sqlmodel import ( + JSON, + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + SQLModel, +) + + +class MultiplayerEventBase(SQLModel, UTCBaseModel): + playlist_item_id: int | None = None + user_id: int | None = Field( + default=None, + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True), + ) + created_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=datetime.now(UTC), + ) + event_type: str = Field(index=True) + + +class MultiplayerEvent(MultiplayerEventBase, table=True): + __tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), + ) + room_id: int = Field(foreign_key="rooms.id", index=True) + updated_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=datetime.now(UTC), + ) + event_detail: dict[str, Any] | None = Field( + sa_column=Column(JSON), + default_factory=dict, + ) + + +class MultiplayerEventResp(MultiplayerEventBase): + id: int + + @classmethod + def from_db(cls, event: MultiplayerEvent) -> "MultiplayerEventResp": + return cls.model_validate(event) diff --git a/app/database/playlist_attempts.py b/app/database/playlist_attempts.py new file mode 100644 index 0000000..93bc8c5 --- /dev/null +++ b/app/database/playlist_attempts.py @@ -0,0 +1,114 @@ +from .lazer_user import User, UserResp +from .playlist_best_score import PlaylistBestScore + +from sqlmodel import ( + BigInteger, + Column, + Field, + ForeignKey, + Relationship, + SQLModel, + col, + func, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + + +class ItemAttemptsCountBase(SQLModel): + room_id: int = Field(foreign_key="rooms.id", index=True) + attempts: int = Field(default=0) + completed: int = Field(default=0) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + accuracy: float = 0.0 + pp: float = 0 + total_score: int = 0 + + +class ItemAttemptsCount(ItemAttemptsCountBase, table=True): + __tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType] + id: int | None = Field(default=None, primary_key=True) + + user: User = Relationship() + + async def get_position(self, session: AsyncSession) -> int: + rownum = ( + func.row_number() + .over( + partition_by=col(ItemAttemptsCountBase.room_id), + order_by=col(ItemAttemptsCountBase.total_score).desc(), + ) + .label("rn") + ) + subq = select(ItemAttemptsCountBase, rownum).subquery() + stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id) + result = await session.exec(stmt) + return result.one() + + async def update(self, session: AsyncSession): + playlist_scores = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.room_id == self.room_id, + PlaylistBestScore.user_id == self.user_id, + ) + ) + ).all() + self.attempts = sum(score.attempts for score in playlist_scores) + self.total_score = sum(score.total_score for score in playlist_scores) + self.pp = sum(score.score.pp for score in playlist_scores) + self.completed = len(playlist_scores) + self.accuracy = ( + sum(score.score.accuracy * score.attempts for score in playlist_scores) + / self.completed + if self.completed > 0 + else 0.0 + ) + await session.commit() + await session.refresh(self) + + @classmethod + async def get_or_create( + cls, + room_id: int, + user_id: int, + session: AsyncSession, + ) -> "ItemAttemptsCount": + item_attempts = await session.exec( + select(cls).where( + cls.room_id == room_id, + cls.user_id == user_id, + ) + ) + item_attempts = item_attempts.first() + if item_attempts is None: + item_attempts = cls(room_id=room_id, user_id=user_id) + session.add(item_attempts) + await session.commit() + await session.refresh(item_attempts) + await item_attempts.update(session) + return item_attempts + + +class ItemAttemptsResp(ItemAttemptsCountBase): + user: UserResp | None = None + position: int | None = None + + @classmethod + async def from_db( + cls, + item_attempts: ItemAttemptsCount, + session: AsyncSession, + include: list[str] = [], + ) -> "ItemAttemptsResp": + resp = cls.model_validate(item_attempts) + resp.user = await UserResp.from_db( + item_attempts.user, + session=session, + include=["statistics", "team", "daily_challenge_user_stats"], + ) + if "position" in include: + resp.position = await item_attempts.get_position(session) + return resp diff --git a/app/database/playlist_best_score.py b/app/database/playlist_best_score.py new file mode 100644 index 0000000..46bbfba --- /dev/null +++ b/app/database/playlist_best_score.py @@ -0,0 +1,109 @@ +from typing import TYPE_CHECKING + +from .lazer_user import User + +from redis.asyncio import Redis +from sqlmodel import ( + BigInteger, + Column, + Field, + ForeignKey, + Relationship, + SQLModel, + col, + func, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from .score import Score + + +class PlaylistBestScore(SQLModel, table=True): + __tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType] + + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + score_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) + ) + room_id: int = Field(foreign_key="rooms.id", index=True) + playlist_id: int = Field(foreign_key="room_playlists.id", index=True) + total_score: int = Field(default=0, sa_column=Column(BigInteger)) + attempts: int = Field(default=0) # playlist + + user: User = Relationship() + score: "Score" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[PlaylistBestScore.score_id]", + "lazy": "joined", + } + ) + + +async def process_playlist_best_score( + room_id: int, + playlist_id: int, + user_id: int, + score_id: int, + total_score: int, + session: AsyncSession, + redis: Redis, +): + previous = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.room_id == room_id, + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.user_id == user_id, + ) + ) + ).first() + if previous is None: + score = PlaylistBestScore( + user_id=user_id, + score_id=score_id, + room_id=room_id, + playlist_id=playlist_id, + total_score=total_score, + ) + session.add(score) + else: + previous.score_id = score_id + previous.total_score = total_score + previous.attempts += 1 + await session.commit() + await redis.decr(f"multiplayer:{room_id}:gameplay:players") + + +async def get_position( + room_id: int, + playlist_id: int, + score_id: int, + session: AsyncSession, +) -> int: + rownum = ( + func.row_number() + .over( + partition_by=( + col(PlaylistBestScore.playlist_id), + col(PlaylistBestScore.room_id), + ), + order_by=col(PlaylistBestScore.total_score).desc(), + ) + .label("row_number") + ) + subq = ( + select(PlaylistBestScore, rownum) + .where( + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + ) + .subquery() + ) + stmt = select(subq.c.row_number).where(subq.c.score_id == score_id) + result = await session.exec(stmt) + s = result.one_or_none() + return s if s else 0 diff --git a/app/database/playlists.py b/app/database/playlists.py new file mode 100644 index 0000000..3f7ae40 --- /dev/null +++ b/app/database/playlists.py @@ -0,0 +1,143 @@ +from datetime import datetime +from typing import TYPE_CHECKING + +from app.models.model import UTCBaseModel +from app.models.mods import APIMod +from app.models.multiplayer_hub import PlaylistItem + +from .beatmap import Beatmap, BeatmapResp + +from sqlmodel import ( + JSON, + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, + func, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from .room import Room + + +class PlaylistBase(SQLModel, UTCBaseModel): + id: int = Field(index=True) + owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) + ruleset_id: int = Field(ge=0, le=3) + expired: bool = Field(default=False) + playlist_order: int = Field(default=0) + played_at: datetime | None = Field( + sa_column=Column(DateTime(timezone=True)), + default=None, + ) + allowed_mods: list[APIMod] = Field( + default_factory=list, + sa_column=Column(JSON), + ) + required_mods: list[APIMod] = Field( + default_factory=list, + sa_column=Column(JSON), + ) + beatmap_id: int = Field( + foreign_key="beatmaps.id", + ) + freestyle: bool = Field(default=False) + + +class Playlist(PlaylistBase, table=True): + __tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType] + db_id: int = Field(default=None, primary_key=True, index=True, exclude=True) + room_id: int = Field(foreign_key="rooms.id", exclude=True) + + beatmap: Beatmap = Relationship( + sa_relationship_kwargs={ + "lazy": "joined", + } + ) + room: "Room" = Relationship() + + @classmethod + async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int: + stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where( + cls.room_id == room_id + ) + result = await session.exec(stmt) + return result.one() + + @classmethod + async def from_hub( + cls, playlist: PlaylistItem, room_id: int, session: AsyncSession + ) -> "Playlist": + next_id = await cls.get_next_id_for_room(room_id, session=session) + return cls( + id=next_id, + owner_id=playlist.owner_id, + ruleset_id=playlist.ruleset_id, + beatmap_id=playlist.beatmap_id, + required_mods=playlist.required_mods, + allowed_mods=playlist.allowed_mods, + expired=playlist.expired, + playlist_order=playlist.playlist_order, + played_at=playlist.played_at, + freestyle=playlist.freestyle, + room_id=room_id, + ) + + @classmethod + async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): + db_playlist = await session.exec( + select(cls).where(cls.id == playlist.id, cls.room_id == room_id) + ) + db_playlist = db_playlist.first() + if db_playlist is None: + raise ValueError("Playlist item not found") + db_playlist.owner_id = playlist.owner_id + db_playlist.ruleset_id = playlist.ruleset_id + db_playlist.beatmap_id = playlist.beatmap_id + db_playlist.required_mods = playlist.required_mods + db_playlist.allowed_mods = playlist.allowed_mods + db_playlist.expired = playlist.expired + db_playlist.playlist_order = playlist.playlist_order + db_playlist.played_at = playlist.played_at + db_playlist.freestyle = playlist.freestyle + await session.commit() + + @classmethod + async def add_to_db( + cls, playlist: PlaylistItem, room_id: int, session: AsyncSession + ): + db_playlist = await cls.from_hub(playlist, room_id, session) + session.add(db_playlist) + await session.commit() + await session.refresh(db_playlist) + playlist.id = db_playlist.id + + @classmethod + async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession): + db_playlist = await session.exec( + select(cls).where(cls.id == item_id, cls.room_id == room_id) + ) + db_playlist = db_playlist.first() + if db_playlist is None: + raise ValueError("Playlist item not found") + await session.delete(db_playlist) + await session.commit() + + +class PlaylistResp(PlaylistBase): + beatmap: BeatmapResp | None = None + + @classmethod + async def from_db( + cls, playlist: Playlist, include: list[str] = [] + ) -> "PlaylistResp": + data = playlist.model_dump() + if "beatmap" in include: + data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap, from_set=True) + resp = cls.model_validate(data) + return resp diff --git a/app/database/pp_best_score.py b/app/database/pp_best_score.py new file mode 100644 index 0000000..ffc74d3 --- /dev/null +++ b/app/database/pp_best_score.py @@ -0,0 +1,41 @@ +from typing import TYPE_CHECKING + +from app.models.score import GameMode + +from .lazer_user import User + +from sqlmodel import ( + BigInteger, + Column, + Field, + Float, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .beatmap import Beatmap + from .score import Score + + +class PPBestScore(SQLModel, table=True): + __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + score_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) + ) + beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) + gamemode: GameMode = Field(index=True) + pp: float = Field( + sa_column=Column(Float, default=0), + ) + acc: float = Field( + sa_column=Column(Float, default=0), + ) + + user: User = Relationship() + score: "Score" = Relationship() + beatmap: "Beatmap" = Relationship() diff --git a/app/database/relationship.py b/app/database/relationship.py index 61dc109..b941c28 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -1,8 +1,6 @@ from enum import Enum -from app.models.user import User as APIUser - -from .user import User as DBUser +from .lazer_user import User, UserResp from pydantic import BaseModel from sqlmodel import ( @@ -24,12 +22,16 @@ class RelationshipType(str, Enum): class Relationship(SQLModel, table=True): __tablename__ = "relationship" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, autoincrement=True, primary_key=True), + exclude=True, + ) user_id: int = Field( default=None, sa_column=Column( BigInteger, - ForeignKey("users.id"), - primary_key=True, + ForeignKey("lazer_users.id"), index=True, ), ) @@ -37,20 +39,22 @@ class Relationship(SQLModel, table=True): default=None, sa_column=Column( BigInteger, - ForeignKey("users.id"), - primary_key=True, + ForeignKey("lazer_users.id"), index=True, ), ) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) - target: DBUser = SQLRelationship( - sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} + target: User = SQLRelationship( + sa_relationship_kwargs={ + "foreign_keys": "[Relationship.target_id]", + "lazy": "selectin", + } ) class RelationshipResp(BaseModel): target_id: int - target: APIUser + target: UserResp mutual: bool = False type: RelationshipType @@ -58,8 +62,6 @@ class RelationshipResp(BaseModel): async def from_db( cls, session: AsyncSession, relationship: Relationship ) -> "RelationshipResp": - from app.utils import convert_db_user_to_api_user - target_relationship = ( await session.exec( select(Relationship).where( @@ -75,7 +77,16 @@ class RelationshipResp(BaseModel): ) return cls( target_id=relationship.target_id, - target=await convert_db_user_to_api_user(relationship.target), + target=await UserResp.from_db( + relationship.target, + session, + include=[ + "team", + "daily_challenge_user_stats", + "statistics", + "statistics_rulesets", + ], + ), mutual=mutual, type=relationship.type, ) diff --git a/app/database/room.py b/app/database/room.py index 0b79ee6..7817805 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -1,6 +1,125 @@ -from sqlmodel import Field, SQLModel +from datetime import UTC, datetime + +from app.models.multiplayer_hub import ServerMultiplayerRoom +from app.models.room import ( + MatchType, + QueueMode, + RoomCategory, + RoomDifficultyRange, + RoomPlaylistItemStats, + RoomStatus, +) + +from .lazer_user import User, UserResp +from .playlists import Playlist, PlaylistResp + +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) -class RoomIndex(SQLModel, table=True): - __tablename__ = "mp_room_index" # pyright: ignore[reportAssignmentType] - id: int = Field(default=None, primary_key=True, index=True) # pyright: ignore[reportCallIssue] +class RoomBase(SQLModel): + name: str = Field(index=True) + category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True) + duration: int | None = Field(default=None) # minutes + starts_at: datetime | None = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=datetime.now(UTC), + ) + ended_at: datetime | None = Field( + sa_column=Column( + DateTime(timezone=True), + ), + default=None, + ) + participant_count: int = Field(default=0) + max_attempts: int | None = Field(default=None) # playlists + type: MatchType + queue_mode: QueueMode + auto_skip: bool + auto_start_duration: int + status: RoomStatus + # TODO: channel_id + # recent_participants: list[User] + + +class Room(RoomBase, table=True): + __tablename__ = "rooms" # pyright: ignore[reportAssignmentType] + id: int = Field(default=None, primary_key=True, index=True) + host_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + + host: User = Relationship() + playlist: list[Playlist] = Relationship( + sa_relationship_kwargs={ + "lazy": "joined", + "cascade": "all, delete-orphan", + "overlaps": "room", + } + ) + + +class RoomResp(RoomBase): + id: int + password: str | None = None + host: UserResp | None = None + playlist: list[PlaylistResp] = [] + playlist_item_stats: RoomPlaylistItemStats | None = None + difficulty_range: RoomDifficultyRange | None = None + current_playlist_item: PlaylistResp | None = None + + @classmethod + async def from_db(cls, room: Room) -> "RoomResp": + resp = cls.model_validate(room.model_dump()) + + stats = RoomPlaylistItemStats(count_active=0, count_total=0) + difficulty_range = RoomDifficultyRange( + min=0, + max=0, + ) + rulesets = set() + for playlist in room.playlist: + stats.count_total += 1 + if not playlist.expired: + stats.count_active += 1 + rulesets.add(playlist.ruleset_id) + difficulty_range.min = min( + difficulty_range.min, playlist.beatmap.difficulty_rating + ) + difficulty_range.max = max( + difficulty_range.max, playlist.beatmap.difficulty_rating + ) + resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"])) + stats.ruleset_ids = list(rulesets) + resp.playlist_item_stats = stats + resp.difficulty_range = difficulty_range + resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None + + return resp + + @classmethod + async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp": + room = server_room.room + resp = cls( + id=room.room_id, + name=room.settings.name, + type=room.settings.match_type, + queue_mode=room.settings.queue_mode, + auto_skip=room.settings.auto_skip, + auto_start_duration=int(room.settings.auto_start_duration.total_seconds()), + status=server_room.status, + category=server_room.category, + # duration = room.settings.duration, + starts_at=server_room.start_at, + participant_count=len(room.users), + ) + return resp diff --git a/app/database/score.py b/app/database/score.py index 046f83c..37b96a3 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -1,8 +1,9 @@ import asyncio from collections.abc import Sequence -from datetime import UTC, datetime +from datetime import UTC, date, datetime +import json import math -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from app.calculator import ( calculate_pp, @@ -12,9 +13,8 @@ from app.calculator import ( calculate_weighted_pp, clamp, ) -from app.database.score_token import ScoreToken -from app.database.user import LazerUserStatistics, User -from app.models.beatmap import BeatmapRankStatus +from app.database.team import TeamMember +from app.models.model import RespWithCursor, UTCBaseModel from app.models.mods import APIMod, mods_can_get_pp from app.models.score import ( INT_TO_MODE, @@ -26,15 +26,24 @@ from app.models.score import ( ScoreStatistics, SoloScoreSubmissionInfo, ) -from app.models.user import User as APIUser from .beatmap import Beatmap, BeatmapResp -from .beatmapset import Beatmapset, BeatmapsetResp +from .beatmapset import BeatmapsetResp from .best_score import BestScore +from .lazer_user import User, UserResp +from .monthly_playcounts import MonthlyPlaycounts +from .pp_best_score import PPBestScore +from .relationship import ( + Relationship as DBRelationship, + RelationshipType, +) +from .score_token import ScoreToken -from redis import Redis +from redis.asyncio import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime -from sqlalchemy.orm import aliased, joinedload +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import aliased +from sqlalchemy.sql.elements import ColumnElement from sqlmodel import ( JSON, BigInteger, @@ -43,9 +52,10 @@ from sqlmodel import ( Relationship, SQLModel, col, - false, func, select, + text, + true, ) from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql._expression_select_cls import SelectOfScalar @@ -54,7 +64,7 @@ if TYPE_CHECKING: from app.fetcher import Fetcher -class ScoreBase(SQLModel): +class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): # 基本字段 accuracy: float map_md5: str = Field(max_length=32, index=True) @@ -78,10 +88,11 @@ class ScoreBase(SQLModel): default=0, sa_column=Column(BigInteger), exclude=True ) type: str + beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") # optional # TODO: current_user_attributes - position: int | None = Field(default=None) # multiplayer + # position: int | None = Field(default=None) # multiplayer class Score(ScoreBase, table=True): @@ -89,12 +100,11 @@ class Score(ScoreBase, table=True): id: int | None = Field( default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True) ) - beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") user_id: int = Field( default=None, sa_column=Column( BigInteger, - ForeignKey("users.id"), + ForeignKey("lazer_users.id"), index=True, ), ) @@ -112,28 +122,13 @@ class Score(ScoreBase, table=True): gamemode: GameMode = Field(index=True) # optional - beatmap: "Beatmap" = Relationship() - user: "User" = Relationship() + beatmap: Beatmap = Relationship() + user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) @property def is_perfect_combo(self) -> bool: return self.max_combo == self.beatmap.max_combo - @staticmethod - def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]: - clause = select(Score).options( - joinedload(Score.beatmap) # pyright: ignore[reportArgumentType] - .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] - .selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ), - ) - if with_user: - return clause.options( - joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType] - ) - return clause - @staticmethod def select_clause_unique( *where_clauses: ColumnExpressionArgument[bool] | bool, @@ -147,18 +142,7 @@ class Score(ScoreBase, table=True): ) subq = select(Score, rownum).where(*where_clauses).subquery() best = aliased(Score, subq, adapt_on_names=True) - return ( - select(best) - .where(subq.c.rn == 1) - .options( - joinedload(best.beatmap) # pyright: ignore[reportArgumentType] - .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] - .selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ), - joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType] - ) - ) + return select(best).where(subq.c.rn == 1) class ScoreResp(ScoreBase): @@ -173,22 +157,23 @@ class ScoreResp(ScoreBase): ruleset_id: int | None = None beatmap: BeatmapResp | None = None beatmapset: BeatmapsetResp | None = None - user: APIUser | None = None + user: UserResp | None = None statistics: ScoreStatistics | None = None maximum_statistics: ScoreStatistics | None = None rank_global: int | None = None rank_country: int | None = None + position: int | None = None + scores_around: "ScoreAround | None" = None @classmethod - async def from_db( - cls, session: AsyncSession, score: Score, user: User | None = None - ) -> "ScoreResp": - from app.utils import convert_db_user_to_api_user - + async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": s = cls.model_validate(score.model_dump()) assert score.id - s.beatmap = BeatmapResp.from_db(score.beatmap) - s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset) + await score.awaitable_attrs.beatmap + s.beatmap = await BeatmapResp.from_db(score.beatmap) + s.beatmapset = await BeatmapsetResp.from_db( + score.beatmap.beatmapset, session=session, user=score.user + ) s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.ruleset_id = MODE_TO_INT[score.gamemode] @@ -220,163 +205,184 @@ class ScoreResp(ScoreBase): s.maximum_statistics = { HitResult.GREAT: score.beatmap.max_combo, } - if user: - s.user = await convert_db_user_to_api_user(user) + s.user = await UserResp.from_db( + score.user, + session, + include=["statistics", "team", "daily_challenge_user_stats"], + ruleset=score.gamemode, + ) s.rank_global = ( await get_score_position_by_id( session, - score.map_md5, + score.beatmap_id, score.id, mode=score.gamemode, - user=user or score.user, + user=score.user, ) or None ) s.rank_country = ( await get_score_position_by_id( session, - score.map_md5, + score.beatmap_id, score.id, score.gamemode, - user or score.user, + score.user, + type=LeaderboardType.COUNTRY, ) or None ) return s +class MultiplayerScores(RespWithCursor): + scores: list[ScoreResp] = Field(default_factory=list) + params: dict[str, Any] = Field(default_factory=dict) + + +class ScoreAround(SQLModel): + higher: MultiplayerScores | None = None + lower: MultiplayerScores | None = None + + async def get_best_id(session: AsyncSession, score_id: int) -> None: rownum = ( func.row_number() - .over(partition_by=col(BestScore.user_id), order_by=col(BestScore.pp).desc()) + .over( + partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc() + ) .label("rn") ) - subq = select(BestScore, rownum).subquery() + subq = select(PPBestScore, rownum).subquery() stmt = select(subq.c.rn).where(subq.c.score_id == score_id) result = await session.exec(stmt) return result.one_or_none() +async def _score_where( + type: LeaderboardType, + beatmap: int, + mode: GameMode, + mods: list[str] | None = None, + user: User | None = None, +) -> list[ColumnElement[bool]] | None: + wheres = [ + col(BestScore.beatmap_id) == beatmap, + col(BestScore.gamemode) == mode, + ] + + if type == LeaderboardType.FRIENDS: + if user and user.is_supporter: + subq = ( + select(DBRelationship.target_id) + .where( + DBRelationship.type == RelationshipType.FOLLOW, + DBRelationship.user_id == user.id, + ) + .subquery() + ) + wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id))) + else: + return None + elif type == LeaderboardType.COUNTRY: + if user and user.is_supporter: + wheres.append( + col(BestScore.user).has(col(User.country_code) == user.country_code) + ) + else: + return None + elif type == LeaderboardType.TEAM: + if user: + team_membership = await user.awaitable_attrs.team_membership + if team_membership: + team_id = team_membership.team_id + wheres.append( + col(BestScore.user).has( + col(User.team_membership).has(TeamMember.team_id == team_id) + ) + ) + if mods: + if user and user.is_supporter: + wheres.append( + text( + "JSON_CONTAINS(total_score_best_scores.mods, :w)" + " AND JSON_CONTAINS(:w, total_score_best_scores.mods)" + ) # pyright: ignore[reportArgumentType] + ) + else: + return None + return wheres + + async def get_leaderboard( session: AsyncSession, - beatmap_md5: str, + beatmap: int, mode: GameMode, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, user: User | None = None, limit: int = 50, -) -> list[Score]: - scores = [] - if type == LeaderboardType.GLOBAL: - query = ( - select(Score) - .where( - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - Score.mods == mods if user and user.is_supporter else false(), - ) - .limit(limit) - .order_by( - col(Score.total_score).desc(), - ) - ) - result = await session.exec(query) - scores = list[Score](result.all()) - elif type == LeaderboardType.FRIENDS and user and user.is_supporter: - # TODO - ... - elif type == LeaderboardType.TEAM and user and user.team_membership: - team_id = user.team_membership.team_id - query = ( - select(Score) - .join(Beatmap) - .options(joinedload(Score.user)) # pyright: ignore[reportArgumentType] - .where( - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Score.user.team_membership).is_not(None), - Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess] - Score.mods == mods if user and user.is_supporter else false(), - ) - .limit(limit) - .order_by( - col(Score.total_score).desc(), - ) - ) - result = await session.exec(query) - scores = list[Score](result.all()) +) -> tuple[list[Score], Score | None]: + wheres = await _score_where(type, beatmap, mode, mods, user) + if wheres is None: + return [], None + query = ( + select(BestScore) + .where(*wheres) + .limit(limit) + .order_by(col(BestScore.total_score).desc()) + ) + if mods: + query = query.params(w=json.dumps(mods)) + scores = [s.score for s in await session.exec(query)] + user_score = None if user: - user_score = ( - await session.exec( - select(Score).where( - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - Score.user_id == user.id, - col(Score.passed).is_(True), - ) + self_query = ( + select(BestScore) + .where(BestScore.user_id == user.id) + .where( + col(BestScore.beatmap_id) == beatmap, + col(BestScore.gamemode) == mode, ) - ).first() + .order_by(col(BestScore.total_score).desc()) + .limit(1) + ) + if mods: + self_query = self_query.where( + text( + "JSON_CONTAINS(total_score_best_scores.mods, :w)" + " AND JSON_CONTAINS(:w, total_score_best_scores.mods)" + ) + ).params(w=json.dumps(mods)) + user_bs = (await session.exec(self_query)).first() + if user_bs: + user_score = user_bs.score if user_score and user_score not in scores: scores.append(user_score) - return scores + return scores, user_score async def get_score_position_by_user( session: AsyncSession, - beatmap_md5: str, + beatmap: int, user: User, mode: GameMode, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, ) -> int: - where_clause = [ - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - ] - if mods and user.is_supporter: - where_clause.append(Score.mods == mods) - else: - where_clause.append(false()) - if type == LeaderboardType.FRIENDS and user.is_supporter: - # TODO - ... - elif type == LeaderboardType.TEAM and user.team_membership: - team_id = user.team_membership.team_id - where_clause.append( - col(Score.user.team_membership).is_not(None), - ) - where_clause.append( - Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess] - ) + wheres = await _score_where(type, beatmap, mode, mods, user=user) + if wheres is None: + return 0 rownum = ( func.row_number() .over( - partition_by=Score.map_md5, - order_by=col(Score.total_score).desc(), + partition_by=col(BestScore.beatmap_id), + order_by=col(BestScore.total_score).desc(), ) .label("row_number") ) - subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery() - stmt = select(subq.c.row_number).where(subq.c.user == user) + subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery() + stmt = select(subq.c.row_number).where(subq.c.user_id == user.id) result = await session.exec(stmt) s = result.one_or_none() return s if s else 0 @@ -384,57 +390,26 @@ async def get_score_position_by_user( async def get_score_position_by_id( session: AsyncSession, - beatmap_md5: str, + beatmap: int, score_id: int, mode: GameMode, user: User | None = None, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, ) -> int: - where_clause = [ - Score.map_md5 == beatmap_md5, - Score.id == score_id, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - ] - if mods and user and user.is_supporter: - where_clause.append(Score.mods == mods) - elif mods: - where_clause.append(false()) + wheres = await _score_where(type, beatmap, mode, mods, user=user) + if wheres is None: + return 0 rownum = ( func.row_number() .over( - partition_by=[col(Score.user_id), col(Score.map_md5)], - order_by=col(Score.total_score).desc(), + partition_by=col(BestScore.beatmap_id), + order_by=col(BestScore.total_score).desc(), ) - .label("rownum") + .label("row_number") ) - subq = ( - select(Score.user_id, Score.id, Score.total_score, rownum) - .join(Beatmap) - .where(*where_clause) - .subquery() - ) - best_scores = aliased(subq) - overall_rank = ( - func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank") - ) - final_q = ( - select(best_scores.c.id, overall_rank) - .select_from(best_scores) - .where(best_scores.c.rownum == 1) - .subquery() - ) - - stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id) + subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery() + stmt = select(subq.c.row_number).where(subq.c.score_id == score_id) result = await session.exec(stmt) s = result.one_or_none() return s if s else 0 @@ -445,16 +420,38 @@ async def get_user_best_score_in_beatmap( beatmap: int, user: int, mode: GameMode | None = None, -) -> Score | None: +) -> BestScore | None: return ( await session.exec( - Score.select_clause(False) + select(BestScore) .where( - Score.gamemode == mode if mode is not None else True, - Score.beatmap_id == beatmap, - Score.user_id == user, + BestScore.gamemode == mode if mode is not None else true(), + BestScore.beatmap_id == beatmap, + BestScore.user_id == user, ) - .order_by(col(Score.total_score).desc()) + .order_by(col(BestScore.total_score).desc()) + ) + ).first() + + +# FIXME +async def get_user_best_score_with_mod_in_beatmap( + session: AsyncSession, + beatmap: int, + user: int, + mod: list[str], + mode: GameMode | None = None, +) -> BestScore | None: + return ( + await session.exec( + select(BestScore) + .where( + BestScore.gamemode == mode if mode is not None else True, + BestScore.beatmap_id == beatmap, + BestScore.user_id == user, + # BestScore.mods == mod, + ) + .order_by(col(BestScore.total_score).desc()) ) ).first() @@ -464,13 +461,13 @@ async def get_user_best_pp_in_beatmap( beatmap: int, user: int, mode: GameMode, -) -> BestScore | None: +) -> PPBestScore | None: return ( await session.exec( - select(BestScore).where( - BestScore.beatmap_id == beatmap, - BestScore.user_id == user, - BestScore.gamemode == mode, + select(PPBestScore).where( + PPBestScore.beatmap_id == beatmap, + PPBestScore.user_id == user, + PPBestScore.gamemode == mode, ) ) ).first() @@ -480,12 +477,12 @@ async def get_user_best_pp( session: AsyncSession, user: int, limit: int = 200, -) -> Sequence[BestScore]: +) -> Sequence[PPBestScore]: return ( await session.exec( - select(BestScore) - .where(BestScore.user_id == user) - .order_by(col(BestScore.pp).desc()) + select(PPBestScore) + .where(PPBestScore.user_id == user) + .order_by(col(PPBestScore.pp).desc()) .limit(limit) ) ).all() @@ -494,27 +491,45 @@ async def get_user_best_pp( async def process_user( session: AsyncSession, user: User, score: Score, ranked: bool = False ): + assert user.id + assert score.id + mod_for_save = list({mod["acronym"] for mod in score.mods}) previous_score_best = await get_user_best_score_in_beatmap( session, score.beatmap_id, user.id, score.gamemode ) - statistics = None + previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap( + session, score.beatmap_id, user.id, mod_for_save, score.gamemode + ) add_to_db = False - for i in user.lazer_statistics: + mouthly_playcount = ( + await session.exec( + select(MonthlyPlaycounts).where( + MonthlyPlaycounts.user_id == user.id, + MonthlyPlaycounts.year == date.today().year, + MonthlyPlaycounts.month == date.today().month, + ) + ) + ).first() + if mouthly_playcount is None: + mouthly_playcount = MonthlyPlaycounts( + user_id=user.id, year=date.today().year, month=date.today().month + ) + add_to_db = True + statistics = None + for i in await user.awaitable_attrs.statistics: if i.mode == score.gamemode.value: statistics = i break if statistics is None: - statistics = LazerUserStatistics( - mode=score.gamemode.value, - user_id=user.id, + raise ValueError( + f"User {user.id} does not have statistics for mode {score.gamemode.value}" ) - add_to_db = True # pc, pt, tth, tts statistics.total_score += score.total_score difference = ( score.total_score - previous_score_best.total_score - if previous_score_best and previous_score_best.id != score.id + if previous_score_best else score.total_score ) if difference > 0 and score.passed and ranked: @@ -541,11 +556,48 @@ async def process_user( statistics.grade_sh -= 1 case Rank.A: statistics.grade_a -= 1 + else: + previous_score_best = BestScore( + user_id=user.id, + beatmap_id=score.beatmap_id, + gamemode=score.gamemode, + score_id=score.id, + total_score=score.total_score, + rank=score.rank, + mods=mod_for_save, + ) + session.add(previous_score_best) + statistics.ranked_score += difference statistics.level_current = calculate_score_to_level(statistics.ranked_score) statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) + if score.passed and ranked: + if previous_score_best_mod is not None: + previous_score_best_mod.mods = mod_for_save + previous_score_best_mod.score_id = score.id + previous_score_best_mod.rank = score.rank + previous_score_best_mod.total_score = score.total_score + elif ( + previous_score_best is not None and previous_score_best.score_id != score.id + ): + session.add( + BestScore( + user_id=user.id, + beatmap_id=score.beatmap_id, + gamemode=score.gamemode, + score_id=score.id, + total_score=score.total_score, + rank=score.rank, + mods=mod_for_save, + ) + ) statistics.play_count += 1 + mouthly_playcount.playcount += 1 statistics.play_time += int((score.ended_at - score.started_at).total_seconds()) + statistics.count_100 += score.n100 + score.nkatu + statistics.count_300 += score.n300 + score.ngeki + statistics.count_50 += score.n50 + statistics.count_miss += score.nmiss statistics.total_hits += ( score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu ) @@ -563,11 +615,8 @@ async def process_user( acc_sum = clamp(acc_sum, 0.0, 100.0) statistics.pp = pp_sum statistics.hit_accuracy = acc_sum - - statistics.updated_at = datetime.now(UTC) - if add_to_db: - session.add(statistics) + session.add(mouthly_playcount) await session.commit() await session.refresh(user) @@ -581,7 +630,10 @@ async def process_score( fetcher: "Fetcher", session: AsyncSession, redis: Redis, + item_id: int | None = None, + room_id: int | None = None, ) -> Score: + assert user.id can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods) score = Score( accuracy=info.accuracy, @@ -611,6 +663,8 @@ async def process_score( nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0), nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0), nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0), + playlist_item_id=item_id, + room_id=room_id, ) if can_get_pp: beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) @@ -628,7 +682,7 @@ async def process_score( ) if previous_pp_best is None or score.pp > previous_pp_best.pp: assert score.id - best_score = BestScore( + best_score = PPBestScore( user_id=user_id, score_id=score.id, beatmap_id=beatmap_id, @@ -637,7 +691,7 @@ async def process_score( acc=score.accuracy, ) session.add(best_score) - session.delete(previous_pp_best) if previous_pp_best else None + await session.delete(previous_pp_best) if previous_pp_best else None await session.commit() await session.refresh(score) await session.refresh(score_token) diff --git a/app/database/score_token.py b/app/database/score_token.py index 6a6edb3..4467b8b 100644 --- a/app/database/score_token.py +++ b/app/database/score_token.py @@ -1,15 +1,16 @@ from datetime import datetime +from app.models.model import UTCBaseModel from app.models.score import GameMode from .beatmap import Beatmap -from .user import User +from .lazer_user import User from sqlalchemy import Column, DateTime, Index from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel -class ScoreTokenBase(SQLModel): +class ScoreTokenBase(SQLModel, UTCBaseModel): score_id: int | None = Field(sa_column=Column(BigInteger), default=None) ruleset_id: GameMode playlist_item_id: int | None = Field(default=None) # playlist @@ -34,10 +35,10 @@ class ScoreToken(ScoreTokenBase, table=True): autoincrement=True, ), ) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) beatmap_id: int = Field(foreign_key="beatmaps.id") - user: "User" = Relationship() - beatmap: "Beatmap" = Relationship() + user: User = Relationship() + beatmap: Beatmap = Relationship() class ScoreTokenResp(ScoreTokenBase): diff --git a/app/database/statistics.py b/app/database/statistics.py new file mode 100644 index 0000000..cac2971 --- /dev/null +++ b/app/database/statistics.py @@ -0,0 +1,95 @@ +from typing import TYPE_CHECKING + +from app.models.score import GameMode + +from sqlmodel import ( + BigInteger, + Column, + Field, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .lazer_user import User + + +class UserStatisticsBase(SQLModel): + mode: GameMode + count_100: int = Field(default=0, sa_column=Column(BigInteger)) + count_300: int = Field(default=0, sa_column=Column(BigInteger)) + count_50: int = Field(default=0, sa_column=Column(BigInteger)) + count_miss: int = Field(default=0, sa_column=Column(BigInteger)) + + global_rank: int | None = Field(default=None) + country_rank: int | None = Field(default=None) + + pp: float = Field(default=0.0) + ranked_score: int = Field(default=0) + hit_accuracy: float = Field(default=0.00) + total_score: int = Field(default=0, sa_column=Column(BigInteger)) + total_hits: int = Field(default=0, sa_column=Column(BigInteger)) + maximum_combo: int = Field(default=0) + + play_count: int = Field(default=0) + play_time: int = Field(default=0, sa_column=Column(BigInteger)) + replays_watched_by_others: int = Field(default=0) + is_ranked: bool = Field(default=True) + + +class UserStatistics(UserStatisticsBase, table=True): + __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType] + id: int | None = Field(default=None, primary_key=True) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("lazer_users.id"), + index=True, + ), + ) + grade_ss: int = Field(default=0) + grade_ssh: int = Field(default=0) + grade_s: int = Field(default=0) + grade_sh: int = Field(default=0) + grade_a: int = Field(default=0) + + level_current: int = Field(default=1) + level_progress: int = Field(default=0) + + user: "User" = Relationship(back_populates="statistics") # type: ignore[valid-type] + + +class UserStatisticsResp(UserStatisticsBase): + grade_counts: dict[str, int] = Field( + default_factory=lambda: { + "ss": 0, + "ssh": 0, + "s": 0, + "sh": 0, + "a": 0, + } + ) + level: dict[str, int] = Field( + default_factory=lambda: { + "current": 1, + "progress": 0, + } + ) + + @classmethod + def from_db(cls, obj: UserStatistics) -> "UserStatisticsResp": + s = cls.model_validate(obj) + s.grade_counts = { + "ss": obj.grade_ss, + "ssh": obj.grade_ssh, + "s": obj.grade_s, + "sh": obj.grade_sh, + "a": obj.grade_a, + } + s.level = { + "current": obj.level_current, + "progress": obj.level_progress, + } + return s diff --git a/app/database/team.py b/app/database/team.py index 360e805..562b0c8 100644 --- a/app/database/team.py +++ b/app/database/team.py @@ -1,14 +1,16 @@ from datetime import datetime from typing import TYPE_CHECKING +from app.models.model import UTCBaseModel + from sqlalchemy import Column, DateTime from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel if TYPE_CHECKING: - from .user import User + from .lazer_user import User -class Team(SQLModel, table=True): +class Team(SQLModel, UTCBaseModel, table=True): __tablename__ = "teams" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) @@ -22,15 +24,19 @@ class Team(SQLModel, table=True): members: list["TeamMember"] = Relationship(back_populates="team") -class TeamMember(SQLModel, table=True): +class TeamMember(SQLModel, UTCBaseModel, table=True): __tablename__ = "team_members" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) team_id: int = Field(foreign_key="teams.id") joined_at: datetime = Field( default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: "User" = Relationship(back_populates="team_membership") - team: "Team" = Relationship(back_populates="members") + user: "User" = Relationship( + back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"} + ) + team: "Team" = Relationship( + back_populates="members", sa_relationship_kwargs={"lazy": "joined"} + ) diff --git a/app/database/user.py b/app/database/user.py deleted file mode 100644 index a188497..0000000 --- a/app/database/user.py +++ /dev/null @@ -1,527 +0,0 @@ -from dataclasses import dataclass -from datetime import datetime -from typing import Optional - -from .legacy import LegacyUserStatistics -from .team import TeamMember - -from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text -from sqlalchemy.dialects.mysql import VARCHAR -from sqlalchemy.orm import joinedload, selectinload -from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select - - -class User(SQLModel, table=True): - __tablename__ = "users" # pyright: ignore[reportAssignmentType] - - # 主键 - id: int = Field( - default=None, sa_column=Column(BigInteger, primary_key=True, index=True) - ) - - # 基本信息(匹配 migrations_old 中的结构) - name: str = Field(max_length=32, unique=True, index=True) # 用户名 - safe_name: str = Field(max_length=32, unique=True, index=True) # 安全用户名 - email: str = Field(max_length=254, unique=True, index=True) - priv: int = Field(default=1) # 权限 - pw_bcrypt: str = Field(max_length=60) # bcrypt 哈希密码 - country: str = Field(default="CN", max_length=2) # 国家代码 - - # 状态和时间 - silence_end: int = Field(default=0) - donor_end: int = Field(default=0) - creation_time: int = Field(default=0) # Unix 时间戳 - latest_activity: int = Field(default=0) # Unix 时间戳 - - # 游戏相关 - preferred_mode: int = Field(default=0) # 偏好游戏模式 - play_style: int = Field(default=0) # 游戏风格 - - # 扩展信息 - clan_id: int = Field(default=0) - clan_priv: int = Field(default=0) - custom_badge_name: str | None = Field(default=None, max_length=16) - custom_badge_icon: str | None = Field(default=None, max_length=64) - userpage_content: str | None = Field(default=None, max_length=2048) - api_key: str | None = Field(default=None, max_length=36, unique=True) - - # 虚拟字段用于兼容性 - @property - def username(self): - return self.name - - @property - def country_code(self): - return self.country - - @property - def join_date(self): - creation_time = getattr(self, "creation_time", 0) - return ( - datetime.fromtimestamp(creation_time) - if creation_time > 0 - else datetime.utcnow() - ) - - @property - def last_visit(self): - latest_activity = getattr(self, "latest_activity", 0) - return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None - - @property - def is_supporter(self): - return self.lazer_profile.is_supporter if self.lazer_profile else False - - # 关联关系 - lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user") - lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user") - lazer_counts: Optional["LazerUserCounts"] = Relationship(back_populates="user") - lazer_achievements: list["LazerUserAchievement"] = Relationship( - back_populates="user" - ) - lazer_profile_sections: list["LazerUserProfileSections"] = Relationship( - back_populates="user" - ) - statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user") - team_membership: Optional["TeamMember"] = Relationship(back_populates="user") - daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship( - back_populates="user" - ) - rank_history: list["RankHistory"] = Relationship(back_populates="user") - avatar: Optional["UserAvatar"] = Relationship(back_populates="user") - active_banners: list["LazerUserBanners"] = Relationship(back_populates="user") - lazer_badges: list["LazerUserBadge"] = Relationship(back_populates="user") - lazer_monthly_playcounts: list["LazerUserMonthlyPlaycounts"] = Relationship( - back_populates="user" - ) - lazer_previous_usernames: list["LazerUserPreviousUsername"] = Relationship( - back_populates="user" - ) - lazer_replays_watched: list["LazerUserReplaysWatched"] = Relationship( - back_populates="user" - ) - - @classmethod - def all_select_option(cls): - return ( - joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType] - joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType] - joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] - joinedload(cls.avatar), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType] - selectinload(cls.statistics), # pyright: ignore[reportArgumentType] - joinedload(cls.team_membership), # pyright: ignore[reportArgumentType] - selectinload(cls.rank_history), # pyright: ignore[reportArgumentType] - selectinload(cls.active_banners), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType] - ) - - @classmethod - def all_select_clause(cls): - return select(cls).options(*cls.all_select_option()) - - -# ============================================ -# Lazer API 专用表模型 -# ============================================ - - -class LazerUserProfile(SQLModel, table=True): - __tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - - # 基本状态字段 - is_active: bool = Field(default=True) - is_bot: bool = Field(default=False) - is_deleted: bool = Field(default=False) - is_online: bool = Field(default=True) - is_supporter: bool = Field(default=False) - is_restricted: bool = Field(default=False) - session_verified: bool = Field(default=False) - has_supported: bool = Field(default=False) - pm_friends_only: bool = Field(default=False) - - # 基本资料字段 - default_group: str = Field(default="default", max_length=50) - last_visit: datetime | None = Field(default=None, sa_column=Column(DateTime)) - join_date: datetime | None = Field(default=None, sa_column=Column(DateTime)) - profile_colour: str | None = Field(default=None, max_length=7) - profile_hue: int | None = Field(default=None) - - # 社交媒体和个人资料字段 - avatar_url: str | None = Field(default=None, max_length=500) - cover_url: str | None = Field(default=None, max_length=500) - discord: str | None = Field(default=None, max_length=100) - twitter: str | None = Field(default=None, max_length=100) - website: str | None = Field(default=None, max_length=500) - title: str | None = Field(default=None, max_length=100) - title_url: str | None = Field(default=None, max_length=500) - interests: str | None = Field(default=None, sa_column=Column(Text)) - location: str | None = Field(default=None, max_length=100) - - occupation: str | None = Field(default=None) # 职业字段,默认为 None - - # 游戏相关字段 - playmode: str = Field(default="osu", max_length=10) - support_level: int = Field(default=0) - max_blocks: int = Field(default=100) - max_friends: int = Field(default=500) - post_count: int = Field(default=0) - - # 页面内容 - page_html: str | None = Field(default=None, sa_column=Column(Text)) - page_raw: str | None = Field(default=None, sa_column=Column(Text)) - - profile_order: str = Field( - default="me,recent_activity,top_ranks,medals,historical,beatmaps,kudosu" - ) - - # 关联关系 - user: "User" = Relationship(back_populates="lazer_profile") - - -class LazerUserProfileSections(SQLModel, table=True): - __tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - section_name: str = Field(sa_column=Column(VARCHAR(50))) - display_order: int | None = Field(default=None) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_profile_sections") - - -class LazerUserCountry(SQLModel, table=True): - __tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - code: str = Field(max_length=2) - name: str = Field(max_length=100) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - -class LazerUserKudosu(SQLModel, table=True): - __tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - available: int = Field(default=0) - total: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - -class LazerUserCounts(SQLModel, table=True): - __tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - - # 统计计数字段 - beatmap_playcounts_count: int = Field(default=0) - comments_count: int = Field(default=0) - favourite_beatmapset_count: int = Field(default=0) - follower_count: int = Field(default=0) - graveyard_beatmapset_count: int = Field(default=0) - guest_beatmapset_count: int = Field(default=0) - loved_beatmapset_count: int = Field(default=0) - mapping_follower_count: int = Field(default=0) - nominated_beatmapset_count: int = Field(default=0) - pending_beatmapset_count: int = Field(default=0) - ranked_beatmapset_count: int = Field(default=0) - ranked_and_approved_beatmapset_count: int = Field(default=0) - unranked_beatmapset_count: int = Field(default=0) - scores_best_count: int = Field(default=0) - scores_first_count: int = Field(default=0) - scores_pinned_count: int = Field(default=0) - scores_recent_count: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - # 关联关系 - user: "User" = Relationship(back_populates="lazer_counts") - - -class LazerUserStatistics(SQLModel, table=True): - __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - mode: str = Field(default="osu", max_length=10, primary_key=True) - - # 基本命中统计 - count_100: int = Field(default=0) - count_300: int = Field(default=0) - count_50: int = Field(default=0) - count_miss: int = Field(default=0) - - # 等级信息 - level_current: int = Field(default=1) - level_progress: int = Field(default=0) - - # 排名信息 - global_rank: int | None = Field(default=None) - global_rank_exp: int | None = Field(default=None) - country_rank: int | None = Field(default=None) - - # PP 和分数 - pp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2))) - pp_exp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2))) - ranked_score: int = Field(default=0, sa_column=Column(BigInteger)) - hit_accuracy: float = Field(default=0.00, sa_column=Column(DECIMAL(5, 2))) - total_score: int = Field(default=0, sa_column=Column(BigInteger)) - total_hits: int = Field(default=0, sa_column=Column(BigInteger)) - maximum_combo: int = Field(default=0) - - # 游戏统计 - play_count: int = Field(default=0) - play_time: int = Field(default=0) # 秒 - replays_watched_by_others: int = Field(default=0) - is_ranked: bool = Field(default=False) - - # 成绩等级计数 - grade_ss: int = Field(default=0) - grade_ssh: int = Field(default=0) - grade_s: int = Field(default=0) - grade_sh: int = Field(default=0) - grade_a: int = Field(default=0) - - # 最高排名记录 - rank_highest: int | None = Field(default=None) - rank_highest_updated_at: datetime | None = Field( - default=None, sa_column=Column(DateTime) - ) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - # 关联关系 - user: "User" = Relationship(back_populates="lazer_statistics") - - -class LazerUserBanners(SQLModel, table=True): - __tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - tournament_id: int - image_url: str = Field(sa_column=Column(VARCHAR(500))) - is_active: bool | None = Field(default=None) - - # 修正user关系的back_populates值 - user: "User" = Relationship(back_populates="active_banners") - - -class LazerUserAchievement(SQLModel, table=True): - __tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - achievement_id: int - achieved_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_achievements") - - -class LazerUserBadge(SQLModel, table=True): - __tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - badge_id: int - awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime)) - description: str | None = Field(default=None, sa_column=Column(Text)) - image_url: str | None = Field(default=None, max_length=500) - url: str | None = Field(default=None, max_length=500) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_badges") - - -class LazerUserMonthlyPlaycounts(SQLModel, table=True): - __tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - start_date: datetime = Field(sa_column=Column(Date)) - play_count: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_monthly_playcounts") - - -class LazerUserPreviousUsername(SQLModel, table=True): - __tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - username: str = Field(max_length=32) - changed_at: datetime = Field(sa_column=Column(DateTime)) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_previous_usernames") - - -class LazerUserReplaysWatched(SQLModel, table=True): - __tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - start_date: datetime = Field(sa_column=Column(Date)) - count: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_replays_watched") - - -# 类型转换用的 UserAchievement(不是 SQLAlchemy 模型) -@dataclass -class UserAchievement: - achieved_at: datetime - achievement_id: int - - -class DailyChallengeStats(SQLModel, table=True): - __tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True) - ) - - daily_streak_best: int = Field(default=0) - daily_streak_current: int = Field(default=0) - last_update: datetime | None = Field(default=None, sa_column=Column(DateTime)) - last_weekly_streak: datetime | None = Field( - default=None, sa_column=Column(DateTime) - ) - playcount: int = Field(default=0) - top_10p_placements: int = Field(default=0) - top_50p_placements: int = Field(default=0) - weekly_streak_best: int = Field(default=0) - weekly_streak_current: int = Field(default=0) - - user: "User" = Relationship(back_populates="daily_challenge_stats") - - -class RankHistory(SQLModel, table=True): - __tablename__ = "rank_history" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - mode: str = Field(max_length=10) - rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks - date_recorded: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="rank_history") - - -class UserAvatar(SQLModel, table=True): - __tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - filename: str = Field(max_length=255) - original_filename: str = Field(max_length=255) - file_size: int - mime_type: str = Field(max_length=100) - is_active: bool = Field(default=True) - created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) - updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) - r2_original_url: str | None = Field(default=None, max_length=500) - r2_game_url: str | None = Field(default=None, max_length=500) - - user: "User" = Relationship(back_populates="avatar") diff --git a/app/database/user_account_history.py b/app/database/user_account_history.py new file mode 100644 index 0000000..217c8eb --- /dev/null +++ b/app/database/user_account_history.py @@ -0,0 +1,45 @@ +from datetime import UTC, datetime +from enum import Enum + +from app.models.model import UTCBaseModel + +from sqlmodel import BigInteger, Column, Field, ForeignKey, Integer, SQLModel + + +class UserAccountHistoryType(str, Enum): + NOTE = "note" + RESTRICTION = "restriction" + SLIENCE = "silence" + TOURNAMENT_BAN = "tournament_ban" + + +class UserAccountHistoryBase(SQLModel, UTCBaseModel): + description: str | None = None + length: int + permanent: bool = False + timestamp: datetime = Field(default=datetime.now(UTC)) + type: UserAccountHistoryType + + +class UserAccountHistory(UserAccountHistoryBase, table=True): + __tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType] + + id: int | None = Field( + sa_column=Column( + Integer, + autoincrement=True, + index=True, + primary_key=True, + ) + ) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + + +class UserAccountHistoryResp(UserAccountHistoryBase): + id: int | None = None + + @classmethod + def from_db(cls, db_model: UserAccountHistory) -> "UserAccountHistoryResp": + return cls.model_validate(db_model) diff --git a/app/dependencies/database.py b/app/dependencies/database.py index fe09139..77b15c3 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -5,15 +5,11 @@ import json from app.config import settings from pydantic import BaseModel +import redis.asyncio as redis from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession -try: - import redis -except ImportError: - redis = None - def json_serializer(value): if isinstance(value, BaseModel | SQLModel): @@ -25,10 +21,7 @@ def json_serializer(value): engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer) # Redis 连接 -if redis: - redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) -else: - redis_client = None +redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) # 数据库依赖 diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py index d3c216a..806eb87 100644 --- a/app/dependencies/fetcher.py +++ b/app/dependencies/fetcher.py @@ -8,7 +8,7 @@ from app.log import logger fetcher: Fetcher | None = None -def get_fetcher() -> Fetcher: +async def get_fetcher() -> Fetcher: global fetcher if fetcher is None: fetcher = Fetcher( @@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher: settings.FETCHER_CALLBACK_URL, ) redis = get_redis() - if redis: - access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}") - if access_token: - fetcher.access_token = str(access_token) - refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}") - if refresh_token: - fetcher.refresh_token = str(refresh_token) - if not fetcher.access_token or not fetcher.refresh_token: - logger.opt(colors=True).info( - f"Login to initialize fetcher: {fetcher.authorize_url}" - ) + access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}") + if access_token: + fetcher.access_token = str(access_token) + refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}") + if refresh_token: + fetcher.refresh_token = str(refresh_token) + if not fetcher.access_token or not fetcher.refresh_token: + logger.opt(colors=True).info( + f"Login to initialize fetcher: {fetcher.authorize_url}" + ) return fetcher diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 0c8f8bc..5537f4f 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -1,14 +1,13 @@ from __future__ import annotations from app.auth import get_token_by_access_token -from app.database import ( - User as DBUser, -) +from app.database import User from .database import get_db from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession security = HTTPBearer() @@ -17,7 +16,7 @@ security = HTTPBearer() async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db), -) -> DBUser: +) -> User: """获取当前认证用户""" token = credentials.credentials @@ -27,13 +26,9 @@ async def get_current_user( return user -async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None: +async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None: token_record = await get_token_by_access_token(db, token) if not token_record: return None - user = ( - await db.exec( - DBUser.all_select_clause().where(DBUser.id == token_record.user_id) - ) - ).first() + user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() return user diff --git a/app/signalr/exception.py b/app/exception.py similarity index 100% rename from app/signalr/exception.py rename to app/exception.py diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py index 08e3508..2717a35 100644 --- a/app/fetcher/_base.py +++ b/app/fetcher/_base.py @@ -59,16 +59,15 @@ class BaseFetcher: self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() - if redis: - redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) async def refresh_access_token(self) -> None: async with AsyncClient() as client: @@ -87,13 +86,12 @@ class BaseFetcher: self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() - if redis: - redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py index 08b8dfc..6e18435 100644 --- a/app/fetcher/osu_dot_direct.py +++ b/app/fetcher/osu_dot_direct.py @@ -4,7 +4,7 @@ from ._base import BaseFetcher from httpx import AsyncClient from loguru import logger -import redis +import redis.asyncio as redis class OsuDotDirectFetcher(BaseFetcher): @@ -22,8 +22,8 @@ class OsuDotDirectFetcher(BaseFetcher): async def get_or_fetch_beatmap_raw( self, redis: redis.Redis, beatmap_id: int ) -> str: - if redis.exists(f"beatmap:{beatmap_id}:raw"): - return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] + if await redis.exists(f"beatmap:{beatmap_id}:raw"): + return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] raw = await self.get_beatmap_raw(beatmap_id) - redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) + await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) return raw diff --git a/app/models/beatmap.py b/app/models/beatmap.py index 4f12e13..fae18ba 100644 --- a/app/models/beatmap.py +++ b/app/models/beatmap.py @@ -42,11 +42,12 @@ class Language(IntEnum): KOREAN = 6 FRENCH = 7 GERMAN = 8 - ITALIAN = 9 - SPANISH = 10 - RUSSIAN = 11 - POLISH = 12 - OTHER = 13 + SWEDISH = 9 + ITALIAN = 10 + SPANISH = 11 + RUSSIAN = 12 + POLISH = 13 + OTHER = 14 class BeatmapAttributes(BaseModel): diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py index 8ae3e65..684ab54 100644 --- a/app/models/metadata_hub.py +++ b/app/models/metadata_hub.py @@ -1,114 +1,85 @@ from __future__ import annotations from enum import IntEnum -from typing import Any, Literal +from typing import ClassVar, Literal -from app.models.signalr import UserState +from app.models.signalr import SignalRUnionMessage, UserState -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel -class _UserActivity(BaseModel): - model_config = ConfigDict(serialize_by_alias=True) - type: Literal[ - "ChoosingBeatmap", - "InSoloGame", - "WatchingReplay", - "SpectatingUser", - "SearchingForLobby", - "InLobby", - "InMultiplayerGame", - "SpectatingMultiplayerGame", - "InPlaylistGame", - "EditingBeatmap", - "ModdingBeatmap", - "TestingBeatmap", - "InDailyChallengeLobby", - "PlayingDailyChallenge", - ] = Field(alias="$dtype") - value: Any | None = Field(alias="$value") +class _UserActivity(SignalRUnionMessage): ... class ChoosingBeatmap(_UserActivity): - type: Literal["ChoosingBeatmap"] = Field(alias="$dtype") - - -class InGameValue(BaseModel): - beatmap_id: int = Field(alias="BeatmapID") - beatmap_display_title: str = Field(alias="BeatmapDisplayTitle") - ruleset_id: int = Field(alias="RulesetID") - ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb") + union_type: ClassVar[Literal[11]] = 11 class _InGame(_UserActivity): - value: InGameValue = Field(alias="$value") + beatmap_id: int + beatmap_display_title: str + ruleset_id: int + ruleset_playing_verb: str class InSoloGame(_InGame): - type: Literal["InSoloGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[12]] = 12 class InMultiplayerGame(_InGame): - type: Literal["InMultiplayerGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[23]] = 23 class SpectatingMultiplayerGame(_InGame): - type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[24]] = 24 class InPlaylistGame(_InGame): - type: Literal["InPlaylistGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[31]] = 31 -class EditingBeatmapValue(BaseModel): - beatmap_id: int = Field(alias="BeatmapID") - beatmap_display_title: str = Field(alias="BeatmapDisplayTitle") +class PlayingDailyChallenge(_InGame): + union_type: ClassVar[Literal[52]] = 52 class EditingBeatmap(_UserActivity): - type: Literal["EditingBeatmap"] = Field(alias="$dtype") - value: EditingBeatmapValue = Field(alias="$value") + union_type: ClassVar[Literal[41]] = 41 + beatmap_id: int + beatmap_display_title: str -class TestingBeatmap(_UserActivity): - type: Literal["TestingBeatmap"] = Field(alias="$dtype") +class TestingBeatmap(EditingBeatmap): + union_type: ClassVar[Literal[43]] = 43 -class ModdingBeatmap(_UserActivity): - type: Literal["ModdingBeatmap"] = Field(alias="$dtype") - - -class WatchingReplayValue(BaseModel): - score_id: int = Field(alias="ScoreID") - player_name: str = Field(alias="PlayerName") - beatmap_id: int = Field(alias="BeatmapID") - beatmap_display_title: str = Field(alias="BeatmapDisplayTitle") +class ModdingBeatmap(EditingBeatmap): + union_type: ClassVar[Literal[42]] = 42 class WatchingReplay(_UserActivity): - type: Literal["WatchingReplay"] = Field(alias="$dtype") - value: int | None = Field(alias="$value") # Replay ID + union_type: ClassVar[Literal[13]] = 13 + score_id: int + player_name: str + beatmap_id: int + beatmap_display_title: str class SpectatingUser(WatchingReplay): - type: Literal["SpectatingUser"] = Field(alias="$dtype") + union_type: ClassVar[Literal[14]] = 14 class SearchingForLobby(_UserActivity): - type: Literal["SearchingForLobby"] = Field(alias="$dtype") - - -class InLobbyValue(BaseModel): - room_id: int = Field(alias="RoomID") - room_name: str = Field(alias="RoomName") + union_type: ClassVar[Literal[21]] = 21 class InLobby(_UserActivity): - type: Literal["InLobby"] = "InLobby" + union_type: ClassVar[Literal[22]] = 22 + room_id: int + room_name: str class InDailyChallengeLobby(_UserActivity): - type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype") + union_type: ClassVar[Literal[51]] = 51 UserActivity = ( @@ -128,23 +99,25 @@ UserActivity = ( ) -class MetadataClientState(UserState): - user_activity: UserActivity | None = None - status: OnlineStatus | None = None +class UserPresence(BaseModel): + activity: UserActivity | None = None - def to_dict(self) -> dict[str, Any] | None: - if self.status is None or self.status == OnlineStatus.OFFLINE: - return None - dumped = self.model_dump(by_alias=True, exclude_none=True) - return { - "Activity": dumped.get("user_activity"), - "Status": dumped.get("status"), - } + status: OnlineStatus | None = None @property def pushable(self) -> bool: return self.status is not None and self.status != OnlineStatus.OFFLINE + @property + def for_push(self) -> "UserPresence | None": + return UserPresence( + activity=self.activity, + status=self.status, + ) + + +class MetadataClientState(UserPresence, UserState): ... + class OnlineStatus(IntEnum): OFFLINE = 0 # 隐身 diff --git a/app/models/model.py b/app/models/model.py new file mode 100644 index 0000000..5ba8093 --- /dev/null +++ b/app/models/model.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from pydantic import BaseModel, field_serializer + + +class UTCBaseModel(BaseModel): + @field_serializer("*", when_used="json") + def serialize_datetime(self, v, _info): + if isinstance(v, datetime): + if v.tzinfo is None: + v = v.replace(tzinfo=UTC) + return v.astimezone(UTC).isoformat() + return v + + +Cursor = dict[str, int] + + +class RespWithCursor(BaseModel): + cursor: Cursor | None = None diff --git a/app/models/mods.py b/app/models/mods.py index abcd2cd..299a05f 100644 --- a/app/models/mods.py +++ b/app/models/mods.py @@ -8,7 +8,7 @@ from app.path import STATIC_DIR class APIMod(TypedDict): acronym: str - settings: NotRequired[dict[str, bool | float | str]] + settings: NotRequired[dict[str, bool | float | str | int]] # https://github.com/ppy/osu-api/wiki#mods diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py new file mode 100644 index 0000000..09d8900 --- /dev/null +++ b/app/models/multiplayer_hub.py @@ -0,0 +1,925 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +import asyncio +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from enum import IntEnum +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + ClassVar, + Literal, + TypedDict, + cast, + override, +) + +from app.database.beatmap import Beatmap +from app.dependencies.database import engine +from app.exception import InvokeException + +from .mods import APIMod +from .room import ( + DownloadState, + MatchType, + MultiplayerRoomState, + MultiplayerUserState, + QueueMode, + RoomCategory, + RoomStatus, +) +from .signalr import ( + SignalRMeta, + SignalRUnionMessage, + UserState, +) + +from pydantic import BaseModel, Field +from sqlalchemy import update +from sqlmodel import col +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from app.signalr.hub import MultiplayerHub + +HOST_LIMIT = 50 +PER_USER_LIMIT = 3 + + +class MultiplayerClientState(UserState): + room_id: int = 0 + + +class MultiplayerRoomSettings(BaseModel): + name: str = "Unnamed Room" + playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)] + password: str = "" + match_type: MatchType = MatchType.HEAD_TO_HEAD + queue_mode: QueueMode = QueueMode.HOST_ONLY + auto_start_duration: timedelta = timedelta(seconds=0) + auto_skip: bool = False + + @property + def auto_start_enabled(self) -> bool: + return self.auto_start_duration != timedelta(seconds=0) + + +class BeatmapAvailability(BaseModel): + state: DownloadState = DownloadState.UNKNOWN + download_progress: float | None = None + + +class _MatchUserState(SignalRUnionMessage): ... + + +class TeamVersusUserState(_MatchUserState): + team_id: int + + union_type: ClassVar[Literal[0]] = 0 + + +MatchUserState = TeamVersusUserState + + +class _MatchRoomState(SignalRUnionMessage): ... + + +class MultiplayerTeam(BaseModel): + id: int + name: str + + +class TeamVersusRoomState(_MatchRoomState): + teams: list[MultiplayerTeam] = Field( + default_factory=lambda: [ + MultiplayerTeam(id=0, name="Team Red"), + MultiplayerTeam(id=1, name="Team Blue"), + ] + ) + + union_type: ClassVar[Literal[0]] = 0 + + +MatchRoomState = TeamVersusRoomState + + +class PlaylistItem(BaseModel): + id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)] + owner_id: int + beatmap_id: int + beatmap_checksum: str + ruleset_id: int + required_mods: list[APIMod] = Field(default_factory=list) + allowed_mods: list[APIMod] = Field(default_factory=list) + expired: bool + playlist_order: int + played_at: datetime | None = None + star_rating: float + freestyle: bool + + def _get_api_mods(self): + from app.models.mods import API_MODS, init_mods + + if not API_MODS: + init_mods() + return API_MODS + + def _validate_mod_for_ruleset( + self, mod: APIMod, ruleset_key: int, context: str = "mod" + ) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + + # Check if mod is valid for ruleset + if ( + typed_ruleset_key not in API_MODS + or mod["acronym"] not in API_MODS[typed_ruleset_key] + ): + raise InvokeException( + f"{context} {mod['acronym']} is invalid for this ruleset" + ) + + mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]] + + # Check if mod is unplayable in multiplayer + if mod_settings.get("UserPlayable", True) is False: + raise InvokeException( + f"{context} {mod['acronym']} is not playable by users" + ) + + if mod_settings.get("ValidForMultiplayer", True) is False: + raise InvokeException( + f"{context} {mod['acronym']} is not valid for multiplayer" + ) + + def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + + for i, mod1 in enumerate(mods): + mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"]) + if mod1_settings: + incompatible = set(mod1_settings.get("IncompatibleMods", [])) + for mod2 in mods[i + 1 :]: + if mod2["acronym"] in incompatible: + raise InvokeException( + f"Mods {mod1['acronym']} and " + f"{mod2['acronym']} are incompatible" + ) + + def _check_required_allowed_compatibility(self, ruleset_key: int) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} + + for req_mod in self.required_mods: + req_acronym = req_mod["acronym"] + req_settings = API_MODS[typed_ruleset_key].get(req_acronym) + if req_settings: + incompatible = set(req_settings.get("IncompatibleMods", [])) + conflicting_allowed = allowed_acronyms & incompatible + if conflicting_allowed: + conflict_list = ", ".join(conflicting_allowed) + raise InvokeException( + f"Required mod {req_acronym} conflicts with " + f"allowed mods: {conflict_list}" + ) + + def validate_playlist_item_mods(self) -> None: + ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id) + + # Validate required mods + for mod in self.required_mods: + self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod") + + # Validate allowed mods + for mod in self.allowed_mods: + self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod") + + # Check internal compatibility of required mods + self._check_mod_compatibility(self.required_mods, ruleset_key) + + # Check compatibility between required and allowed mods + self._check_required_allowed_compatibility(ruleset_key) + + def validate_user_mods( + self, + user: "MultiplayerRoomUser", + proposed_mods: list[APIMod], + ) -> tuple[bool, list[APIMod]]: + """ + Validates user mods against playlist item rules and returns valid mods. + Returns (is_valid, valid_mods). + """ + from typing import Literal, cast + + API_MODS = self._get_api_mods() + + ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id + ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id) + + valid_mods = [] + all_proposed_valid = True + + # Check if mods are valid for the ruleset + for mod in proposed_mods: + if ( + ruleset_key not in API_MODS + or mod["acronym"] not in API_MODS[ruleset_key] + ): + all_proposed_valid = False + continue + valid_mods.append(mod) + + # Check mod compatibility within user mods + incompatible_mods = set() + final_valid_mods = [] + for mod in valid_mods: + if mod["acronym"] in incompatible_mods: + all_proposed_valid = False + continue + setting_mods = API_MODS[ruleset_key].get(mod["acronym"]) + if setting_mods: + incompatible_mods.update(setting_mods["IncompatibleMods"]) + final_valid_mods.append(mod) + + # If not freestyle, check against allowed mods + if not self.freestyle: + allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} + filtered_valid_mods = [] + for mod in final_valid_mods: + if mod["acronym"] not in allowed_acronyms: + all_proposed_valid = False + else: + filtered_valid_mods.append(mod) + final_valid_mods = filtered_valid_mods + + # Check compatibility with required mods + required_mod_acronyms = {mod["acronym"] for mod in self.required_mods} + all_mod_acronyms = { + mod["acronym"] for mod in final_valid_mods + } | required_mod_acronyms + + # Check for incompatibility between required and user mods + filtered_valid_mods = [] + for mod in final_valid_mods: + mod_acronym = mod["acronym"] + is_compatible = True + + for other_acronym in all_mod_acronyms: + if other_acronym == mod_acronym: + continue + setting_mods = API_MODS[ruleset_key].get(mod_acronym) + if setting_mods and other_acronym in setting_mods["IncompatibleMods"]: + is_compatible = False + all_proposed_valid = False + break + + if is_compatible: + filtered_valid_mods.append(mod) + + return all_proposed_valid, filtered_valid_mods + + def clone(self) -> "PlaylistItem": + copy = self.model_copy() + copy.required_mods = list(self.required_mods) + copy.allowed_mods = list(self.allowed_mods) + copy.expired = False + copy.played_at = None + return copy + + +class _MultiplayerCountdown(SignalRUnionMessage): + id: int = 0 + time_remaining: timedelta + is_exclusive: Annotated[ + bool, Field(default=True), SignalRMeta(member_ignore=True) + ] = True + + +class MatchStartCountdown(_MultiplayerCountdown): + union_type: ClassVar[Literal[0]] = 0 + + +class ForceGameplayStartCountdown(_MultiplayerCountdown): + union_type: ClassVar[Literal[1]] = 1 + + +class ServerShuttingDownCountdown(_MultiplayerCountdown): + union_type: ClassVar[Literal[2]] = 2 + + +MultiplayerCountdown = ( + MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown +) + + +class MultiplayerRoomUser(BaseModel): + user_id: int + state: MultiplayerUserState = MultiplayerUserState.IDLE + availability: BeatmapAvailability = BeatmapAvailability( + state=DownloadState.UNKNOWN, download_progress=None + ) + mods: list[APIMod] = Field(default_factory=list) + match_state: MatchUserState | None = None + ruleset_id: int | None = None # freestyle + beatmap_id: int | None = None # freestyle + + +class MultiplayerRoom(BaseModel): + room_id: int + state: MultiplayerRoomState + settings: MultiplayerRoomSettings + users: list[MultiplayerRoomUser] = Field(default_factory=list) + host: MultiplayerRoomUser | None = None + match_state: MatchRoomState | None = None + playlist: list[PlaylistItem] = Field(default_factory=list) + active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list) + channel_id: int + + @classmethod + def from_db(cls, room) -> "MultiplayerRoom": + """ + 将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型) + """ + + # 用户列表 + users = [MultiplayerRoomUser(user_id=room.host_id)] + host_user = MultiplayerRoomUser(user_id=room.host_id) + # playlist 转换 + playlist = [] + if hasattr(room, "playlist"): + for item in room.playlist: + playlist.append( + PlaylistItem( + id=item.id, + owner_id=item.owner_id, + beatmap_id=item.beatmap_id, + beatmap_checksum=item.beatmap.checksum if item.beatmap else "", + ruleset_id=item.ruleset_id, + required_mods=item.required_mods, + allowed_mods=item.allowed_mods, + expired=item.expired, + playlist_order=item.playlist_order, + played_at=item.played_at, + star_rating=item.beatmap.difficulty_rating + if item.beatmap is not None + else 0.0, + freestyle=item.freestyle, + ) + ) + + return cls( + room_id=room.id, + state=getattr(room, "state", MultiplayerRoomState.OPEN), + settings=MultiplayerRoomSettings( + name=room.name, + playlist_item_id=playlist[0].id if playlist else 0, + password=getattr(room, "password", ""), + match_type=room.type, + queue_mode=room.queue_mode, + auto_start_duration=timedelta(seconds=room.auto_start_duration), + auto_skip=room.auto_skip, + ), + users=users, + host=host_user, + match_state=None, + playlist=playlist, + active_countdowns=[], + channel_id=getattr(room, "channel_id", 0), + ) + + +class MultiplayerQueue: + def __init__(self, room: "ServerMultiplayerRoom"): + self.server_room = room + self.current_index = 0 + + @property + def hub(self) -> "MultiplayerHub": + return self.server_room.hub + + @property + def upcoming_items(self): + return sorted( + (item for item in self.room.playlist if not item.expired), + key=lambda i: i.playlist_order, + ) + + @property + def room(self): + return self.server_room.room + + async def update_order(self): + from app.database import Playlist + + match self.room.settings.queue_mode: + case QueueMode.ALL_PLAYERS_ROUND_ROBIN: + ordered_active_items = [] + + is_first_set = True + first_set_order_by_user_id = {} + + active_items = [item for item in self.room.playlist if not item.expired] + active_items.sort(key=lambda x: x.id) + + user_item_groups = {} + for item in active_items: + if item.owner_id not in user_item_groups: + user_item_groups[item.owner_id] = [] + user_item_groups[item.owner_id].append(item) + + max_items = max( + (len(items) for items in user_item_groups.values()), default=0 + ) + + for i in range(max_items): + current_set = [] + for user_id, items in user_item_groups.items(): + if i < len(items): + current_set.append(items[i]) + + if is_first_set: + current_set.sort( + key=lambda item: (item.playlist_order, item.id) + ) + ordered_active_items.extend(current_set) + first_set_order_by_user_id = { + item.owner_id: idx + for idx, item in enumerate(ordered_active_items) + } + else: + current_set.sort( + key=lambda item: first_set_order_by_user_id.get( + item.owner_id, 0 + ) + ) + ordered_active_items.extend(current_set) + + is_first_set = False + + for idx, item in enumerate(ordered_active_items): + item.playlist_order = idx + case _: + ordered_active_items = sorted( + (item for item in self.room.playlist if not item.expired), + key=lambda x: x.id, + ) + async with AsyncSession(engine) as session: + for idx, item in enumerate(ordered_active_items): + if item.playlist_order == idx: + continue + item.playlist_order = idx + await Playlist.update(item, self.room.room_id, session) + await self.hub.playlist_changed( + self.server_room, item, beatmap_changed=False + ) + + async def update_current_item(self): + upcoming_items = self.upcoming_items + next_item = ( + upcoming_items[0] + if upcoming_items + else max( + self.room.playlist, + key=lambda i: i.played_at or datetime.min, + ) + ) + self.current_index = self.room.playlist.index(next_item) + last_id = self.room.settings.playlist_item_id + self.room.settings.playlist_item_id = next_item.id + if last_id != next_item.id: + await self.hub.setting_changed(self.server_room, True) + + async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser): + from app.database import Playlist + + is_host = self.room.host and self.room.host.user_id == user.user_id + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host: + raise InvokeException("You are not the host") + + limit = HOST_LIMIT if is_host else PER_USER_LIMIT + if ( + len([True for u in self.room.playlist if u.owner_id == user.user_id]) + >= limit + ): + raise InvokeException(f"You can only have {limit} items in the queue") + + if item.freestyle and len(item.allowed_mods) > 0: + raise InvokeException("Freestyle items cannot have allowed mods") + + async with AsyncSession(engine) as session: + async with session: + beatmap = await session.get(Beatmap, item.beatmap_id) + if beatmap is None: + raise InvokeException("Beatmap not found") + if item.beatmap_checksum != beatmap.checksum: + raise InvokeException("Checksum mismatch") + + item.validate_playlist_item_mods() + item.owner_id = user.user_id + item.star_rating = float( + beatmap.difficulty_rating + ) # FIXME: beatmap use decimal + await Playlist.add_to_db(item, self.room.room_id, session) + self.room.playlist.append(item) + await self.hub.playlist_added(self.server_room, item) + await self.update_order() + await self.update_current_item() + + async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser): + from app.database import Playlist + + if item.freestyle and len(item.allowed_mods) > 0: + raise InvokeException("Freestyle items cannot have allowed mods") + + async with AsyncSession(engine) as session: + async with session: + beatmap = await session.get(Beatmap, item.beatmap_id) + if beatmap is None: + raise InvokeException("Beatmap not found") + if item.beatmap_checksum != beatmap.checksum: + raise InvokeException("Checksum mismatch") + + existing_item = next( + (i for i in self.room.playlist if i.id == item.id), None + ) + if existing_item is None: + raise InvokeException( + "Attempted to change an item that doesn't exist" + ) + + if existing_item.owner_id != user.user_id and self.room.host != user: + raise InvokeException( + "Attempted to change an item which is not owned by the user" + ) + + if existing_item.expired: + raise InvokeException( + "Attempted to change an item which has already been played" + ) + + item.validate_playlist_item_mods() + item.owner_id = user.user_id + item.star_rating = float(beatmap.difficulty_rating) + item.playlist_order = existing_item.playlist_order + + await Playlist.update(item, self.room.room_id, session) + + # Update item in playlist + for idx, playlist_item in enumerate(self.room.playlist): + if playlist_item.id == item.id: + self.room.playlist[idx] = item + break + + await self.hub.playlist_changed( + self.server_room, + item, + beatmap_changed=item.beatmap_checksum + != existing_item.beatmap_checksum, + ) + + async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser): + from app.database import Playlist + + item = next( + (i for i in self.room.playlist if i.id == playlist_item_id), + None, + ) + + if item is None: + raise InvokeException("Item does not exist in the room") + + # Check if it's the only item and current item + if item == self.current_item: + upcoming_items = [i for i in self.room.playlist if not i.expired] + if len(upcoming_items) == 1: + raise InvokeException("The only item in the room cannot be removed") + + if item.owner_id != user.user_id and self.room.host != user: + raise InvokeException( + "Attempted to remove an item which is not owned by the user" + ) + + if item.expired: + raise InvokeException( + "Attempted to remove an item which has already been played" + ) + + async with AsyncSession(engine) as session: + await Playlist.delete_item(item.id, self.room.room_id, session) + + self.room.playlist.remove(item) + self.current_index = self.room.playlist.index(self.upcoming_items[0]) + + await self.update_order() + await self.update_current_item() + await self.hub.playlist_removed(self.server_room, item.id) + + async def finish_current_item(self): + from app.database import Playlist + + async with AsyncSession(engine) as session: + played_at = datetime.now(UTC) + await session.execute( + update(Playlist) + .where( + col(Playlist.id) == self.current_item.id, + col(Playlist.room_id) == self.room.room_id, + ) + .values(expired=True, played_at=played_at) + ) + self.room.playlist[self.current_index].expired = True + self.room.playlist[self.current_index].played_at = played_at + await self.hub.playlist_changed(self.server_room, self.current_item, True) + await self.update_order() + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all( + playitem.expired for playitem in self.room.playlist + ): + assert self.room.host + await self.add_item(self.current_item.clone(), self.room.host) + + async def update_queue_mode(self): + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all( + playitem.expired for playitem in self.room.playlist + ): + assert self.room.host + await self.add_item(self.current_item.clone(), self.room.host) + await self.update_order() + await self.update_current_item() + + @property + def current_item(self): + return self.room.playlist[self.current_index] + + +@dataclass +class CountdownInfo: + countdown: MultiplayerCountdown + duration: timedelta + task: asyncio.Task | None = None + + def __init__(self, countdown: MultiplayerCountdown): + self.countdown = countdown + self.duration = ( + countdown.time_remaining + if countdown.time_remaining > timedelta(seconds=0) + else timedelta(seconds=0) + ) + + +class _MatchRequest(SignalRUnionMessage): ... + + +class ChangeTeamRequest(_MatchRequest): + union_type: ClassVar[Literal[0]] = 0 + team_id: int + + +class StartMatchCountdownRequest(_MatchRequest): + union_type: ClassVar[Literal[1]] = 1 + duration: timedelta + + +class StopCountdownRequest(_MatchRequest): + union_type: ClassVar[Literal[2]] = 2 + id: int + + +MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest + + +class MatchTypeHandler(ABC): + def __init__(self, room: "ServerMultiplayerRoom"): + self.room = room + self.hub = room.hub + + @abstractmethod + async def handle_join(self, user: MultiplayerRoomUser): ... + + @abstractmethod + async def handle_request( + self, user: MultiplayerRoomUser, request: MatchRequest + ): ... + + @abstractmethod + async def handle_leave(self, user: MultiplayerRoomUser): ... + + @abstractmethod + def get_details(self) -> MatchStartedEventDetail: ... + + +class HeadToHeadHandler(MatchTypeHandler): + @override + async def handle_join(self, user: MultiplayerRoomUser): + if user.match_state is not None: + user.match_state = None + await self.hub.change_user_match_state(self.room, user) + + @override + async def handle_request( + self, user: MultiplayerRoomUser, request: MatchRequest + ): ... + + @override + async def handle_leave(self, user: MultiplayerRoomUser): ... + + @override + def get_details(self) -> MatchStartedEventDetail: + detail = MatchStartedEventDetail(room_type="head_to_head", team=None) + return detail + + +class TeamVersusHandler(MatchTypeHandler): + @override + def __init__(self, room: "ServerMultiplayerRoom"): + super().__init__(room) + self.state = TeamVersusRoomState() + room.room.match_state = self.state + task = asyncio.create_task(self.hub.change_room_match_state(self.room)) + self.hub.tasks.add(task) + task.add_done_callback(self.hub.tasks.discard) + + def _get_best_available_team(self) -> int: + for team in self.state.teams: + if all( + ( + user.match_state is None + or not isinstance(user.match_state, TeamVersusUserState) + or user.match_state.team_id != team.id + ) + for user in self.room.room.users + ): + return team.id + + from collections import defaultdict + + team_counts = defaultdict(int) + for user in self.room.room.users: + if user.match_state is not None and isinstance( + user.match_state, TeamVersusUserState + ): + team_counts[user.match_state.team_id] += 1 + + if team_counts: + min_count = min(team_counts.values()) + for team_id, count in team_counts.items(): + if count == min_count: + return team_id + return self.state.teams[0].id if self.state.teams else 0 + + @override + async def handle_join(self, user: MultiplayerRoomUser): + best_team_id = self._get_best_available_team() + user.match_state = TeamVersusUserState(team_id=best_team_id) + await self.hub.change_user_match_state(self.room, user) + + @override + async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): + if not isinstance(request, ChangeTeamRequest): + return + + if request.team_id not in [team.id for team in self.state.teams]: + raise InvokeException("Invalid team ID") + + user.match_state = TeamVersusUserState(team_id=request.team_id) + await self.hub.change_user_match_state(self.room, user) + + @override + async def handle_leave(self, user: MultiplayerRoomUser): ... + + @override + def get_details(self) -> MatchStartedEventDetail: + teams: dict[int, Literal["blue", "red"]] = {} + for user in self.room.room.users: + if user.match_state is not None and isinstance( + user.match_state, TeamVersusUserState + ): + teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red" + detail = MatchStartedEventDetail(room_type="team_versus", team=teams) + return detail + + +MATCH_TYPE_HANDLERS = { + MatchType.HEAD_TO_HEAD: HeadToHeadHandler, + MatchType.TEAM_VERSUS: TeamVersusHandler, +} + + +@dataclass +class ServerMultiplayerRoom: + room: MultiplayerRoom + category: RoomCategory + status: RoomStatus + start_at: datetime + hub: "MultiplayerHub" + match_type_handler: MatchTypeHandler + queue: MultiplayerQueue + _next_countdown_id: int + _countdown_id_lock: asyncio.Lock + _tracked_countdown: dict[int, CountdownInfo] + + def __init__( + self, + room: MultiplayerRoom, + category: RoomCategory, + start_at: datetime, + hub: "MultiplayerHub", + ): + self.room = room + self.category = category + self.status = RoomStatus.IDLE + self.start_at = start_at + self.hub = hub + self.queue = MultiplayerQueue(self) + self._next_countdown_id = 0 + self._countdown_id_lock = asyncio.Lock() + self._tracked_countdown = {} + + async def set_handler(self): + self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type]( + self + ) + for i in self.room.users: + await self.match_type_handler.handle_join(i) + + async def get_next_countdown_id(self) -> int: + async with self._countdown_id_lock: + self._next_countdown_id += 1 + return self._next_countdown_id + + async def start_countdown( + self, + countdown: MultiplayerCountdown, + on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None, + ): + async def _countdown_task(self: "ServerMultiplayerRoom"): + await asyncio.sleep(info.duration.total_seconds()) + if on_complete is not None: + await on_complete(self) + await self.stop_countdown(countdown) + + if countdown.is_exclusive: + await self.stop_all_countdowns() + countdown.id = await self.get_next_countdown_id() + info = CountdownInfo(countdown) + self.room.active_countdowns.append(info.countdown) + self._tracked_countdown[countdown.id] = info + await self.hub.send_match_event( + self, CountdownStartedEvent(countdown=info.countdown) + ) + info.task = asyncio.create_task(_countdown_task(self)) + + async def stop_countdown(self, countdown: MultiplayerCountdown): + info = self._tracked_countdown.get(countdown.id) + if info is None: + return + del self._tracked_countdown[countdown.id] + self.room.active_countdowns.remove(countdown) + await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id)) + if info.task is not None and not info.task.done(): + info.task.cancel() + + async def stop_all_countdowns(self): + for countdown in list(self._tracked_countdown.values()): + await self.stop_countdown(countdown.countdown) + + self._tracked_countdown.clear() + self.room.active_countdowns.clear() + + +class _MatchServerEvent(SignalRUnionMessage): ... + + +class CountdownStartedEvent(_MatchServerEvent): + countdown: MultiplayerCountdown + + union_type: ClassVar[Literal[0]] = 0 + + +class CountdownStoppedEvent(_MatchServerEvent): + id: int + + union_type: ClassVar[Literal[1]] = 1 + + +MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent + + +class GameplayAbortReason(IntEnum): + LOAD_TOOK_TOO_LONG = 0 + HOST_ABORTED = 1 + + +class MatchStartedEventDetail(TypedDict): + room_type: Literal["playlists", "head_to_head", "team_versus"] + team: dict[int, Literal["blue", "red"]] | None diff --git a/app/models/oauth.py b/app/models/oauth.py index 22fcf63..6665965 100644 --- a/app/models/oauth.py +++ b/app/models/oauth.py @@ -1,7 +1,6 @@ # OAuth 相关模型 from __future__ import annotations -from typing import List from pydantic import BaseModel @@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel): class RegistrationErrorResponse(BaseModel): """注册错误响应模型""" + form_error: dict class UserRegistrationErrors(BaseModel): """用户注册错误模型""" - username: List[str] = [] - user_email: List[str] = [] - password: List[str] = [] + + username: list[str] = [] + user_email: list[str] = [] + password: list[str] = [] class RegistrationRequestErrors(BaseModel): """注册请求错误模型""" + message: str | None = None redirect: str | None = None user: UserRegistrationErrors | None = None diff --git a/app/models/room.py b/app/models/room.py index 2d01a26..392562a 100644 --- a/app/models/room.py +++ b/app/models/room.py @@ -1,17 +1,8 @@ from __future__ import annotations -from datetime import datetime, timedelta from enum import Enum -from app.database.beatmap import Beatmap, BeatmapResp -from app.database.user import User as DBUser -from app.fetcher import Fetcher -from app.models.mods import APIMod -from app.models.user import User -from app.utils import convert_db_user_to_api_user - -from pydantic import BaseModel, Field -from sqlmodel.ext.asyncio.session import AsyncSession +from pydantic import BaseModel class RoomCategory(str, Enum): @@ -62,53 +53,24 @@ class MultiplayerUserState(str, Enum): RESULTS = "results" SPECTATING = "spectating" + @property + def is_playing(self) -> bool: + return self in { + self.WAITING_FOR_LOAD, + self.PLAYING, + self.READY_FOR_GAMEPLAY, + self.LOADED, + } + class DownloadState(str, Enum): - UNKOWN = "unkown" + UNKNOWN = "unknown" NOT_DOWNLOADED = "not_downloaded" DOWNLOADING = "downloading" IMPORTING = "importing" LOCALLY_AVAILABLE = "locally_available" -class PlaylistItem(BaseModel): - id: int - owner_id: int - ruleset_id: int - expired: bool - playlist_order: int | None - played_at: datetime | None - allowed_mods: list[APIMod] = [] - required_mods: list[APIMod] = [] - beatmap_id: int - beatmap: BeatmapResp | None - freestyle: bool - - class Config: - exclude_none = True - - @classmethod - async def from_mpListItem( - cls, item: MultiPlayerListItem, db: AsyncSession, fetcher: Fetcher - ): - s = cls.model_validate(item.model_dump()) - s.id = item.id - s.owner_id = item.OwnerID - s.ruleset_id = item.RulesetID - s.expired = item.Expired - s.playlist_order = item.PlaylistOrder - s.played_at = item.PlayedAt - s.required_mods = item.RequierdMods - s.allowed_mods = item.AllowedMods - s.freestyle = item.Freestyle - cur_beatmap = await Beatmap.get_or_fetch( - db, fetcher=fetcher, bid=item.BeatmapID - ) - s.beatmap = BeatmapResp.from_db(cur_beatmap) - s.beatmap_id = item.BeatmapID - return s - - class RoomPlaylistItemStats(BaseModel): count_active: int count_total: int @@ -120,269 +82,7 @@ class RoomDifficultyRange(BaseModel): max: float -class ItemAttemptsCount(BaseModel): - id: int - attempts: int - passed: bool - - -class PlaylistAggregateScore(BaseModel): - playlist_item_attempts: list[ItemAttemptsCount] - - -class MultiplayerRoomSettings(BaseModel): - Name: str = "Unnamed Room" - PlaylistItemId: int - Password: str = "" - MatchType: MatchType - QueueMode: QueueMode - AutoStartDuration: timedelta - AutoSkip: bool - - @classmethod - def from_apiRoom(cls, room: Room): - s = cls.model_validate(room.model_dump()) - s.Name = room.name - s.Password = room.password if room.password is not None else "" - s.MatchType = room.type - s.QueueMode = room.queue_mode - s.AutoStartDuration = timedelta(seconds=room.auto_start_duration) - s.AutoSkip = room.auto_skip - return s - - -class BeatmapAvailability(BaseModel): - State: DownloadState - DownloadProgress: float | None - - -class MatchUserState(BaseModel): - class Config: - extra = "allow" - - -class TeamVersusState(MatchUserState): - TeamId: int - - -MatchUserStateType = TeamVersusState | MatchUserState - - -class MultiplayerRoomUser(BaseModel): - UserID: int - State: MultiplayerUserState = MultiplayerUserState.IDLE - BeatmapAvailability: BeatmapAvailability - Mods: list[APIMod] = [] - MatchUserState: MatchUserStateType | None - RulesetId: int | None - BeatmapId: int | None - User: User | None - - @classmethod - async def from_id(cls, id: int, db: AsyncSession): - actualUser = ( - await db.exec( - DBUser.all_select_clause().where( - DBUser.id == id, - ) - ) - ).first() - user = ( - await convert_db_user_to_api_user(actualUser) - if actualUser is not None - else None - ) - return MultiplayerRoomUser( - UserID=id, - MatchUserState=None, - BeatmapAvailability=BeatmapAvailability( - State=DownloadState.UNKOWN, DownloadProgress=None - ), - RulesetId=None, - BeatmapId=None, - User=user, - ) - - -class MatchRoomState(BaseModel): - class Config: - extra = "allow" - - -class MultiPlayerTeam(BaseModel): - id: int = 0 - name: str = "" - - -class TeamVersusRoomState(BaseModel): - teams: list[MultiPlayerTeam] = [] - - class Config: - pass - - @classmethod - def create_default(cls): - return cls( - teams=[ - MultiPlayerTeam(id=0, name="Team Red"), - MultiPlayerTeam(id=1, name="Team Blue"), - ] - ) - - -MatchRoomStateType = TeamVersusRoomState | MatchRoomState - - -class MultiPlayerListItem(BaseModel): - id: int - OwnerID: int - BeatmapID: int - BeatmapChecksum: str = "" - RulesetID: int - RequierdMods: list[APIMod] - AllowedMods: list[APIMod] - Expired: bool - PlaylistOrder: int | None - PlayedAt: datetime | None - StarRating: float - Freestyle: bool - - @classmethod - async def from_apiItem(cls, item: PlaylistItem, db: AsyncSession, fetcher: Fetcher): - s = cls.model_validate(item.model_dump()) - s.id = item.id - s.OwnerID = item.owner_id - if item.beatmap is None: # 从客户端接受的一定没有这字段 - cur_beatmap = await Beatmap.get_or_fetch( - db, fetcher=fetcher, bid=item.beatmap_id - ) - s.BeatmapID = cur_beatmap.id if cur_beatmap.id is not None else 0 - s.BeatmapChecksum = cur_beatmap.checksum - s.StarRating = cur_beatmap.difficulty_rating - s.RulesetID = item.ruleset_id - s.RequierdMods = item.required_mods - s.AllowedMods = item.allowed_mods - s.Expired = item.expired - s.PlaylistOrder = item.playlist_order if item.playlist_order is not None else 0 - s.PlayedAt = item.played_at - s.Freestyle = item.freestyle - return s - - -class MultiplayerCountdown(BaseModel): - id: int = 0 - time_remaining: timedelta = timedelta(seconds=0) - is_exclusive: bool = True - - class Config: - extra = "allow" - - -class MatchStartCountdown(MultiplayerCountdown): - pass - - -class ForceGameplayStartCountdown(MultiplayerCountdown): - pass - - -class ServerShuttingCountdown(MultiplayerCountdown): - pass - - -MultiplayerCountdownType = ( - MatchStartCountdown - | ForceGameplayStartCountdown - | ServerShuttingCountdown - | MultiplayerCountdown -) - - class PlaylistStatus(BaseModel): count_active: int count_total: int ruleset_ids: list[int] - - -class MultiplayerRoom(BaseModel): - RoomId: int - State: MultiplayerRoomState - Settings: MultiplayerRoomSettings = MultiplayerRoomSettings( - PlaylistItemId=0, - MatchType=MatchType.HEAD_TO_HEAD, - QueueMode=QueueMode.HOST_ONLY, - AutoStartDuration=timedelta(0), - AutoSkip=False, - ) - Users: list[MultiplayerRoomUser] - Host: MultiplayerRoomUser - MatchState: MatchRoomState | None - Playlist: list[MultiPlayerListItem] - ActivecCountDowns: list[MultiplayerCountdownType] - ChannelID: int - - @classmethod - def CanAddPlayistItem(cls, user: MultiplayerRoomUser) -> bool: - return user == cls.Host or cls.Settings.QueueMode != QueueMode.HOST_ONLY - - @classmethod - async def from_apiRoom(cls, room: Room, db: AsyncSession, fetcher: Fetcher): - s = cls.model_validate(room.model_dump()) - s.RoomId = room.room_id if room.room_id is not None else 0 - s.ChannelID = room.channel_id - s.Settings = MultiplayerRoomSettings.from_apiRoom(room) - s.Host = await MultiplayerRoomUser.from_id(room.host.id if room.host else 0, db) - s.Playlist = [ - await MultiPlayerListItem.from_apiItem(item, db, fetcher) - for item in room.playlist - ] - return s - - -class Room(BaseModel): - room_id: int - name: str - password: str | None - has_password: bool = Field(exclude=True) - host: User | None - category: RoomCategory - duration: int | None - starts_at: datetime | None - ends_at: datetime | None - max_particapants: int | None = Field(exclude=True) - particapant_count: int - recent_particapants: list[User] - type: MatchType - max_attempts: int | None - playlist: list[PlaylistItem] - playlist_item_status: list[RoomPlaylistItemStats] - difficulity_range: RoomDifficultyRange - queue_mode: QueueMode - auto_skip: bool - auto_start_duration: int - current_user_score: PlaylistAggregateScore | None - current_playlist_item: PlaylistItem | None - channel_id: int - status: RoomStatus - availability: RoomAvailability = Field(exclude=True) - - class Config: - exclude_none = True - - @classmethod - async def from_mpRoom( - cls, room: MultiplayerRoom, db: AsyncSession, fetcher: Fetcher - ): - s = cls.model_validate(room.model_dump()) - s.room_id = room.RoomId - s.name = room.Settings.Name - s.password = room.Settings.Password - s.type = room.Settings.MatchType - s.queue_mode = room.Settings.QueueMode - s.auto_skip = room.Settings.AutoSkip - s.host = room.Host.User - s.playlist = [ - await PlaylistItem.from_mpListItem(item, db, fetcher) - for item in room.Playlist - ] - return s diff --git a/app/models/score.py b/app/models/score.py index b613ae2..cef6b28 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -1,6 +1,6 @@ from __future__ import annotations -from enum import Enum, IntEnum +from enum import Enum from typing import Literal, TypedDict from .mods import API_MODS, APIMod, init_mods @@ -93,52 +93,14 @@ class HitResult(str, Enum): ) -class HitResultInt(IntEnum): - PERFECT = 0 - GREAT = 1 - GOOD = 2 - OK = 3 - MEH = 4 - MISS = 5 - - LARGE_TICK_HIT = 6 - SMALL_TICK_HIT = 7 - SLIDER_TAIL_HIT = 8 - - LARGE_BONUS = 9 - SMALL_BONUS = 10 - - LARGE_TICK_MISS = 11 - SMALL_TICK_MISS = 12 - - IGNORE_HIT = 13 - IGNORE_MISS = 14 - - NONE = 15 - COMBO_BREAK = 16 - - LEGACY_COMBO_INCREASE = 99 - - def is_hit(self) -> bool: - return self not in ( - HitResultInt.NONE, - HitResultInt.IGNORE_MISS, - HitResultInt.COMBO_BREAK, - HitResultInt.LARGE_TICK_MISS, - HitResultInt.SMALL_TICK_MISS, - HitResultInt.MISS, - ) - - class LeaderboardType(Enum): GLOBAL = "global" - FRIENDS = "friends" + FRIENDS = "friend" COUNTRY = "country" TEAM = "team" ScoreStatistics = dict[HitResult, int] -ScoreStatisticsInt = dict[HitResultInt, int] class SoloScoreSubmissionInfo(BaseModel): @@ -176,8 +138,8 @@ class SoloScoreSubmissionInfo(BaseModel): class LegacyReplaySoloScoreInfo(TypedDict): online_id: int mods: list[APIMod] - statistics: ScoreStatisticsInt - maximum_statistics: ScoreStatisticsInt + statistics: ScoreStatistics + maximum_statistics: ScoreStatistics client_version: str rank: Rank user_id: int diff --git a/app/models/signalr.py b/app/models/signalr.py index 37b2741..ffbaf6b 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -1,54 +1,23 @@ from __future__ import annotations -import datetime -from typing import Any, get_origin +from dataclasses import dataclass +from typing import ClassVar from pydantic import ( BaseModel, - ConfigDict, Field, - TypeAdapter, - model_serializer, - model_validator, ) -def serialize_to_list(value: BaseModel) -> list[Any]: - data = [] - for field, info in value.__class__.model_fields.items(): - v = getattr(value, field) - anno = get_origin(info.annotation) - if anno and issubclass(anno, BaseModel): - data.append(serialize_to_list(v)) - elif anno and issubclass(anno, list): - data.append( - TypeAdapter( - info.annotation, config=ConfigDict(arbitrary_types_allowed=True) - ).dump_python(v) - ) - elif isinstance(v, datetime.datetime): - data.append([v, 0]) - else: - data.append(v) - return data +@dataclass +class SignalRMeta: + member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute + json_ignore: bool = False # implement of JsonIgnore (json) attribute + use_abbr: bool = True -class MessagePackArrayModel(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - @model_validator(mode="before") - @classmethod - def unpack(cls, v: Any) -> Any: - if isinstance(v, list): - fields = list(cls.model_fields.keys()) - if len(v) != len(fields): - raise ValueError(f"Expected list of length {len(fields)}, got {len(v)}") - return dict(zip(fields, v)) - return v - - @model_serializer - def serialize(self) -> list[Any]: - return serialize_to_list(self) +class SignalRUnionMessage(BaseModel): + union_type: ClassVar[int] class Transport(BaseModel): diff --git a/app/models/spectator_hub.py b/app/models/spectator_hub.py index 994e083..9f35932 100644 --- a/app/models/spectator_hub.py +++ b/app/models/spectator_hub.py @@ -2,17 +2,17 @@ from __future__ import annotations import datetime from enum import IntEnum -from typing import Any +from typing import Annotated, Any from app.models.beatmap import BeatmapRankStatus +from app.models.mods import APIMod from .score import ( - ScoreStatisticsInt, + ScoreStatistics, ) -from .signalr import MessagePackArrayModel, UserState +from .signalr import SignalRMeta, UserState -from msgpack_lazer_api import APIMod -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, Field, field_validator class SpectatedUserState(IntEnum): @@ -24,14 +24,12 @@ class SpectatedUserState(IntEnum): Quit = 5 -class SpectatorState(MessagePackArrayModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - +class SpectatorState(BaseModel): beatmap_id: int | None = None ruleset_id: int | None = None # 0,1,2,3 mods: list[APIMod] = Field(default_factory=list) state: SpectatedUserState - maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict) + maximum_statistics: ScoreStatistics = Field(default_factory=dict) def __eq__(self, other: object) -> bool: if not isinstance(other, SpectatorState): @@ -44,22 +42,20 @@ class SpectatorState(MessagePackArrayModel): ) -class ScoreProcessorStatistics(MessagePackArrayModel): - base_score: int - maximum_base_score: int +class ScoreProcessorStatistics(BaseModel): + base_score: float + maximum_base_score: float accuracy_judgement_count: int combo_portion: float - bouns_portion: float + bonus_portion: float -class FrameHeader(MessagePackArrayModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - +class FrameHeader(BaseModel): total_score: int - acc: float + accuracy: float combo: int max_combo: int - statistics: ScoreStatisticsInt = Field(default_factory=dict) + statistics: ScoreStatistics = Field(default_factory=dict) score_processor_statistics: ScoreProcessorStatistics received_time: datetime.datetime mods: list[APIMod] = Field(default_factory=list) @@ -87,14 +83,18 @@ class FrameHeader(MessagePackArrayModel): # SMOKE = 16 -class LegacyReplayFrame(MessagePackArrayModel): +class LegacyReplayFrame(BaseModel): time: float # from ReplayFrame,the parent of LegacyReplayFrame - x: float | None = None - y: float | None = None + mouse_x: float | None = None + mouse_y: float | None = None button_state: int + header: Annotated[ + FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True) + ] -class FrameDataBundle(MessagePackArrayModel): + +class FrameDataBundle(BaseModel): header: FrameHeader frames: list[LegacyReplayFrame] @@ -106,18 +106,16 @@ class APIUser(BaseModel): class ScoreInfo(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - mods: list[APIMod] user: APIUser ruleset: int - maximum_statistics: ScoreStatisticsInt + maximum_statistics: ScoreStatistics id: int | None = None total_score: int | None = None - acc: float | None = None + accuracy: float | None = None max_combo: int | None = None combo: int | None = None - statistics: ScoreStatisticsInt = Field(default_factory=dict) + statistics: ScoreStatistics = Field(default_factory=dict) class StoreScore(BaseModel): diff --git a/app/models/user.py b/app/models/user.py index 4eb0911..3052eef 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -2,15 +2,11 @@ from __future__ import annotations from datetime import datetime from enum import Enum -from typing import TYPE_CHECKING -from .score import GameMode +from .model import UTCBaseModel from pydantic import BaseModel -if TYPE_CHECKING: - from app.database import LazerUserAchievement, Team - class PlayStyle(str, Enum): MOUSE = "mouse" @@ -77,24 +73,7 @@ class MonthlyPlaycount(BaseModel): count: int -class UserAchievement(BaseModel): - achieved_at: datetime - achievement_id: int - - # 添加数据库模型转换方法 - def to_db_model(self, user_id: int) -> "LazerUserAchievement": - from app.database import ( - LazerUserAchievement, - ) - - return LazerUserAchievement( - user_id=user_id, - achievement_id=self.achievement_id, - achieved_at=self.achieved_at, - ) - - -class RankHighest(BaseModel): +class RankHighest(UTCBaseModel): rank: int updated_at: datetime @@ -104,115 +83,6 @@ class RankHistory(BaseModel): data: list[int] -class DailyChallengeStats(BaseModel): - daily_streak_best: int = 0 - daily_streak_current: int = 0 - last_update: datetime | None = None - last_weekly_streak: datetime | None = None - playcount: int = 0 - top_10p_placements: int = 0 - top_50p_placements: int = 0 - user_id: int - weekly_streak_best: int = 0 - weekly_streak_current: int = 0 - - class Page(BaseModel): html: str = "" raw: str = "" - - -class User(BaseModel): - # 基本信息 - id: int - username: str - avatar_url: str - country_code: str - default_group: str = "default" - is_active: bool = True - is_bot: bool = False - is_deleted: bool = False - is_online: bool = True - is_supporter: bool = False - is_restricted: bool = False - last_visit: datetime | None = None - pm_friends_only: bool = False - profile_colour: str | None = None - - # 个人资料 - cover_url: str | None = None - discord: str | None = None - has_supported: bool = False - interests: str | None = None - join_date: datetime - location: str | None = None - max_blocks: int = 100 - max_friends: int = 500 - occupation: str | None = None - playmode: GameMode = GameMode.OSU - playstyle: list[PlayStyle] = [] - post_count: int = 0 - profile_hue: int | None = None - profile_order: list[str] = [ - "me", - "recent_activity", - "top_ranks", - "medals", - "historical", - "beatmaps", - "kudosu", - ] - title: str | None = None - title_url: str | None = None - twitter: str | None = None - website: str | None = None - session_verified: bool = False - support_level: int = 0 - - # 关联对象 - country: Country - cover: Cover - kudosu: Kudosu - statistics: Statistics - statistics_rulesets: dict[str, Statistics] - - # 计数信息 - beatmap_playcounts_count: int = 0 - comments_count: int = 0 - favourite_beatmapset_count: int = 0 - follower_count: int = 0 - graveyard_beatmapset_count: int = 0 - guest_beatmapset_count: int = 0 - loved_beatmapset_count: int = 0 - mapping_follower_count: int = 0 - nominated_beatmapset_count: int = 0 - pending_beatmapset_count: int = 0 - ranked_beatmapset_count: int = 0 - ranked_and_approved_beatmapset_count: int = 0 - unranked_beatmapset_count: int = 0 - scores_best_count: int = 0 - scores_first_count: int = 0 - scores_pinned_count: int = 0 - scores_recent_count: int = 0 - - # 历史数据 - account_history: list[dict] = [] - active_tournament_banner: dict | None = None - active_tournament_banners: list[dict] = [] - badges: list[dict] = [] - current_season_stats: dict | None = None - daily_challenge_user_stats: DailyChallengeStats | None = None - groups: list[dict] = [] - monthly_playcounts: list[MonthlyPlaycount] = [] - page: Page = Page() - previous_usernames: list[str] = [] - rank_highest: RankHighest | None = None - rank_history: RankHistory | None = None - rankHistory: RankHistory | None = None # 兼容性别名 - replays_watched_counts: list[dict] = [] - team: "Team | None" = None - user_achievements: list[UserAchievement] = [] - - -class APIUser(BaseModel): - id: int diff --git a/app/router/__init__.py b/app/router/__init__.py index 1e87343..22f6c70 100644 --- a/app/router/__init__.py +++ b/app/router/__init__.py @@ -7,6 +7,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401 beatmapset, me, relationship, + room, score, user, ) @@ -14,4 +15,9 @@ from .api_router import router as api_router from .auth import router as auth_router from .fetcher import fetcher_router as fetcher_router -__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"] +__all__ = [ + "api_router", + "auth_router", + "fetcher_router", + "signalr_router", +] diff --git a/app/router/auth.py b/app/router/auth.py index 0f41b32..7a2a14d 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import timedelta +from datetime import UTC, datetime, timedelta import re from app.auth import ( @@ -12,17 +12,21 @@ from app.auth import ( store_token, ) from app.config import settings -from app.database import User as DBUser +from app.database import DailyChallengeStats, User +from app.database.statistics import UserStatistics from app.dependencies import get_db +from app.log import logger from app.models.oauth import ( OAuthErrorResponse, RegistrationRequestErrors, TokenResponse, UserRegistrationErrors, ) +from app.models.score import GameMode from fastapi import APIRouter, Depends, Form from fastapi.responses import JSONResponse +from sqlalchemy import text from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -110,12 +114,12 @@ async def register_user( email_errors = validate_email(user_email) password_errors = validate_password(user_password) - result = await db.exec(select(DBUser).where(DBUser.name == user_username)) + result = await db.exec(select(User).where(User.username == user_username)) existing_user = result.first() if existing_user: username_errors.append("Username is already taken") - result = await db.exec(select(DBUser).where(DBUser.email == user_email)) + result = await db.exec(select(User).where(User.email == user_email)) existing_email = result.first() if existing_email: email_errors.append("Email is already taken") @@ -135,119 +139,41 @@ async def register_user( try: # 创建新用户 - from datetime import datetime - import time + # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy) + result = await db.execute( # pyright: ignore[reportDeprecated] + text( + "SELECT AUTO_INCREMENT FROM information_schema.TABLES " + "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'" + ) + ) + next_id = result.one()[0] + if next_id <= 2: + await db.execute(text("ALTER TABLE lazer_users AUTO_INCREMENT = 3")) + await db.commit() - new_user = DBUser( - name=user_username, - safe_name=user_username.lower(), # 安全用户名(小写) + new_user = User( + username=user_username, email=user_email, pw_bcrypt=get_password_hash(user_password), priv=1, # 普通用户权限 - country="CN", # 默认国家 - creation_time=int(time.time()), - latest_activity=int(time.time()), - preferred_mode=0, # 默认模式 - play_style=0, # 默认游戏风格 + country_code="CN", # 默认国家 + join_date=datetime.now(UTC), + last_visit=datetime.now(UTC), ) - db.add(new_user) await db.commit() await db.refresh(new_user) - - # 保存用户ID,因为会话可能会关闭 - user_id = new_user.id - - if user_id <= 2: - await db.rollback() - try: - from sqlalchemy import text - - # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy) - await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3")) - await db.commit() - - # 重新创建用户 - new_user = DBUser( - name=user_username, - safe_name=user_username.lower(), - email=user_email, - pw_bcrypt=get_password_hash(user_password), - priv=1, - country="CN", - creation_time=int(time.time()), - latest_activity=int(time.time()), - preferred_mode=0, - play_style=0, - ) - - db.add(new_user) - await db.commit() - await db.refresh(new_user) - user_id = new_user.id - - # 最终检查ID是否有效 - if user_id <= 2: - await db.rollback() - errors = RegistrationRequestErrors( - message=( - "Failed to create account with valid ID. " - "Please contact support." - ) - ) - return JSONResponse( - status_code=500, content={"form_error": errors.model_dump()} - ) - - except Exception as fix_error: - await db.rollback() - print(f"Failed to fix AUTO_INCREMENT: {fix_error}") - errors = RegistrationRequestErrors( - message="Failed to create account with valid ID. Please try again." - ) - return JSONResponse( - status_code=500, content={"form_error": errors.model_dump()} - ) - - # 创建默认的 lazer_profile - from app.database.user import LazerUserProfile - - lazer_profile = LazerUserProfile( - user_id=user_id, - is_active=True, - is_bot=False, - is_deleted=False, - is_online=True, - is_supporter=False, - is_restricted=False, - session_verified=False, - has_supported=False, - pm_friends_only=False, - default_group="default", - join_date=datetime.utcnow(), - playmode="osu", - support_level=0, - max_blocks=50, - max_friends=250, - post_count=0, - ) - - db.add(lazer_profile) + assert new_user.id is not None, "New user ID should not be None" + for i in GameMode: + statistics = UserStatistics(mode=i, user_id=new_user.id) + db.add(statistics) + daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id) + db.add(daily_challenge_user_stats) await db.commit() - - # 返回成功响应 - return JSONResponse( - status_code=201, - content={"message": "Account created successfully", "user_id": user_id}, - ) - - except Exception as e: + except Exception: await db.rollback() # 打印详细错误信息用于调试 - print(f"Registration error: {e}") - import traceback - - traceback.print_exc() + logger.exception(f"Registration error for user {user_username}") # 返回通用错误 errors = RegistrationRequestErrors( @@ -323,6 +249,7 @@ async def oauth_token( refresh_token_str = generate_refresh_token() # 存储令牌 + assert user.id await store_token( db, user.id, diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 71d554f..7dfd0f9 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -5,12 +5,7 @@ import hashlib import json from app.calculator import calculate_beatmap_attribute -from app.database import ( - Beatmap, - BeatmapResp, - User as DBUser, -) -from app.database.beatmapset import Beatmapset +from app.database import Beatmap, BeatmapResp, User from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -27,9 +22,8 @@ from .api_router import router from fastapi import Depends, HTTPException, Query from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel -from redis import Redis +from redis.asyncio import Redis import rosu_pp_py as rosu -from sqlalchemy.orm import joinedload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -39,7 +33,7 @@ async def lookup_beatmap( id: int | None = Query(default=None, alias="id"), md5: str | None = Query(default=None, alias="checksum"), filename: str | None = Query(default=None, alias="filename"), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): @@ -56,19 +50,19 @@ async def lookup_beatmap( if beatmap is None: raise HTTPException(status_code=404, detail="Beatmap not found") - return BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap, session=db, user=current_user) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) async def get_beatmap( bid: int, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): try: beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) - return BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap, session=db, user=current_user) except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") @@ -80,43 +74,28 @@ class BatchGetResp(BaseModel): @router.get("/beatmaps", tags=["beatmap"], response_model=BatchGetResp) @router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp) async def batch_get_beatmaps( - b_ids: list[int] = Query(alias="id", default_factory=list), - current_user: DBUser = Depends(get_current_user), + b_ids: list[int] = Query(alias="ids[]", default_factory=list), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if not b_ids: # select 50 beatmaps by last_updated beatmaps = ( await db.exec( - select(Beatmap) - .options( - joinedload( - Beatmap.beatmapset # pyright: ignore[reportArgumentType] - ).selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ) - ) - .order_by(col(Beatmap.last_updated).desc()) - .limit(50) + select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50) ) ).all() else: beatmaps = ( - await db.exec( - select(Beatmap) - .options( - joinedload( - Beatmap.beatmapset # pyright: ignore[reportArgumentType] - ).selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ) - ) - .where(col(Beatmap.id).in_(b_ids)) - .limit(50) - ) + await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)) ).all() - return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps]) + return BatchGetResp( + beatmaps=[ + await BeatmapResp.from_db(bm, session=db, user=current_user) + for bm in beatmaps + ] + ) @router.post( @@ -126,7 +105,7 @@ async def batch_get_beatmaps( ) async def get_beatmap_attributes( beatmap: int, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), mods: list[str] = Query(default_factory=list), ruleset: GameMode | None = Query(default=None), ruleset_id: int | None = Query(default=None), @@ -153,8 +132,8 @@ async def get_beatmap_attributes( f"beatmap:{beatmap}:{ruleset}:" f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" ) - if redis.exists(key): - return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType] + if await redis.exists(key): + return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] try: resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap) @@ -164,7 +143,7 @@ async def get_beatmap_attributes( ) except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] raise HTTPException(status_code=400, detail=str(e)) - redis.set(key, attr.model_dump_json()) + await redis.set(key, attr.model_dump_json()) return attr except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmap not found") diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index db2dd77..f77c2ed 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -1,10 +1,8 @@ from __future__ import annotations -from app.database import ( - Beatmapset, - BeatmapsetResp, - User as DBUser, -) +from typing import Literal + +from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.dependencies.database import get_db from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -12,27 +10,47 @@ from app.fetcher import Fetcher from .api_router import router -from fastapi import Depends, HTTPException +from fastapi import Depends, Form, HTTPException, Query +from fastapi.responses import RedirectResponse from httpx import HTTPStatusError -from sqlalchemy.orm import selectinload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +@router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp) +async def lookup_beatmapset( + beatmap_id: int = Query(), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), + fetcher: Fetcher = Depends(get_fetcher), +): + beatmapset_id = ( + await db.exec(select(Beatmap.beatmapset_id).where(Beatmap.id == beatmap_id)) + ).first() + if not beatmapset_id: + try: + resp = await fetcher.get_beatmap(beatmap_id) + await Beatmap.from_resp(db, resp) + await db.refresh(current_user) + except HTTPStatusError: + raise HTTPException(status_code=404, detail="Beatmapset not found") + beatmapset = ( + await db.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id)) + ).first() + if not beatmapset: + raise HTTPException(status_code=404, detail="Beatmapset not found") + resp = await BeatmapsetResp.from_db(beatmapset, session=db, user=current_user) + return resp + + @router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp) async def get_beatmapset( sid: int, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): - beatmapset = ( - await db.exec( - select(Beatmapset) - .options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType] - .where(Beatmapset.id == sid) - ) - ).first() + beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first() if not beatmapset: try: resp = await fetcher.get_beatmapset(sid) @@ -40,5 +58,55 @@ async def get_beatmapset( except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmapset not found") else: - resp = BeatmapsetResp.from_db(beatmapset) + resp = await BeatmapsetResp.from_db( + beatmapset, session=db, include=["recent_favourites"], user=current_user + ) return resp + + +@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"]) +async def download_beatmapset( + beatmapset: int, + no_video: bool = Query(True, alias="noVideo"), + current_user: User = Depends(get_current_user), +): + if current_user.country_code == "CN": + return RedirectResponse( + f"https://txy1.sayobot.cn/beatmaps/download/" + f"{'novideo' if no_video else 'full'}/{beatmapset}?server=auto" + ) + else: + return RedirectResponse( + f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}" + ) + + +@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"]) +async def favourite_beatmapset( + beatmapset: int, + action: Literal["favourite", "unfavourite"] = Form(), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + existing_favourite = ( + await db.exec( + select(FavouriteBeatmapset).where( + FavouriteBeatmapset.user_id == current_user.id, + FavouriteBeatmapset.beatmapset_id == beatmapset, + ) + ) + ).first() + + if action == "favourite" and existing_favourite: + raise HTTPException(status_code=400, detail="Already favourited") + elif action == "unfavourite" and not existing_favourite: + raise HTTPException(status_code=400, detail="Not favourited") + + if action == "favourite": + favourite = FavouriteBeatmapset( + user_id=current_user.id, beatmapset_id=beatmapset + ) + db.add(favourite) + else: + await db.delete(existing_favourite) + await db.commit() diff --git a/app/router/me.py b/app/router/me.py index 93dcbdc..b6d7d26 100644 --- a/app/router/me.py +++ b/app/router/me.py @@ -1,28 +1,27 @@ from __future__ import annotations -from typing import Literal - -from app.database import ( - User as DBUser, -) +from app.database import User, UserResp +from app.database.lazer_user import ALL_INCLUDED from app.dependencies import get_current_user -from app.models.user import ( - User as ApiUser, -) -from app.utils import convert_db_user_to_api_user +from app.dependencies.database import get_db +from app.models.score import GameMode from .api_router import router from fastapi import Depends +from sqlmodel.ext.asyncio.session import AsyncSession -@router.get("/me/{ruleset}", response_model=ApiUser) -@router.get("/me/", response_model=ApiUser) +@router.get("/me/{ruleset}", response_model=UserResp) +@router.get("/me/", response_model=UserResp) async def get_user_info_default( - ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", - current_user: DBUser = Depends(get_current_user), + ruleset: GameMode | None = None, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), ): - """获取当前用户信息(默认使用osu模式)""" - # 默认使用osu模式 - api_user = await convert_db_user_to_api_user(current_user, ruleset) - return api_user + return await UserResp.from_db( + current_user, + session, + ALL_INCLUDED, + ruleset, + ) diff --git a/app/router/relationship.py b/app/router/relationship.py index 9ed5b0f..0832d09 100644 --- a/app/router/relationship.py +++ b/app/router/relationship.py @@ -8,7 +8,7 @@ from app.dependencies.user import get_current_user from .api_router import router from fastapi import Depends, HTTPException, Query, Request -from sqlalchemy.orm import joinedload +from pydantic import BaseModel from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -26,17 +26,19 @@ async def get_relationship( else RelationshipType.BLOCK ) relationships = await db.exec( - select(Relationship) - .options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType] - .where( + select(Relationship).where( Relationship.user_id == current_user.id, Relationship.type == relationship_type, ) ) - return [await RelationshipResp.from_db(db, rel) for rel in relationships] + return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()] -@router.post("/friends", tags=["relationship"], response_model=RelationshipResp) +class AddFriendResp(BaseModel): + user_relation: RelationshipResp + + +@router.post("/friends", tags=["relationship"], response_model=AddFriendResp) @router.post("/blocks", tags=["relationship"]) async def add_relationship( request: Request, @@ -87,14 +89,10 @@ async def add_relationship( if origin_type == RelationshipType.FOLLOW: relationship = ( await db.exec( - select(Relationship) - .where( + select(Relationship).where( Relationship.user_id == current_user_id, Relationship.target_id == target, ) - .options( - joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType] - ) ) ).first() assert relationship, "Relationship should exist after commit" diff --git a/app/router/room.py b/app/router/room.py index 6e6f4a1..2677b75 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -1,106 +1,296 @@ from __future__ import annotations -from app.database.room import RoomIndex +from datetime import UTC, datetime +from typing import Literal + +from app.database.beatmap import Beatmap, BeatmapResp +from app.database.beatmapset import BeatmapsetResp +from app.database.lazer_user import User, UserResp +from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp +from app.database.playlist_attempts import ItemAttemptsCount, ItemAttemptsResp +from app.database.playlists import Playlist, PlaylistResp +from app.database.room import Room, RoomBase, RoomResp +from app.database.score import Score from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher +from app.dependencies.user import get_current_user from app.fetcher import Fetcher -from app.models.room import MultiplayerRoom, MultiplayerRoomState, Room +from app.models.multiplayer_hub import ( + MultiplayerRoom, + MultiplayerRoomUser, + ServerMultiplayerRoom, +) +from app.models.room import RoomStatus +from app.signalr.hub import MultiplayerHubs + +from .api_router import router -from api_router import router from fastapi import Depends, HTTPException, Query -from sqlmodel import select +from pydantic import BaseModel, Field +from redis.asyncio import Redis +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession -@router.get("/rooms", tags=["rooms"], response_model=list[Room]) +@router.get("/rooms", tags=["rooms"], response_model=list[RoomResp]) async def get_all_rooms( - mode: str | None = Query(None), # TODO: 对房间根据状态进行筛选 - status: str | None = Query(None), - category: str | None = Query( - None - ), # TODO: 对房间根据分类进行筛选(真的有人用这功能吗) + mode: Literal["open", "ended", "participated", "owned", None] = Query( + default="open" + ), # TODO: 对房间根据状态进行筛选 + category: str = Query(default="realtime"), # TODO + status: RoomStatus | None = Query(None), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), + redis: Redis = Depends(get_redis), + current_user: User = Depends(get_current_user), ): - all_roomID = (await db.exec(select(RoomIndex))).all() - redis = get_redis() - if redis is not None: - resp: list[Room] = [] - for id in all_roomID: - dumped_room = redis.get(str(id)) - validated_room = MultiplayerRoom.model_validate_json(str(dumped_room)) - flag: bool = False - if status is not None: - if ( - validated_room.State == MultiplayerRoomState.OPEN - and status == "idle" - ): - flag = True - elif validated_room != MultiplayerRoomState.CLOSED: - flag = True - if flag: - resp.append( - await Room.from_mpRoom( - MultiplayerRoom.model_validate_json(str(dumped_room)), - db, - fetcher, - ) - ) - return resp - else: - raise HTTPException(status_code=500, detail="Redis Error") + rooms = MultiplayerHubs.rooms.values() + resp_list: list[RoomResp] = [] + for room in rooms: + # if category == "realtime" and room.category != "normal": + # continue + # elif category != room.category and category != "": + # continue + resp_list.append(await RoomResp.from_hub(room)) + return resp_list -@router.get("/rooms/{room}", tags=["room"], response_model=Room) +class APICreatedRoom(RoomResp): + error: str = "" + + +class APIUploadedRoom(RoomBase): + def to_room(self) -> Room: + """ + 将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。 + """ + room_dict = self.model_dump() + room_dict.pop("playlist", None) + # host_id 已在字段中 + return Room(**room_dict) + + id: int | None + host_id: int | None = None + playlist: list[Playlist] + + +@router.post("/rooms", tags=["room"], response_model=APICreatedRoom) +async def create_room( + room: APIUploadedRoom, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + # db_room = Room.from_resp(room) + await db.refresh(current_user) + user_id = current_user.id + db_room = room.to_room() + db_room.host_id = current_user.id if current_user.id else 1 + db.add(db_room) + await db.commit() + await db.refresh(db_room) + + playlist: list[Playlist] = [] + # 处理 APIUploadedRoom 里的 playlist 字段 + for item in room.playlist: + # 确保 room_id 正确赋值 + item.id = await Playlist.get_next_id_for_room(db_room.id, db) + item.room_id = db_room.id + item.owner_id = user_id if user_id else 1 + db.add(item) + await db.commit() + await db.refresh(item) + playlist.append(item) + await db.refresh(db_room) + db_room.playlist = playlist + server_room = ServerMultiplayerRoom( + room=MultiplayerRoom.from_db(db_room), + category=db_room.category, + start_at=datetime.now(UTC), + hub=MultiplayerHubs, + ) + MultiplayerHubs.rooms[db_room.id] = server_room + created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room)) + created_room.error = "" + return created_room + + +@router.get("/rooms/{room}", tags=["room"], response_model=RoomResp) async def get_room( room: int, db: AsyncSession = Depends(get_db), - fetcher: Fetcher = Depends(get_fetcher), ): - redis = get_redis() - if redis: - dumped_room = str(redis.get(str(room))) - if dumped_room is not None: - resp = await Room.from_mpRoom( - MultiplayerRoom.model_validate_json(str(dumped_room)), db, fetcher - ) - return resp - else: - raise HTTPException(status_code=404, detail="Room Not Found") - else: - raise HTTPException(status_code=500, detail="Redis error") - - -class APICreatedRoom(Room): - error: str | None - - -@router.post("/rooms", tags=["beatmap"], response_model=APICreatedRoom) -async def create_room( - room: Room, - db: AsyncSession = Depends(get_db), - fetcher: Fetcher = Depends(get_fetcher), -): - redis = get_redis() - if redis: - room_index = RoomIndex() - db.add(room_index) - await db.commit() - await db.refresh(room_index) - server_room = await MultiplayerRoom.from_apiRoom(room, db, fetcher) - redis.set(str(room_index.id), server_room.model_dump_json()) - room.room_id = room_index.id - return APICreatedRoom(**room.model_dump(), error=None) - else: - raise HTTPException(status_code=500, detail="redis error") + server_room = MultiplayerHubs.rooms[room] + return await RoomResp.from_hub(server_room) @router.delete("/rooms/{room}", tags=["room"]) -async def remove_room(room: int, db: AsyncSession = Depends(get_db)): - redis = get_redis() - if redis: - redis.delete(str(room)) - room_index = await db.get(RoomIndex, room) - if room_index: - await db.delete(room_index) +async def delete_room(room: int, db: AsyncSession = Depends(get_db)): + db_room = (await db.exec(select(Room).where(Room.id == room))).first() + if db_room is None: + raise HTTPException(404, "Room not found") + else: + await db.delete(db_room) + return None + + +@router.put("/rooms/{room}/users/{user}", tags=["room"]) +async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_db)): + server_room = MultiplayerHubs.rooms[room] + server_room.room.users.append(MultiplayerRoomUser(user_id=user)) + db_room = (await db.exec(select(Room).where(Room.id == room))).first() + if db_room is not None: + db_room.participant_count += 1 await db.commit() + resp = await RoomResp.from_hub(server_room) + await db.refresh(db_room) + for item in db_room.playlist: + resp.playlist.append(await PlaylistResp.from_db(item, ["beatmap"])) + return resp + else: + raise HTTPException(404, "room not found0") + + +class APILeaderboard(BaseModel): + leaderboard: list[ItemAttemptsResp] = Field(default_factory=list) + user_score: ItemAttemptsResp | None = None + + +@router.get("/rooms/{room}/leaderboard", tags=["room"], response_model=APILeaderboard) +async def get_room_leaderboard( + room: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + server_room = MultiplayerHubs.rooms[room] + if not server_room: + raise HTTPException(404, "Room not found") + + aggs = await db.exec( + select(ItemAttemptsCount) + .where(ItemAttemptsCount.room_id == room) + .order_by(col(ItemAttemptsCount.total_score).desc()) + ) + aggs_resp = [] + user_agg = None + for i, agg in enumerate(aggs): + resp = await ItemAttemptsResp.from_db(agg, db) + resp.position = i + 1 + aggs_resp.append(resp) + if agg.user_id == current_user.id: + user_agg = resp + return APILeaderboard( + leaderboard=aggs_resp, + user_score=user_agg, + ) + + +class RoomEvents(BaseModel): + beatmaps: list[BeatmapResp] = Field(default_factory=list) + beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict) + current_playlist_item_id: int = 0 + events: list[MultiplayerEventResp] = Field(default_factory=list) + first_event_id: int = 0 + last_event_id: int = 0 + playlist_items: list[PlaylistResp] = Field(default_factory=list) + room: RoomResp + user: list[UserResp] = Field(default_factory=list) + + +@router.get("/rooms/{room_id}/events", response_model=RoomEvents, tags=["room"]) +async def get_room_events( + room_id: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), + limit: int = Query(100, ge=1, le=1000), + after: int | None = Query(None, ge=0), + before: int | None = Query(None, ge=0), +): + events = ( + await db.exec( + select(MultiplayerEvent) + .where( + MultiplayerEvent.room_id == room_id, + col(MultiplayerEvent.id) > after if after is not None else True, + col(MultiplayerEvent.id) < before if before is not None else True, + ) + .order_by(col(MultiplayerEvent.id).desc()) + .limit(limit) + ) + ).all() + + user_ids = set() + playlist_items = {} + beatmap_ids = set() + + event_resps = [] + first_event_id = 0 + last_event_id = 0 + + current_playlist_item_id = 0 + for event in events: + event_resps.append(MultiplayerEventResp.from_db(event)) + + if event.user_id: + user_ids.add(event.user_id) + + if event.playlist_item_id is not None and ( + playitem := ( + await db.exec( + select(Playlist).where( + Playlist.id == event.playlist_item_id, + Playlist.room_id == room_id, + ) + ) + ).first() + ): + current_playlist_item_id = playitem.id + playlist_items[event.playlist_item_id] = playitem + beatmap_ids.add(playitem.beatmap_id) + scores = await db.exec( + select(Score).where( + Score.playlist_item_id == event.playlist_item_id, + Score.room_id == room_id, + ) + ) + for score in scores: + user_ids.add(score.user_id) + beatmap_ids.add(score.beatmap_id) + + assert event.id is not None + first_event_id = min(first_event_id, event.id) + last_event_id = max(last_event_id, event.id) + + if room := MultiplayerHubs.rooms.get(room_id): + current_playlist_item_id = room.queue.current_item.id + room_resp = await RoomResp.from_hub(room) + else: + room = (await db.exec(select(Room).where(Room.id == room_id))).first() + if room is None: + raise HTTPException(404, "Room not found") + room_resp = await RoomResp.from_db(room) + + users = await db.exec(select(User).where(col(User.id).in_(user_ids))) + user_resps = [await UserResp.from_db(user, db) for user in users] + beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids))) + beatmap_resps = [ + await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps + ] + beatmapset_resps = {} + for beatmap_resp in beatmap_resps: + beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset + + playlist_items_resps = [ + await PlaylistResp.from_db(item) for item in playlist_items.values() + ] + + return RoomEvents( + beatmaps=beatmap_resps, + beatmapsets=beatmapset_resps, + current_playlist_item_id=current_playlist_item_id, + events=event_resps, + first_event_id=first_event_id, + last_event_id=last_event_id, + playlist_items=playlist_items_resps, + room=room_resp, + user=user_resps, + ) diff --git a/app/router/score.py b/app/router/score.py index cc38dcc..5db171d 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,18 +1,41 @@ from __future__ import annotations +from datetime import UTC, datetime +import time + +from app.calculator import clamp from app.database import ( - User as DBUser, + Beatmap, + Playlist, + Room, + Score, + ScoreResp, + ScoreToken, + ScoreTokenResp, + User, +) +from app.database.playlist_attempts import ItemAttemptsCount +from app.database.playlist_best_score import ( + PlaylistBestScore, + get_position, + process_playlist_best_score, +) +from app.database.score import ( + MultiplayerScores, + ScoreAround, + get_leaderboard, + process_score, + process_user, ) -from app.database.beatmap import Beatmap -from app.database.score import Score, ScoreResp, process_score, process_user -from app.database.score_token import ScoreToken, ScoreTokenResp from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user +from app.fetcher import Fetcher from app.models.beatmap import BeatmapRankStatus from app.models.score import ( INT_TO_MODE, GameMode, + LeaderboardType, Rank, SoloScoreSubmissionInfo, ) @@ -21,11 +44,75 @@ from .api_router import router from fastapi import Depends, Form, HTTPException, Query from pydantic import BaseModel -from redis import Redis +from redis.asyncio import Redis from sqlalchemy.orm import joinedload -from sqlmodel import col, select, true +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession +READ_SCORE_TIMEOUT = 10 + + +async def submit_score( + info: SoloScoreSubmissionInfo, + beatmap: int, + token: int, + current_user: User, + db: AsyncSession, + redis: Redis, + fetcher: Fetcher, + item_id: int | None = None, + room_id: int | None = None, +): + if not info.passed: + info.rank = Rank.F + score_token = ( + await db.exec( + select(ScoreToken) + .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] + .where(ScoreToken.id == token) + ) + ).first() + if not score_token or score_token.user_id != current_user.id: + raise HTTPException(status_code=404, detail="Score token not found") + if score_token.score_id: + score = ( + await db.exec( + select(Score).where( + Score.id == score_token.score_id, + Score.user_id == current_user.id, + ) + ) + ).first() + if not score: + raise HTTPException(status_code=404, detail="Score not found") + else: + beatmap_status = ( + await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)) + ).first() + if beatmap_status is None: + raise HTTPException(status_code=404, detail="Beatmap not found") + ranked = beatmap_status in { + BeatmapRankStatus.RANKED, + BeatmapRankStatus.APPROVED, + } + score = await process_score( + current_user, + beatmap, + ranked, + score_token, + info, + fetcher, + db, + redis, + ) + await db.refresh(current_user) + score_id = score.id + score_token.score_id = score_id + await process_user(db, current_user, score, ranked) + score = (await db.exec(select(Score).where(Score.id == score_id))).first() + assert score is not None + return await ScoreResp.from_db(db, score) + class BeatmapScores(BaseModel): scores: list[ScoreResp] @@ -37,44 +124,26 @@ class BeatmapScores(BaseModel): ) async def get_beatmap_scores( beatmap: int, + mode: GameMode, legacy_only: bool = Query(None), # TODO:加入对这个参数的查询 - mode: GameMode | None = Query(None), - # mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询 - type: str = Query(None), - current_user: DBUser = Depends(get_current_user), + mods: list[str] = Query(default_factory=set, alias="mods[]"), + type: LeaderboardType = Query(LeaderboardType.GLOBAL), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), + limit: int = Query(50, ge=1, le=200), ): if legacy_only: raise HTTPException( status_code=404, detail="this server only contains lazer scores" ) - all_scores = ( - await db.exec( - Score.select_clause_unique( - Score.beatmap_id == beatmap, - col(Score.passed).is_(True), - Score.gamemode == mode if mode is not None else true(), - ) - ) - ).all() - - user_score = ( - await db.exec( - Score.select_clause_unique( - Score.beatmap_id == beatmap, - Score.user_id == current_user.id, - col(Score.passed).is_(True), - Score.gamemode == mode if mode is not None else true(), - ) - ) - ).first() + all_scores, user_score = await get_leaderboard( + db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods + ) return BeatmapScores( - scores=[await ScoreResp.from_db(db, score, score.user) for score in all_scores], - userScore=await ScoreResp.from_db(db, user_score, user_score.user) - if user_score - else None, + scores=[await ScoreResp.from_db(db, score) for score in all_scores], + userScore=await ScoreResp.from_db(db, user_score) if user_score else None, ) @@ -94,7 +163,7 @@ async def get_user_beatmap_score( legacy_only: bool = Query(None), mode: str = Query(None), mods: str = Query(None), # TODO:添加mods筛选 - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if legacy_only: @@ -103,7 +172,7 @@ async def get_user_beatmap_score( ) user_score = ( await db.exec( - Score.select_clause(True) + select(Score) .where( Score.gamemode == mode if mode is not None else True, Score.beatmap_id == beatmap, @@ -118,9 +187,10 @@ async def get_user_beatmap_score( status_code=404, detail=f"Cannot find user {user}'s score on this beatmap" ) else: + resp = await ScoreResp.from_db(db, user_score) return BeatmapUserScore( - position=user_score.position if user_score.position is not None else 0, - score=await ScoreResp.from_db(db, user_score, user_score.user), + position=resp.rank_global or 0, + score=resp, ) @@ -134,7 +204,7 @@ async def get_user_all_beatmap_scores( user: int, legacy_only: bool = Query(None), ruleset: str = Query(None), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if legacy_only: @@ -143,7 +213,7 @@ async def get_user_all_beatmap_scores( ) all_user_scores = ( await db.exec( - Score.select_clause() + select(Score) .where( Score.gamemode == ruleset if ruleset is not None else True, Score.beatmap_id == beatmap, @@ -153,9 +223,7 @@ async def get_user_all_beatmap_scores( ) ).all() - return [ - await ScoreResp.from_db(db, score, current_user) for score in all_user_scores - ] + return [await ScoreResp.from_db(db, score) for score in all_user_scores] @router.post( @@ -166,9 +234,10 @@ async def create_solo_score( version_hash: str = Form(""), beatmap_hash: str = Form(), ruleset_id: int = Form(..., ge=0, le=3), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): + assert current_user.id async with db: score_token = ScoreToken( user_id=current_user.id, @@ -190,64 +259,275 @@ async def submit_solo_score( beatmap: int, token: int, info: SoloScoreSubmissionInfo, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), fetcher=Depends(get_fetcher), ): - if not info.passed: - info.rank = Rank.F - async with db: - score_token = ( - await db.exec( - select(ScoreToken) - .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] - .where(ScoreToken.id == token, ScoreToken.user_id == current_user.id) + return await submit_score(info, beatmap, token, current_user, db, redis, fetcher) + + +@router.post( + "/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp +) +async def create_playlist_score( + room_id: int, + playlist_id: int, + beatmap_id: int = Form(), + beatmap_hash: str = Form(), + ruleset_id: int = Form(..., ge=0, le=3), + version_hash: str = Form(""), + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), +): + room = await session.get(Room, room_id) + if not room: + raise HTTPException(status_code=404, detail="Room not found") + if room.ended_at and room.ended_at < datetime.now(UTC): + raise HTTPException(status_code=400, detail="Room has ended") + item = ( + await session.exec( + select(Playlist).where( + Playlist.id == playlist_id, Playlist.room_id == room_id ) - ).first() - if not score_token or score_token.user_id != current_user.id: - raise HTTPException(status_code=404, detail="Score token not found") - if score_token.score_id: - score = ( - await db.exec( - select(Score) - .options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType] - .where( - Score.id == score_token.score_id, - Score.user_id == current_user.id, + ) + ).first() + if not item: + raise HTTPException(status_code=404, detail="Playlist not found") + + # validate + if not item.freestyle: + if item.ruleset_id != ruleset_id: + raise HTTPException( + status_code=400, detail="Ruleset mismatch in playlist item" + ) + if item.beatmap_id != beatmap_id: + raise HTTPException( + status_code=400, detail="Beatmap ID mismatch in playlist item" + ) + agg = await session.exec( + select(ItemAttemptsCount).where( + ItemAttemptsCount.room_id == room_id, + ItemAttemptsCount.user_id == current_user.id, + ) + ) + agg = agg.first() + if agg and room.max_attempts and agg.attempts >= room.max_attempts: + raise HTTPException( + status_code=422, + detail="You have reached the maximum attempts for this room", + ) + if item.expired: + raise HTTPException(status_code=400, detail="Playlist item has expired") + if item.played_at: + raise HTTPException( + status_code=400, detail="Playlist item has already been played" + ) + # 这里应该不用验证mod了吧。。。 + + score_token = ScoreToken( + user_id=current_user.id, + beatmap_id=beatmap_id, + ruleset_id=INT_TO_MODE[ruleset_id], + playlist_item_id=playlist_id, + ) + session.add(score_token) + await session.commit() + await session.refresh(score_token) + return ScoreTokenResp.from_db(score_token) + + +@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}") +async def submit_playlist_score( + room_id: int, + playlist_id: int, + token: int, + info: SoloScoreSubmissionInfo, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), + fetcher: Fetcher = Depends(get_fetcher), +): + item = ( + await session.exec( + select(Playlist).where( + Playlist.id == playlist_id, Playlist.room_id == room_id + ) + ) + ).first() + if not item: + raise HTTPException(status_code=404, detail="Playlist item not found") + + user_id = current_user.id + score_resp = await submit_score( + info, + item.beatmap_id, + token, + current_user, + session, + redis, + fetcher, + item.id, + room_id, + ) + await process_playlist_best_score( + room_id, + playlist_id, + user_id, + score_resp.id, + score_resp.total_score, + session, + redis, + ) + await ItemAttemptsCount.get_or_create(room_id, user_id, session) + return score_resp + + +class IndexedScoreResp(MultiplayerScores): + total: int + user_score: ScoreResp | None = None + + +@router.get( + "/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=IndexedScoreResp +) +async def index_playlist_scores( + room_id: int, + playlist_id: int, + limit: int = 50, + cursor: int = Query(2000000, alias="cursor[total_score]"), + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), +): + limit = clamp(limit, 1, 50) + + scores = ( + await session.exec( + select(PlaylistBestScore) + .where( + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + PlaylistBestScore.total_score < cursor, + ) + .order_by(col(PlaylistBestScore.total_score).desc()) + .limit(limit + 1) + ) + ).all() + has_more = len(scores) > limit + if has_more: + scores = scores[:-1] + + user_score = None + score_resp = [await ScoreResp.from_db(session, score.score) for score in scores] + for score in score_resp: + score.position = await get_position(room_id, playlist_id, score.id, session) + if score.user_id == current_user.id: + user_score = score + resp = IndexedScoreResp( + scores=score_resp, + user_score=user_score, + total=len(scores), + params={ + "limit": limit, + }, + ) + if has_more: + resp.cursor = { + "total_score": scores[-1].total_score, + } + return resp + + +@router.get( + "/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}", + response_model=ScoreResp, +) +async def show_playlist_score( + room_id: int, + playlist_id: int, + score_id: int, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + start_time = time.time() + score_record = None + completed = False + while time.time() - start_time < READ_SCORE_TIMEOUT: + if score_record is None: + score_record = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.score_id == score_id, + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, ) ) ).first() - if not score: - raise HTTPException(status_code=404, detail="Score not found") - else: - beatmap_status = ( - await db.exec( - select(Beatmap.beatmap_status).where(Beatmap.id == beatmap) + if completed_players := await redis.get( + f"multiplayer:{room_id}:gameplay:players" + ): + completed = completed_players == "0" + if score_record and completed: + break + if not score_record: + raise HTTPException(status_code=404, detail="Score not found") + resp = await ScoreResp.from_db(session, score_record.score) + resp.position = await get_position(room_id, playlist_id, score_id, session) + if completed: + scores = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, ) - ).first() - if beatmap_status is None: - raise HTTPException(status_code=404, detail="Beatmap not found") - ranked = beatmap_status in { - BeatmapRankStatus.RANKED, - BeatmapRankStatus.APPROVED, - } - score = await process_score( - current_user, - beatmap, - ranked, - score_token, - info, - fetcher, - db, - redis, ) - await db.refresh(current_user) - score_id = score.id - score_token.score_id = score_id - await process_user(db, current_user, score, ranked) - score = ( - await db.exec(Score.select_clause().where(Score.id == score_id)) - ).first() - assert score is not None - return await ScoreResp.from_db(db, score, current_user) + ).all() + higher_scores = [] + lower_scores = [] + for score in scores: + if score.total_score > resp.total_score: + higher_scores.append(await ScoreResp.from_db(session, score.score)) + elif score.total_score < resp.total_score: + lower_scores.append(await ScoreResp.from_db(session, score.score)) + resp.scores_around = ScoreAround( + higher=MultiplayerScores(scores=higher_scores), + lower=MultiplayerScores(scores=lower_scores), + ) + + return resp + + +@router.get( + "rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}", + response_model=ScoreResp, +) +async def get_user_playlist_score( + room_id: int, + playlist_id: int, + user_id: int, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), +): + score_record = None + start_time = time.time() + while time.time() - start_time < READ_SCORE_TIMEOUT: + score_record = ( + await session.exec( + select(PlaylistBestScore).where( + PlaylistBestScore.user_id == user_id, + PlaylistBestScore.playlist_id == playlist_id, + PlaylistBestScore.room_id == room_id, + ) + ) + ).first() + if score_record: + break + if not score_record: + raise HTTPException(status_code=404, detail="Score not found") + + resp = await ScoreResp.from_db(session, score_record.score) + resp.position = await get_position( + room_id, playlist_id, score_record.score_id, session + ) + return resp diff --git a/app/router/user.py b/app/router/user.py index 6e169c3..649f1d4 100644 --- a/app/router/user.py +++ b/app/router/user.py @@ -1,12 +1,9 @@ from __future__ import annotations -from typing import Literal - -from app.database import User as DBUser +from app.database import User, UserResp +from app.database.lazer_user import SEARCH_INCLUDED from app.dependencies.database import get_db -from app.models.score import INT_TO_MODE -from app.models.user import User as ApiUser -from app.utils import convert_db_user_to_api_user +from app.models.score import GameMode from .api_router import router @@ -17,28 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql.expression import col -# ---------- Shared Utility ---------- -async def get_user_by_lookup( - db: AsyncSession, lookup: str, key: str = "id" -) -> DBUser | None: - """根据查找方式获取用户""" - if key == "id": - try: - user_id = int(lookup) - result = await db.exec(select(DBUser).where(DBUser.id == user_id)) - return result.first() - except ValueError: - return None - elif key == "username": - result = await db.exec(select(DBUser).where(DBUser.name == lookup)) - return result.first() - else: - return None - - -# ---------- Batch Users ---------- class BatchUserResponse(BaseModel): - users: list[ApiUser] + users: list[UserResp] @router.get("/users", response_model=BatchUserResponse) @@ -51,75 +28,44 @@ async def get_users( ): if user_ids: searched_users = ( - await session.exec( - DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids)) - ) + await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids))) ).all() else: - searched_users = ( - await session.exec(DBUser.all_select_clause().limit(50)) - ).all() + searched_users = (await session.exec(select(User).limit(50))).all() return BatchUserResponse( users=[ - await convert_db_user_to_api_user( - searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value + await UserResp.from_db( + searched_user, + session, + include=SEARCH_INCLUDED, ) for searched_user in searched_users ] ) -# # ---------- Individual User ---------- -# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser) -# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser) -# async def get_user_with_mode( -# user_lookup: str, -# mode: Literal["osu", "taiko", "fruits", "mania"], -# key: Literal["id", "username"] = Query("id"), -# current_user: DBUser = Depends(get_current_user), -# db: AsyncSession = Depends(get_db), -# ): -# """获取指定游戏模式的用户信息""" -# user = await get_user_by_lookup(db, user_lookup, key) -# if not user: -# raise HTTPException(status_code=404, detail="User not found") - -# return await convert_db_user_to_api_user(user, mode) - - -# @router.get("/users/{user_lookup}", response_model=ApiUser) -# @router.get("/users/{user_lookup}/", response_model=ApiUser) -# async def get_user_default( -# user_lookup: str, -# key: Literal["id", "username"] = Query("id"), -# current_user: DBUser = Depends(get_current_user), -# db: AsyncSession = Depends(get_db), -# ): -# """获取用户信息(默认使用osu模式,但包含所有模式的统计信息)""" -# user = await get_user_by_lookup(db, user_lookup, key) -# if not user: -# raise HTTPException(status_code=404, detail="User not found") - -# return await convert_db_user_to_api_user(user, "osu") - - -@router.get("/users/{user}/{ruleset}", response_model=ApiUser) -@router.get("/users/{user}/", response_model=ApiUser) -@router.get("/users/{user}", response_model=ApiUser) +@router.get("/users/{user}/{ruleset}", response_model=UserResp) +@router.get("/users/{user}/", response_model=UserResp) +@router.get("/users/{user}", response_model=UserResp) async def get_user_info( user: str, - ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", + ruleset: GameMode | None = None, session: AsyncSession = Depends(get_db), ): searched_user = ( await session.exec( - DBUser.all_select_clause().where( - DBUser.id == int(user) + select(User).where( + User.id == int(user) if user.isdigit() - else DBUser.name == user.removeprefix("@") + else User.username == user.removeprefix("@") ) ) ).first() if not searched_user: raise HTTPException(404, detail="User not found") - return await convert_db_user_to_api_user(searched_user, ruleset=ruleset) + return await UserResp.from_db( + searched_user, + session, + include=SEARCH_INCLUDED, + ruleset=ruleset, + ) diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index 276140f..4bab451 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -6,9 +6,9 @@ import time from typing import Any from app.config import settings +from app.exception import InvokeException from app.log import logger from app.models.signalr import UserState -from app.signalr.exception import InvokeException from app.signalr.packet import ( ClosePacket, CompletionPacket, @@ -21,7 +21,6 @@ from app.signalr.store import ResultStore from app.signalr.utils import get_signature from fastapi import WebSocket -from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect @@ -49,7 +48,7 @@ class Client: self.connection_id = connection_id self.connection_token = connection_token self.connection = connection - self.procotol = protocol + self.protocol = protocol self._listen_task: asyncio.Task | None = None self._ping_task: asyncio.Task | None = None self._store = ResultStore() @@ -62,14 +61,14 @@ class Client: return int(self.connection_id) async def send_packet(self, packet: Packet): - await self.connection.send_bytes(self.procotol.encode(packet)) + await self.connection.send_bytes(self.protocol.encode(packet)) async def receive_packets(self) -> list[Packet]: message = await self.connection.receive() d = message.get("bytes") or message.get("text", "").encode() if not d: return [] - return self.procotol.decode(d) + return self.protocol.decode(d) async def _ping(self): while True: @@ -263,10 +262,9 @@ class Hub[TState: UserState]: for name, param in signature.parameters.items(): if name == "self" or param.annotation is Client: continue - if issubclass(param.annotation, BaseModel): - call_params.append(param.annotation.model_validate(args.pop(0))) - else: - call_params.append(args.pop(0)) + call_params.append( + client.protocol.validate_object(args.pop(0), param.annotation) + ) return await method_(client, *call_params) async def call(self, client: Client, method: str, *args: Any) -> Any: diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 821d831..08ee035 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -2,15 +2,15 @@ from __future__ import annotations import asyncio from collections.abc import Coroutine +from datetime import UTC, datetime from typing import override -from app.database.relationship import Relationship, RelationshipType -from app.dependencies.database import engine +from app.database import Relationship, RelationshipType, User +from app.dependencies.database import engine, get_redis from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity from .hub import Client, Hub -from pydantic import TypeAdapter from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -30,7 +30,7 @@ class MetadataHub(Hub[MetadataClientState]): ) -> set[Coroutine]: if store is not None and not store.pushable: return set() - data = store.to_dict() if store else None + data = store.for_push if store else None return { self.broadcast_group_call( self.online_presence_watchers_group(), @@ -54,6 +54,18 @@ class MetadataHub(Hub[MetadataClientState]): async def _clean_state(self, state: MetadataClientState) -> None: if state.pushable: await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None)) + redis = get_redis() + if await redis.exists(f"metadata:online:{state.connection_id}"): + await redis.delete(f"metadata:online:{state.connection_id}") + async with AsyncSession(engine) as session: + async with session.begin(): + user = ( + await session.exec( + select(User).where(User.id == int(state.connection_id)) + ) + ).one() + user.last_visit = datetime.now(UTC) + await session.commit() @override def create_state(self, client: Client) -> MetadataClientState: @@ -89,10 +101,14 @@ class MetadataHub(Hub[MetadataClientState]): self.friend_presence_watchers_group(friend_id), "FriendPresenceUpdated", friend_id, - friend_state.to_dict(), + friend_state.for_push + if friend_state.pushable + else None, ) ) await asyncio.gather(*tasks) + redis = get_redis() + await redis.set(f"metadata:online:{user_id}", "") async def UpdateStatus(self, client: Client, status: int) -> None: status_ = OnlineStatus(status) @@ -107,27 +123,24 @@ class MetadataHub(Hub[MetadataClientState]): client, "UserPresenceUpdated", user_id, - store.to_dict(), + store.for_push, ) ) await asyncio.gather(*tasks) - async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None: + async def UpdateActivity( + self, client: Client, activity: UserActivity | None + ) -> None: user_id = int(client.connection_id) - activity = ( - TypeAdapter(UserActivity).validate_python(activity_dict) - if activity_dict - else None - ) store = self.get_or_create_state(client) - store.user_activity = activity + store.activity = activity tasks = self.broadcast_tasks(user_id, store) tasks.add( self.call_noblock( client, "UserPresenceUpdated", user_id, - store.to_dict(), + store.for_push, ) ) await asyncio.gather(*tasks) @@ -139,7 +152,7 @@ class MetadataHub(Hub[MetadataClientState]): client, "UserPresenceUpdated", user_id, - store.to_dict(), + store, ) for user_id, store in self.state.items() if store.pushable diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 72b4a52..3688efa 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -1,6 +1,1138 @@ from __future__ import annotations -from .hub import Hub +import asyncio +from datetime import UTC, datetime, timedelta +from typing import override + +from app.database import Room +from app.database.beatmap import Beatmap +from app.database.lazer_user import User +from app.database.multiplayer_event import MultiplayerEvent +from app.database.playlists import Playlist +from app.database.relationship import Relationship, RelationshipType +from app.dependencies.database import engine, get_redis +from app.exception import InvokeException +from app.log import logger +from app.models.mods import APIMod +from app.models.multiplayer_hub import ( + BeatmapAvailability, + ForceGameplayStartCountdown, + GameplayAbortReason, + MatchRequest, + MatchServerEvent, + MatchStartCountdown, + MatchStartedEventDetail, + MultiplayerClientState, + MultiplayerRoom, + MultiplayerRoomSettings, + MultiplayerRoomUser, + PlaylistItem, + ServerMultiplayerRoom, + ServerShuttingDownCountdown, + StartMatchCountdownRequest, + StopCountdownRequest, +) +from app.models.room import ( + DownloadState, + MatchType, + MultiplayerRoomState, + MultiplayerUserState, + RoomCategory, + RoomStatus, +) +from app.models.score import GameMode + +from .hub import Client, Hub + +from sqlalchemy import update +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + +GAMEPLAY_LOAD_TIMEOUT = 30 -class MultiplayerHub(Hub): ... +class MultiplayerEventLogger: + def __init__(self): + pass + + async def log_event(self, event: MultiplayerEvent): + try: + async with AsyncSession(engine) as session: + session.add(event) + await session.commit() + except Exception as e: + logger.warning(f"Failed to log multiplayer room event to database: {e}") + + async def room_created(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="room_created", + ) + await self.log_event(event) + + async def room_disbanded(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="room_disbanded", + ) + await self.log_event(event) + + async def player_joined(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="player_joined", + ) + await self.log_event(event) + + async def player_left(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="player_left", + ) + await self.log_event(event) + + async def player_kicked(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="player_kicked", + ) + await self.log_event(event) + + async def host_changed(self, room_id: int, user_id: int): + event = MultiplayerEvent( + room_id=room_id, + user_id=user_id, + event_type="host_changed", + ) + await self.log_event(event) + + async def game_started( + self, room_id: int, playlist_item_id: int, details: MatchStartedEventDetail + ): + event = MultiplayerEvent( + room_id=room_id, + playlist_item_id=playlist_item_id, + event_type="game_started", + event_detail=details, # pyright: ignore[reportArgumentType] + ) + await self.log_event(event) + + async def game_aborted(self, room_id: int, playlist_item_id: int): + event = MultiplayerEvent( + room_id=room_id, + playlist_item_id=playlist_item_id, + event_type="game_aborted", + ) + await self.log_event(event) + + async def game_completed(self, room_id: int, playlist_item_id: int): + event = MultiplayerEvent( + room_id=room_id, + playlist_item_id=playlist_item_id, + event_type="game_completed", + ) + await self.log_event(event) + + +class MultiplayerHub(Hub[MultiplayerClientState]): + @override + def __init__(self): + super().__init__() + self.rooms: dict[int, ServerMultiplayerRoom] = {} + self.event_logger = MultiplayerEventLogger() + + @staticmethod + def group_id(room: int) -> str: + return f"room:{room}" + + @override + def create_state(self, client: Client) -> MultiplayerClientState: + return MultiplayerClientState( + connection_id=client.connection_id, + connection_token=client.connection_token, + ) + + @override + async def _clean_state(self, state: MultiplayerClientState): + user_id = int(state.connection_id) + if state.room_id != 0 and state.room_id in self.rooms: + server_room = self.rooms[state.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == user_id), None) + if user is not None: + await self.make_user_leave( + self.get_client_by_id(str(user_id)), server_room, user + ) + + async def CreateRoom(self, client: Client, room: MultiplayerRoom): + logger.info(f"[MultiplayerHub] {client.user_id} creating room") + store = self.get_or_create_state(client) + if store.room_id != 0: + raise InvokeException("You are already in a room") + async with AsyncSession(engine) as session: + async with session: + db_room = Room( + name=room.settings.name, + category=RoomCategory.NORMAL, + type=room.settings.match_type, + queue_mode=room.settings.queue_mode, + auto_skip=room.settings.auto_skip, + auto_start_duration=int( + room.settings.auto_start_duration.total_seconds() + ), + host_id=client.user_id, + status=RoomStatus.IDLE, + ) + session.add(db_room) + await session.commit() + await session.refresh(db_room) + item = room.playlist[0] + item.owner_id = client.user_id + room.room_id = db_room.id + starts_at = db_room.starts_at or datetime.now(UTC) + await Playlist.add_to_db(item, db_room.id, session) + server_room = ServerMultiplayerRoom( + room=room, + category=RoomCategory.NORMAL, + start_at=starts_at, + hub=self, + ) + self.rooms[room.room_id] = server_room + await server_room.set_handler() + await self.event_logger.room_created(room.room_id, client.user_id) + return await self.JoinRoomWithPassword( + client, room.room_id, room.settings.password + ) + + async def JoinRoom(self, client: Client, room_id: int): + return self.JoinRoomWithPassword(client, room_id, "") + + async def JoinRoomWithPassword(self, client: Client, room_id: int, password: str): + logger.info(f"[MultiplayerHub] {client.user_id} joining room {room_id}") + store = self.get_or_create_state(client) + if store.room_id != 0: + raise InvokeException("You are already in a room") + user = MultiplayerRoomUser(user_id=client.user_id) + if room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[room_id] + room = server_room.room + for u in room.users: + if u.user_id == client.user_id: + raise InvokeException("You are already in this room") + if room.settings.password != password: + raise InvokeException("Incorrect password") + if room.host is None: + # from CreateRoom + room.host = user + store.room_id = room_id + await self.broadcast_group_call(self.group_id(room_id), "UserJoined", user) + room.users.append(user) + self.add_to_group(client, self.group_id(room_id)) + await server_room.match_type_handler.handle_join(user) + await self.event_logger.player_joined(room_id, user.user_id) + return room + + async def ChangeBeatmapAvailability( + self, client: Client, beatmap_availability: BeatmapAvailability + ): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + availability = user.availability + if ( + availability.state == beatmap_availability.state + and availability.download_progress == beatmap_availability.download_progress + ): + return + user.availability = beatmap_availability + await self.broadcast_group_call( + self.group_id(store.room_id), + "UserBeatmapAvailabilityChanged", + user.user_id, + (beatmap_availability), + ) + + async def AddPlaylistItem(self, client: Client, item: PlaylistItem): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.add_item( + item, + user, + ) + + async def EditPlaylistItem(self, client: Client, item: PlaylistItem): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.edit_item( + item, + user, + ) + + async def RemovePlaylistItem(self, client: Client, item_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.remove_item( + item_id, + user, + ) + + async def setting_changed(self, room: ServerMultiplayerRoom, beatmap_changed: bool): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "SettingsChanged", + (room.room.settings), + ) + + async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemAdded", + (item), + ) + + async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemRemoved", + item_id, + ) + + async def playlist_changed( + self, room: ServerMultiplayerRoom, item: PlaylistItem, beatmap_changed: bool + ): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemChanged", + (item), + ) + + async def ChangeUserStyle( + self, client: Client, beatmap_id: int | None, ruleset_id: int | None + ): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.change_user_style( + beatmap_id, + ruleset_id, + server_room, + user, + ) + + async def validate_styles(self, room: ServerMultiplayerRoom): + if not room.queue.current_item.freestyle: + for user in room.room.users: + await self.change_user_style( + None, + None, + room, + user, + ) + async with AsyncSession(engine) as session: + beatmap = await session.get(Beatmap, room.queue.current_item.beatmap_id) + if beatmap is None: + raise InvokeException("Beatmap not found") + beatmap_ids = ( + await session.exec( + select(Beatmap.id, Beatmap.mode).where( + Beatmap.beatmapset_id == beatmap.beatmapset_id, + ) + ) + ).all() + for user in room.room.users: + beatmap_id = user.beatmap_id + ruleset_id = user.ruleset_id + user_beatmap = next( + (b for b in beatmap_ids if b[0] == beatmap_id), + None, + ) + if beatmap_id is not None and user_beatmap is None: + beatmap_id = None + beatmap_ruleset = user_beatmap[1] if user_beatmap else beatmap.mode + if ( + ruleset_id is not None + and beatmap_ruleset != GameMode.OSU + and ruleset_id != beatmap_ruleset + ): + ruleset_id = None + await self.change_user_style( + beatmap_id, + ruleset_id, + room, + user, + ) + + for user in room.room.users: + is_valid, valid_mods = room.queue.current_item.validate_user_mods( + user, user.mods + ) + if not is_valid: + await self.change_user_mods(valid_mods, room, user) + + async def change_user_style( + self, + beatmap_id: int | None, + ruleset_id: int | None, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + ): + if user.beatmap_id == beatmap_id and user.ruleset_id == ruleset_id: + return + + if beatmap_id is not None or ruleset_id is not None: + if not room.queue.current_item.freestyle: + raise InvokeException("Current item does not allow free user styles.") + + async with AsyncSession(engine) as session: + item_beatmap = await session.get( + Beatmap, room.queue.current_item.beatmap_id + ) + if item_beatmap is None: + raise InvokeException("Item beatmap not found") + + user_beatmap = ( + item_beatmap + if beatmap_id is None + else await session.get(Beatmap, beatmap_id) + ) + + if user_beatmap is None: + raise InvokeException("Invalid beatmap selected.") + + if user_beatmap.beatmapset_id != item_beatmap.beatmapset_id: + raise InvokeException( + "Selected beatmap is not from the same beatmap set." + ) + + if ( + ruleset_id is not None + and user_beatmap.mode != GameMode.OSU + and ruleset_id != user_beatmap.mode + ): + raise InvokeException( + "Selected ruleset is not supported for the given beatmap." + ) + + user.beatmap_id = beatmap_id + user.ruleset_id = ruleset_id + + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserStyleChanged", + user.user_id, + beatmap_id, + ruleset_id, + ) + + async def ChangeUserMods(self, client: Client, new_mods: list[APIMod]): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.change_user_mods(new_mods, server_room, user) + + async def change_user_mods( + self, + new_mods: list[APIMod], + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + ): + is_valid, valid_mods = room.queue.current_item.validate_user_mods( + user, new_mods + ) + if not is_valid: + incompatible_mods = [ + mod["acronym"] for mod in new_mods if mod not in valid_mods + ] + raise InvokeException( + f"Incompatible mods were selected: {','.join(incompatible_mods)}" + ) + + if user.mods == valid_mods: + return + + user.mods = valid_mods + + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserModsChanged", + user.user_id, + valid_mods, + ) + + async def validate_user_stare( + self, + room: ServerMultiplayerRoom, + old: MultiplayerUserState, + new: MultiplayerUserState, + ): + match new: + case MultiplayerUserState.IDLE: + if old.is_playing: + raise InvokeException( + "Cannot return to idle without aborting gameplay." + ) + case MultiplayerUserState.READY: + if old != MultiplayerUserState.IDLE: + raise InvokeException(f"Cannot change state from {old} to {new}") + if room.queue.current_item.expired: + raise InvokeException( + "Cannot ready up while all items have been played." + ) + case MultiplayerUserState.WAITING_FOR_LOAD: + raise InvokeException("Cannot change state from {old} to {new}") + case MultiplayerUserState.LOADED: + if old != MultiplayerUserState.WAITING_FOR_LOAD: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.READY_FOR_GAMEPLAY: + if old != MultiplayerUserState.LOADED: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.PLAYING: + raise InvokeException("State is managed by the server.") + case MultiplayerUserState.FINISHED_PLAY: + if old != MultiplayerUserState.PLAYING: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.RESULTS: + raise InvokeException("Cannot change state from {old} to {new}") + case MultiplayerUserState.SPECTATING: + if old not in (MultiplayerUserState.IDLE, MultiplayerUserState.READY): + raise InvokeException(f"Cannot change state from {old} to {new}") + + async def ChangeState(self, client: Client, state: MultiplayerUserState): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + if user.state == state: + return + match state: + case MultiplayerUserState.IDLE: + if user.state.is_playing: + return + case MultiplayerUserState.LOADED | MultiplayerUserState.READY_FOR_GAMEPLAY: + if not user.state.is_playing: + return + await self.validate_user_stare( + server_room, + user.state, + state, + ) + await self.change_user_state(server_room, user, state) + if state == MultiplayerUserState.SPECTATING and ( + room.state == MultiplayerRoomState.PLAYING + or room.state == MultiplayerRoomState.WAITING_FOR_LOAD + ): + await self.call_noblock(client, "LoadRequested") + await self.update_room_state(server_room) + + async def change_user_state( + self, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + state: MultiplayerUserState, + ): + user.state = state + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserStateChanged", + user.user_id, + user.state, + ) + + async def update_room_state(self, room: ServerMultiplayerRoom): + match room.room.state: + case MultiplayerRoomState.OPEN: + if room.room.settings.auto_start_enabled: + if ( + not room.queue.current_item.expired + and any( + u.state == MultiplayerUserState.READY + for u in room.room.users + ) + and not any( + isinstance(countdown, MatchStartCountdown) + for countdown in room.room.active_countdowns + ) + ): + await room.start_countdown( + MatchStartCountdown( + time_remaining=room.room.settings.auto_start_duration + ), + self.start_match, + ) + case MultiplayerRoomState.WAITING_FOR_LOAD: + played_count = len( + [True for user in room.room.users if user.state.is_playing] + ) + ready_count = len( + [ + True + for user in room.room.users + if user.state == MultiplayerUserState.READY_FOR_GAMEPLAY + ] + ) + if played_count == ready_count: + await self.start_gameplay(room) + case MultiplayerRoomState.PLAYING: + if all( + u.state != MultiplayerUserState.PLAYING for u in room.room.users + ): + any_user_finished_playing = False + for u in filter( + lambda u: u.state == MultiplayerUserState.FINISHED_PLAY, + room.room.users, + ): + any_user_finished_playing = True + await self.change_user_state( + room, u, MultiplayerUserState.RESULTS + ) + await self.change_room_state(room, MultiplayerRoomState.OPEN) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "ResultsReady", + ) + if any_user_finished_playing: + await self.event_logger.game_completed( + room.room.room_id, + room.queue.current_item.id, + ) + else: + await self.event_logger.game_aborted( + room.room.room_id, + room.queue.current_item.id, + ) + await room.queue.finish_current_item() + + async def change_room_state( + self, room: ServerMultiplayerRoom, state: MultiplayerRoomState + ): + room.room.state = state + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "RoomStateChanged", + state, + ) + + async def StartMatch(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + # Check host state - host must be ready or spectating + if room.host.state not in ( + MultiplayerUserState.SPECTATING, + MultiplayerUserState.READY, + ): + raise InvokeException("Can't start match when the host is not ready.") + + # Check if any users are ready + if all(u.state != MultiplayerUserState.READY for u in room.users): + raise InvokeException("Can't start match when no users are ready.") + + await self.start_match(server_room) + + async def start_match(self, room: ServerMultiplayerRoom): + if room.room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Can't start match when already in a running state.") + if room.queue.current_item.expired: + raise InvokeException("Current playlist item is expired") + ready_users = [ + u + for u in room.room.users + if u.availability.state == DownloadState.LOCALLY_AVAILABLE + and ( + u.state == MultiplayerUserState.READY + or u.state == MultiplayerUserState.IDLE + ) + ] + await asyncio.gather( + *[ + self.change_user_state(room, u, MultiplayerUserState.WAITING_FOR_LOAD) + for u in ready_users + ] + ) + await self.change_room_state( + room, + MultiplayerRoomState.WAITING_FOR_LOAD, + ) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "LoadRequested", + ) + await room.start_countdown( + ForceGameplayStartCountdown( + time_remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT) + ), + self.start_gameplay, + ) + await self.event_logger.game_started( + room.room.room_id, + room.queue.current_item.id, + details=room.match_type_handler.get_details(), + ) + + async def start_gameplay(self, room: ServerMultiplayerRoom): + if room.room.state != MultiplayerRoomState.WAITING_FOR_LOAD: + raise InvokeException("Room is not ready for gameplay") + if room.queue.current_item.expired: + raise InvokeException("Current playlist item is expired") + playing = False + played_user = 0 + for user in room.room.users: + client = self.get_client_by_id(str(user.user_id)) + if client is None: + continue + + if user.state in ( + MultiplayerUserState.READY_FOR_GAMEPLAY, + MultiplayerUserState.LOADED, + ): + playing = True + played_user += 1 + await self.change_user_state(room, user, MultiplayerUserState.PLAYING) + await self.call_noblock(client, "GameplayStarted") + elif user.state == MultiplayerUserState.WAITING_FOR_LOAD: + await self.change_user_state(room, user, MultiplayerUserState.IDLE) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "GameplayAborted", + GameplayAbortReason.LOAD_TOOK_TOO_LONG, + ) + await self.change_room_state( + room, + (MultiplayerRoomState.PLAYING if playing else MultiplayerRoomState.OPEN), + ) + if playing: + redis = get_redis() + await redis.set( + f"multiplayer:{room.room.room_id}:gameplay:players", + played_user, + ex=3600, + ) + + async def send_match_event( + self, room: ServerMultiplayerRoom, event: MatchServerEvent + ): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "MatchEvent", + event, + ) + + async def make_user_leave( + self, + client: Client, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + kicked: bool = False, + ): + self.remove_from_group(client, self.group_id(room.room.room_id)) + room.room.users.remove(user) + + if len(room.room.users) == 0: + await self.end_room(room) + await self.update_room_state(room) + if ( + len(room.room.users) != 0 + and room.room.host + and room.room.host.user_id == user.user_id + ): + next_host = room.room.users[0] + await self.set_host(room, next_host) + + if kicked: + await self.call_noblock(client, "UserKicked", user) + await self.broadcast_group_call( + self.group_id(room.room.room_id), "UserKicked", user + ) + else: + await self.broadcast_group_call( + self.group_id(room.room.room_id), "UserLeft", user + ) + + target_store = self.state.get(user.user_id) + if target_store: + target_store.room_id = 0 + + async def end_room(self, room: ServerMultiplayerRoom): + assert room.room.host + async with AsyncSession(engine) as session: + await session.execute( + update(Room) + .where(col(Room.id) == room.room.room_id) + .values( + name=room.room.settings.name, + ended_at=datetime.now(UTC), + type=room.room.settings.match_type, + queue_mode=room.room.settings.queue_mode, + auto_skip=room.room.settings.auto_skip, + auto_start_duration=int( + room.room.settings.auto_start_duration.total_seconds() + ), + host_id=room.room.host.user_id, + ) + ) + await self.event_logger.room_disbanded( + room.room.room_id, + room.room.host.user_id, + ) + del self.rooms[room.room.room_id] + + async def LeaveRoom(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + return + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.event_logger.player_left( + room.room_id, + user.user_id, + ) + await self.make_user_leave(client, server_room, user) + + async def KickUser(self, client: Client, user_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + if user_id == client.user_id: + raise InvokeException("Can't kick self") + + user = next((u for u in room.users if u.user_id == user_id), None) + if user is None: + raise InvokeException("User not found in this room") + + await self.event_logger.player_kicked( + room.room_id, + user.user_id, + ) + target_client = self.get_client_by_id(str(user.user_id)) + if target_client is None: + return + await self.make_user_leave(target_client, server_room, user, kicked=True) + + async def set_host(self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser): + room.room.host = user + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "HostChanged", + user.user_id, + ) + + async def TransferHost(self, client: Client, user_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + new_host = next((u for u in room.users if u.user_id == user_id), None) + if new_host is None: + raise InvokeException("User not found in this room") + await self.event_logger.host_changed( + room.room_id, + new_host.user_id, + ) + await self.set_host(server_room, new_host) + + async def AbortGameplay(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + if not user.state.is_playing: + raise InvokeException("Cannot abort gameplay while not in a gameplay state") + + await self.change_user_state( + server_room, + user, + MultiplayerUserState.IDLE, + ) + await self.update_room_state(server_room) + + async def AbortMatch(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + if ( + room.state != MultiplayerRoomState.PLAYING + and room.state != MultiplayerRoomState.WAITING_FOR_LOAD + ): + raise InvokeException("Cannot abort a match that hasn't started.") + + await asyncio.gather( + *[ + self.change_user_state(server_room, u, MultiplayerUserState.IDLE) + for u in room.users + if u.state.is_playing + ] + ) + await self.broadcast_group_call( + self.group_id(room.room_id), + "GameplayAborted", + GameplayAbortReason.HOST_ABORTED, + ) + await self.update_room_state(server_room) + + async def change_user_match_state( + self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser + ): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "MatchUserStateChanged", + user.user_id, + user.match_state, + ) + + async def change_room_match_state(self, room: ServerMultiplayerRoom): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "MatchRoomStateChanged", + room.room.match_state, + ) + + async def ChangeSettings(self, client: Client, settings: MultiplayerRoomSettings): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + + if room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Cannot change settings while playing") + + if settings.match_type == MatchType.PLAYLISTS: + raise InvokeException("Invalid match type selected") + + previous_settings = room.settings + room.settings = settings + + if previous_settings.match_type != settings.match_type: + await server_room.set_handler() + if previous_settings.queue_mode != settings.queue_mode: + await server_room.queue.update_queue_mode() + + await self.setting_changed(server_room, beatmap_changed=False) + await self.update_room_state(server_room) + + async def SendMatchRequest(self, client: Client, request: MatchRequest): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + if isinstance(request, StartMatchCountdownRequest): + if room.host and room.host.user_id != user.user_id: + raise InvokeException("You are not the host of this room") + if room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Cannot start a countdown during ongoing play") + await server_room.start_countdown( + MatchStartCountdown(time_remaining=request.duration), + self.start_match, + ) + elif isinstance(request, StopCountdownRequest): + countdown = next( + (c for c in room.active_countdowns if c.id == request.id), + None, + ) + if countdown is None: + return + if ( + isinstance(countdown, MatchStartCountdown) + and room.settings.auto_start_enabled + ) or isinstance( + countdown, (ForceGameplayStartCountdown | ServerShuttingDownCountdown) + ): + raise InvokeException("Cannot stop the requested countdown") + + await server_room.stop_countdown(countdown) + else: + await server_room.match_type_handler.handle_request(user, request) + + async def InvitePlayer(self, client: Client, user_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + async with AsyncSession(engine) as session: + db_user = await session.get(User, user_id) + target_relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == user_id, + Relationship.target_id == client.user_id, + ) + ) + ).first() + inviter_relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == client.user_id, + Relationship.target_id == user_id, + ) + ) + ).first() + if db_user is None: + raise InvokeException("User not found") + if db_user.id == client.user_id: + raise InvokeException("You cannot invite yourself") + if db_user.id in [u.user_id for u in room.users]: + raise InvokeException("User already invited") + if db_user.is_restricted: + raise InvokeException("User is restricted") + if ( + inviter_relationship + and inviter_relationship.type == RelationshipType.BLOCK + ): + raise InvokeException("Cannot perform action due to user being blocked") + if ( + target_relationship + and target_relationship.type == RelationshipType.BLOCK + ): + raise InvokeException("Cannot perform action due to user being blocked") + if ( + db_user.pm_friends_only + and target_relationship is not None + and target_relationship.type != RelationshipType.FOLLOW + ): + raise InvokeException( + "Cannot perform action " + "because user has disabled non-friend communications" + ) + + target_client = self.get_client_by_id(str(user_id)) + if target_client is None: + raise InvokeException("User is not online") + await self.call_noblock( + target_client, + "Invited", + client.user_id, + room.room_id, + room.settings.password, + ) diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index 0d0899e..b9a3c99 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -7,15 +7,13 @@ import struct import time from typing import override -from app.database import Beatmap +from app.database import Beatmap, User from app.database.score import Score from app.database.score_token import ScoreToken -from app.database.user import User from app.dependencies.database import engine from app.models.beatmap import BeatmapRankStatus from app.models.mods import mods_to_int -from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt -from app.models.signalr import serialize_to_list +from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics from app.models.spectator_hub import ( APIUser, FrameDataBundle, @@ -70,8 +68,8 @@ def save_replay( md5: str, username: str, score: Score, - statistics: ScoreStatisticsInt, - maximum_statistics: ScoreStatisticsInt, + statistics: ScoreStatistics, + maximum_statistics: ScoreStatistics, frames: list[LegacyReplayFrame], ) -> None: data = bytearray() @@ -108,8 +106,8 @@ def save_replay( last_time = 0 for frame in frames: frame_strs.append( - f"{frame.time - last_time}|{frame.x or 0.0}" - f"|{frame.y or 0.0}|{frame.button_state}" + f"{frame.time - last_time}|{frame.mouse_x or 0.0}" + f"|{frame.mouse_y or 0.0}|{frame.button_state}" ) last_time = frame.time frame_strs.append("-12345|0|0|0") @@ -166,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]): async def on_client_connect(self, client: Client) -> None: tasks = [ - self.call_noblock( - client, "UserBeganPlaying", user_id, serialize_to_list(store.state) - ) + self.call_noblock(client, "UserBeganPlaying", user_id, store.state) for user_id, store in self.state.items() if store.state is not None ] @@ -197,7 +193,7 @@ class SpectatorHub(Hub[StoreClientState]): ).first() if not user: return - name = user.name + name = user.username store.state = state store.beatmap_status = beatmap.beatmap_status store.checksum = beatmap.checksum @@ -215,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]): self.group_id(user_id), "UserBeganPlaying", user_id, - serialize_to_list(state), + state, ) async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None: @@ -223,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]): state = self.get_or_create_state(client) if not state.score: return - state.score.score_info.acc = frame_data.header.acc + state.score.score_info.accuracy = frame_data.header.accuracy state.score.score_info.combo = frame_data.header.combo state.score.score_info.max_combo = frame_data.header.max_combo state.score.score_info.statistics = frame_data.header.statistics @@ -234,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]): self.group_id(user_id), "UserSentFrames", user_id, - frame_data.model_dump(), + frame_data, ) async def EndPlaySession(self, client: Client, state: SpectatorState) -> None: @@ -297,19 +293,18 @@ class SpectatorHub(Hub[StoreClientState]): score_record.id, ) # save replay - if store.state.state == SpectatedUserState.Passed: - score_record.has_replay = True - await session.commit() - await session.refresh(score_record) - save_replay( - ruleset_id=store.ruleset_id, - md5=store.checksum, - username=store.score.score_info.user.name, - score=score_record, - statistics=store.score.score_info.statistics, - maximum_statistics=store.score.score_info.maximum_statistics, - frames=store.score.replay_frames, - ) + score_record.has_replay = True + await session.commit() + await session.refresh(score_record) + save_replay( + ruleset_id=store.ruleset_id, + md5=store.checksum, + username=store.score.score_info.user.name, + score=score_record, + statistics=store.score.score_info.statistics, + maximum_statistics=store.score.score_info.maximum_statistics, + frames=store.score.replay_frames, + ) async def _end_session(self, user_id: int, state: SpectatorState) -> None: if state.state == SpectatedUserState.Playing: @@ -318,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]): self.group_id(user_id), "UserFinishedPlaying", user_id, - serialize_to_list(state) if state else None, + state, ) async def StartWatchingUser(self, client: Client, target_id: int) -> None: @@ -329,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]): client, "UserBeganPlaying", target_id, - serialize_to_list(target_store.state), + target_store.state, ) store = self.get_or_create_state(client) store.watched_user.add(target_id) @@ -339,7 +334,7 @@ class SpectatorHub(Hub[StoreClientState]): async with AsyncSession(engine) as session: async with session.begin(): username = ( - await session.exec(select(User.name).where(User.id == user_id)) + await session.exec(select(User.username).where(User.id == user_id)) ).first() if not username: return diff --git a/app/signalr/packet.py b/app/signalr/packet.py index e361ef8..8949f4b 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -1,14 +1,24 @@ from __future__ import annotations from dataclasses import dataclass -from enum import IntEnum +import datetime +from enum import Enum, IntEnum +import inspect import json +from types import NoneType, UnionType from typing import ( Any, Protocol as TypingProtocol, + Union, + get_args, + get_origin, ) +from app.models.signalr import SignalRMeta, SignalRUnionMessage +from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal + import msgpack_lazer_api as m +from pydantic import BaseModel SEP = b"\x1e" @@ -73,8 +83,67 @@ class Protocol(TypingProtocol): @staticmethod def encode(packet: Packet) -> bytes: ... + @classmethod + def validate_object(cls, v: Any, typ: type) -> Any: ... + class MsgpackProtocol: + @classmethod + def serialize_msgpack(cls, v: Any) -> Any: + typ = v.__class__ + if issubclass(typ, BaseModel): + return cls.serialize_to_list(v) + elif issubclass(typ, list): + return [cls.serialize_msgpack(item) for item in v] + elif issubclass(typ, datetime.datetime): + return [v, 0] + elif issubclass(typ, datetime.timedelta): + return int(v.total_seconds() * 10_000_000) + elif isinstance(v, dict): + return { + cls.serialize_msgpack(k): cls.serialize_msgpack(value) + for k, value in v.items() + } + elif issubclass(typ, Enum): + list_ = list(typ) + return list_.index(v) if v in list_ else v.value + return v + + @classmethod + def serialize_to_list(cls, value: BaseModel) -> list[Any]: + values = [] + for field, info in value.__class__.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.member_ignore: + continue + values.append(cls.serialize_msgpack(v=getattr(value, field))) + if issubclass(value.__class__, SignalRUnionMessage): + return [value.__class__.union_type, values] + else: + return values + + @staticmethod + def process_object(v: Any, typ: type[BaseModel]) -> Any: + if isinstance(v, list): + d = {} + i = 0 + for field, info in typ.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.member_ignore: + continue + anno = info.annotation + if anno is None: + d[camel_to_snake(field)] = v[i] + else: + d[field] = MsgpackProtocol.validate_object(v[i], anno) + i += 1 + return d + return v + @staticmethod def _encode_varint(value: int) -> bytes: result = [] @@ -140,6 +209,53 @@ class MsgpackProtocol: ] raise ValueError(f"Unsupported packet type: {packet_type}") + @classmethod + def validate_object(cls, v: Any, typ: type) -> Any: + if issubclass(typ, BaseModel): + return typ.model_validate(obj=cls.process_object(v, typ)) + elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): + return v[0] + elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): + return datetime.timedelta(seconds=int(v / 10_000_000)) + elif get_origin(typ) is list: + return [cls.validate_object(item, get_args(typ)[0]) for item in v] + elif inspect.isclass(typ) and issubclass(typ, Enum): + list_ = list(typ) + return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v) + elif get_origin(typ) is dict: + return { + cls.validate_object(k, get_args(typ)[0]): cls.validate_object( + v, get_args(typ)[1] + ) + for k, v in v.items() + } + elif (origin := get_origin(typ)) is Union or origin is UnionType: + args = get_args(typ) + if len(args) == 2 and NoneType in args: + non_none_args = [arg for arg in args if arg is not NoneType] + if len(non_none_args) == 1: + if v is None: + return None + return cls.validate_object(v, non_none_args[0]) + + # suppose use `MessagePack-CSharp Union | None` + # except `X (Other Type) | None` + if NoneType in args and v is None: + return None + if not all( + issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args + ): + raise ValueError( + f"Cannot validate {v} to {typ}, " + "only SignalRUnionMessage subclasses are supported" + ) + union_type = v[0] + for arg in args: + assert issubclass(arg, SignalRUnionMessage) + if arg.union_type == union_type: + return cls.validate_object(v[1], arg) + return v + @staticmethod def encode(packet: Packet) -> bytes: payload = [packet.type.value, packet.header or {}] @@ -151,20 +267,24 @@ class MsgpackProtocol: ] ) if packet.arguments is not None: - payload.append(packet.arguments) + payload.append( + [MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments] + ) if packet.stream_ids is not None: payload.append(packet.stream_ids) elif isinstance(packet, CompletionPacket): result_kind = 2 if packet.error: result_kind = 1 - elif packet.result is None: + elif packet.result is not None: result_kind = 3 payload.extend( [ packet.invocation_id, result_kind, - packet.error or packet.result or None, + packet.error + or MsgpackProtocol.serialize_msgpack(packet.result) + or None, ] ) elif isinstance(packet, ClosePacket): @@ -181,6 +301,86 @@ class MsgpackProtocol: class JSONProtocol: + @classmethod + def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False): + typ = v.__class__ + if issubclass(typ, BaseModel): + return cls.serialize_model(v, in_union) + elif isinstance(v, dict): + return { + cls.serialize_to_json(k, True): cls.serialize_to_json(value) + for k, value in v.items() + } + elif isinstance(v, list): + return [cls.serialize_to_json(item) for item in v] + elif isinstance(v, datetime.datetime): + return v.isoformat() + elif isinstance(v, datetime.timedelta): + # d.hh:mm:ss + total_seconds = int(v.total_seconds()) + hours, remainder = divmod(total_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + return f"{hours:02}:{minutes:02}:{seconds:02}" + elif isinstance(v, Enum) and dict_key: + return v.value + elif isinstance(v, Enum): + list_ = list(typ) + return list_.index(v) + return v + + @classmethod + def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]: + d = {} + is_union = issubclass(v.__class__, SignalRUnionMessage) + for field, info in v.__class__.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.json_ignore: + continue + name = ( + snake_to_camel( + field, + metadata.use_abbr if metadata else True, + ) + if not is_union + else snake_to_pascal( + field, + metadata.use_abbr if metadata else True, + ) + ) + d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union) + if is_union and not in_union: + return { + "$dtype": v.__class__.__name__, + "$value": d, + } + return d + + @staticmethod + def process_object( + v: Any, typ: type[BaseModel], from_union: bool = False + ) -> dict[str, Any]: + d = {} + for field, info in typ.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.json_ignore: + continue + name = ( + snake_to_camel(field, metadata.use_abbr if metadata else True) + if not from_union + else snake_to_pascal(field, metadata.use_abbr if metadata else True) + ) + value = v.get(name) + anno = typ.model_fields[field].annotation + if anno is None: + d[field] = value + continue + d[field] = JSONProtocol.validate_object(value, anno) + return d + @staticmethod def decode(input: bytes) -> list[Packet]: packets_raw = input.removesuffix(SEP).split(SEP) @@ -225,6 +425,63 @@ class JSONProtocol: ] raise ValueError(f"Unsupported packet type: {packet_type}") + @classmethod + def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any: + if issubclass(typ, BaseModel): + return typ.model_validate(JSONProtocol.process_object(v, typ, from_union)) + elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): + return datetime.datetime.fromisoformat(v) + elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): + # d.hh:mm:ss + parts = v.split(":") + if len(parts) == 3: + return datetime.timedelta( + hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2]) + ) + elif len(parts) == 2: + return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1])) + elif len(parts) == 1: + return datetime.timedelta(seconds=int(parts[0])) + elif get_origin(typ) is list: + return [cls.validate_object(item, get_args(typ)[0]) for item in v] + elif inspect.isclass(typ) and issubclass(typ, Enum): + list_ = list(typ) + return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v) + elif get_origin(typ) is dict: + return { + cls.validate_object(k, get_args(typ)[0]): cls.validate_object( + v, get_args(typ)[1] + ) + for k, v in v.items() + } + elif (origin := get_origin(typ)) is Union or origin is UnionType: + args = get_args(typ) + if len(args) == 2 and NoneType in args: + non_none_args = [arg for arg in args if arg is not NoneType] + if len(non_none_args) == 1: + if v is None: + return None + return cls.validate_object(v, non_none_args[0]) + + # suppose use `MessagePack-CSharp Union | None` + # except `X (Other Type) | None` + if NoneType in args and v is None: + return None + if not all( + issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args + ): + raise ValueError( + f"Cannot validate {v} to {typ}, " + "only SignalRUnionMessage subclasses are supported" + ) + # https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs + union_type = v["$dtype"] + for arg in args: + assert issubclass(arg, SignalRUnionMessage) + if arg.__name__ == union_type: + return cls.validate_object(v["$value"], arg, True) + return v + @staticmethod def encode(packet: Packet) -> bytes: payload: dict[str, Any] = { @@ -241,7 +498,9 @@ class JSONProtocol: if packet.invocation_id is not None: payload["invocationId"] = packet.invocation_id if packet.arguments is not None: - payload["arguments"] = packet.arguments + payload["arguments"] = [ + JSONProtocol.serialize_to_json(arg) for arg in packet.arguments + ] if packet.stream_ids is not None: payload["streamIds"] = packet.stream_ids elif isinstance(packet, CompletionPacket): @@ -253,7 +512,7 @@ class JSONProtocol: if packet.error is not None: payload["error"] = packet.error if packet.result is not None: - payload["result"] = packet.result + payload["result"] = JSONProtocol.serialize_to_json(packet.result) elif isinstance(packet, PingPacket): pass elif isinstance(packet, ClosePacket): diff --git a/app/utils.py b/app/utils.py index 9008706..22f06dd 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,465 +1,92 @@ from __future__ import annotations -from datetime import UTC, datetime - -from app.database import ( - LazerUserCounts, - LazerUserProfile, - LazerUserStatistics, - User as DBUser, -) -from app.models.user import ( - Country, - Cover, - DailyChallengeStats, - GradeCounts, - Kudosu, - Level, - Page, - RankHighest, - RankHistory, - Statistics, - User, - UserAchievement, -) - def unix_timestamp_to_windows(timestamp: int) -> int: """Convert a Unix timestamp to a Windows timestamp.""" return (timestamp + 62135596800) * 10_000_000 -async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User: - """将数据库用户模型转换为API用户模型(使用 Lazer 表)""" - - # 从db_user获取基本字段值 - user_id = getattr(db_user, "id") - user_name = getattr(db_user, "name") - user_country = getattr(db_user, "country") - user_country_code = user_country # 在User模型中,country字段就是country_code - - # 获取 Lazer 用户资料 - profile = db_user.lazer_profile - if not profile: - # 如果没有 lazer 资料,使用默认值 - profile = LazerUserProfile( - user_id=user_id, - ) - - # 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系 - lzrcnt = db_user.lazer_counts - - if not lzrcnt: - # 如果没有 lazer 计数,使用默认值 - lzrcnt = LazerUserCounts(user_id=user_id) - - # 获取指定模式的统计信息 - user_stats = None - if db_user.lazer_statistics: - for stat in db_user.lazer_statistics: - if stat.mode == ruleset: - user_stats = stat - break - - if not user_stats: - # 如果没有找到指定模式的统计,创建默认统计 - user_stats = LazerUserStatistics(user_id=user_id) - - # 获取国家信息 - country_code = db_user.country_code if db_user.country_code is not None else "XX" - - country = Country(code=str(country_code), name=get_country_name(str(country_code))) - - # 获取 Kudosu 信息 - kudosu = Kudosu(available=0, total=0) - - # 获取计数信息 - # counts = LazerUserCounts(user_id=user_id) - - # 转换统计信息 - statistics = Statistics( - count_100=user_stats.count_100, - count_300=user_stats.count_300, - count_50=user_stats.count_50, - count_miss=user_stats.count_miss, - level=Level( - current=user_stats.level_current, progress=user_stats.level_progress - ), - global_rank=user_stats.global_rank, - global_rank_exp=user_stats.global_rank_exp, - pp=float(user_stats.pp) if user_stats.pp else 0.0, - pp_exp=float(user_stats.pp_exp) if user_stats.pp_exp else 0.0, - ranked_score=user_stats.ranked_score, - hit_accuracy=float(user_stats.hit_accuracy) if user_stats.hit_accuracy else 0.0, - play_count=user_stats.play_count, - play_time=user_stats.play_time, - total_score=user_stats.total_score, - total_hits=user_stats.total_hits, - maximum_combo=user_stats.maximum_combo, - replays_watched_by_others=user_stats.replays_watched_by_others, - is_ranked=user_stats.is_ranked, - grade_counts=GradeCounts( - ss=user_stats.grade_ss, - ssh=user_stats.grade_ssh, - s=user_stats.grade_s, - sh=user_stats.grade_sh, - a=user_stats.grade_a, - ), - country_rank=user_stats.country_rank, - rank={"country": user_stats.country_rank} if user_stats.country_rank else None, - ) - - # 转换所有模式的统计信息 - statistics_rulesets = {} - if db_user.lazer_statistics: - for stat in db_user.lazer_statistics: - statistics_rulesets[stat.mode] = Statistics( - count_100=stat.count_100, - count_300=stat.count_300, - count_50=stat.count_50, - count_miss=stat.count_miss, - level=Level(current=stat.level_current, progress=stat.level_progress), - global_rank=stat.global_rank, - global_rank_exp=stat.global_rank_exp, - pp=float(stat.pp) if stat.pp else 0.0, - pp_exp=float(stat.pp_exp) if stat.pp_exp else 0.0, - ranked_score=stat.ranked_score, - hit_accuracy=float(stat.hit_accuracy) if stat.hit_accuracy else 0.0, - play_count=stat.play_count, - play_time=stat.play_time, - total_score=stat.total_score, - total_hits=stat.total_hits, - maximum_combo=stat.maximum_combo, - replays_watched_by_others=stat.replays_watched_by_others, - is_ranked=stat.is_ranked, - grade_counts=GradeCounts( - ss=stat.grade_ss, - ssh=stat.grade_ssh, - s=stat.grade_s, - sh=stat.grade_sh, - a=stat.grade_a, - ), - country_rank=stat.country_rank, - rank={"country": stat.country_rank} if stat.country_rank else None, - ) - - # 转换国家信息 - country = Country(code=user_country_code, name=get_country_name(user_country_code)) - - # 转换封面信息 - cover_url = ( - profile.cover_url - if profile and profile.cover_url - else "https://assets.ppy.sh/user-profile-covers/default.jpeg" - ) - cover = Cover( - custom_url=profile.cover_url if profile else None, url=str(cover_url), id=None - ) - - # 转换 Kudosu 信息 - kudosu = Kudosu(available=0, total=0) - - # 转换成就信息 - user_achievements = [] - if db_user.lazer_achievements: - for achievement in db_user.lazer_achievements: - user_achievements.append( - UserAchievement( - achieved_at=achievement.achieved_at, - achievement_id=achievement.achievement_id, - ) - ) - - # 转换排名历史 - rank_history = None - rank_history_data = None - for rh in db_user.rank_history: - if rh.mode == ruleset: - rank_history_data = rh.rank_data - break - - if rank_history_data: - rank_history = RankHistory(mode=ruleset, data=rank_history_data) - - # 转换每日挑战统计 - # daily_challenge_stats = None - # if db_user.daily_challenge_stats: - # dcs = db_user.daily_challenge_stats - # daily_challenge_stats = DailyChallengeStats( - # daily_streak_best=dcs.daily_streak_best, - # daily_streak_current=dcs.daily_streak_current, - # last_update=dcs.last_update, - # last_weekly_streak=dcs.last_weekly_streak, - # playcount=dcs.playcount, - # top_10p_placements=dcs.top_10p_placements, - # top_50p_placements=dcs.top_50p_placements, - # user_id=dcs.user_id, - # weekly_streak_best=dcs.weekly_streak_best, - # weekly_streak_current=dcs.weekly_streak_current, - # ) - - # 转换最高排名 - rank_highest = None - if user_stats.rank_highest: - rank_highest = RankHighest( - rank=user_stats.rank_highest, - updated_at=user_stats.rank_highest_updated_at or datetime.utcnow(), - ) - - # 转换团队信息 - team = None - if db_user.team_membership: - team_member = db_user.team_membership # 假设用户只属于一个团队 - team = team_member.team - - # 创建用户对象 - # 从db_user获取基本字段值 - user_id = getattr(db_user, "id") - user_name = getattr(db_user, "name") - user_country = getattr(db_user, "country") - - # 获取用户头像URL - avatar_url = None - - # 首先检查 profile 中的 avatar_url - if profile and hasattr(profile, "avatar_url") and profile.avatar_url: - avatar_url = str(profile.avatar_url) - - # 然后检查是否有关联的头像记录 - if avatar_url is None and hasattr(db_user, "avatar") and db_user.avatar is not None: - if db_user.avatar.r2_game_url: - # 优先使用游戏用的头像URL - avatar_url = str(db_user.avatar.r2_game_url) - elif db_user.avatar.r2_original_url: - # 其次使用原始头像URL - avatar_url = str(db_user.avatar.r2_original_url) - - # 如果还是没有找到,通过查询获取 - # if db_session and avatar_url is None: - # try: - # # 导入UserAvatar模型 - - # # 尝试查找用户的头像记录 - # statement = select(UserAvatar).where( - # UserAvatar.user_id == user_id, UserAvatar.is_active == True - # ) - # avatar_record = db_session.exec(statement).first() - # if avatar_record is not None: - # if avatar_record.r2_game_url is not None: - # # 优先使用游戏用的头像URL - # avatar_url = str(avatar_record.r2_game_url) - # elif avatar_record.r2_original_url is not None: - # # 其次使用原始头像URL - # avatar_url = str(avatar_record.r2_original_url) - # except Exception as e: - # print(f"获取用户头像时出错: {e}") - # print(f"最终头像URL: {avatar_url}") - # 如果仍然没有找到头像URL,则使用默认URL - if avatar_url is None: - avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1" - - # 处理 profile_order 列表排序 - profile_order = [ - "me", - "recent_activity", - "top_ranks", - "medals", - "historical", - "beatmaps", - "kudosu", - ] - if profile and profile.profile_order: - profile_order = profile.profile_order.split(",") - - # 在convert_db_user_to_api_user函数中添加active_tournament_banners处理 - active_tournament_banners = [] - if db_user.active_banners: - for banner in db_user.active_banners: - active_tournament_banners.append( - { - "tournament_id": banner.tournament_id, - "image_url": banner.image_url, - "is_active": banner.is_active, - } - ) - - # 在convert_db_user_to_api_user函数中添加badges处理 - badges = [] - if db_user.lazer_badges: - for badge in db_user.lazer_badges: - badges.append( - { - "badge_id": badge.badge_id, - "awarded_at": badge.awarded_at, - "description": badge.description, - "image_url": badge.image_url, - "url": badge.url, - } - ) - - # 在convert_db_user_to_api_user函数中添加monthly_playcounts处理 - monthly_playcounts = [] - if db_user.lazer_monthly_playcounts: - for playcount in db_user.lazer_monthly_playcounts: - monthly_playcounts.append( - { - "start_date": playcount.start_date.isoformat() - if playcount.start_date - else None, - "play_count": playcount.play_count, - } - ) - - # 在convert_db_user_to_api_user函数中添加previous_usernames处理 - previous_usernames = [] - if db_user.lazer_previous_usernames: - for username in db_user.lazer_previous_usernames: - previous_usernames.append( - { - "username": username.username, - "changed_at": username.changed_at.isoformat() - if username.changed_at - else None, - } - ) - - # 在convert_db_user_to_api_user函数中添加replays_watched_counts处理 - replays_watched_counts = [] - if hasattr(db_user, "lazer_replays_watched") and db_user.lazer_replays_watched: - for replay in db_user.lazer_replays_watched: - replays_watched_counts.append( - { - "start_date": replay.start_date.isoformat() - if replay.start_date - else None, - "count": replay.count, - } - ) - - # 创建用户对象 - user = User( - id=user_id, - username=user_name, - avatar_url=avatar_url, - country_code=str(country_code), - default_group=profile.default_group if profile else "default", - is_active=profile.is_active, - is_bot=profile.is_bot, - is_deleted=profile.is_deleted, - is_online=profile.is_online, - is_supporter=profile.is_supporter, - is_restricted=profile.is_restricted, - last_visit=db_user.last_visit, - pm_friends_only=profile.pm_friends_only, - profile_colour=profile.profile_colour, - cover_url=profile.cover_url - if profile and profile.cover_url - else "https://assets.ppy.sh/user-profile-covers/default.jpeg", - discord=profile.discord if profile else None, - has_supported=profile.has_supported if profile else False, - interests=profile.interests if profile else None, - join_date=profile.join_date if profile.join_date else datetime.now(UTC), - location=profile.location if profile else None, - max_blocks=profile.max_blocks if profile and profile.max_blocks else 100, - max_friends=profile.max_friends if profile and profile.max_friends else 500, - post_count=profile.post_count if profile and profile.post_count else 0, - profile_hue=profile.profile_hue if profile and profile.profile_hue else None, - profile_order=profile_order, # 使用排序后的 profile_order - title=profile.title if profile else None, - title_url=profile.title_url if profile else None, - twitter=profile.twitter if profile else None, - website=profile.website if profile else None, - session_verified=True, - support_level=profile.support_level if profile else 0, - country=country, - cover=cover, - kudosu=kudosu, - statistics=statistics, - statistics_rulesets=statistics_rulesets, - beatmap_playcounts_count=lzrcnt.beatmap_playcounts_count if lzrcnt else 0, - comments_count=lzrcnt.comments_count if lzrcnt else 0, - favourite_beatmapset_count=lzrcnt.favourite_beatmapset_count if lzrcnt else 0, - follower_count=lzrcnt.follower_count if lzrcnt else 0, - graveyard_beatmapset_count=lzrcnt.graveyard_beatmapset_count if lzrcnt else 0, - guest_beatmapset_count=lzrcnt.guest_beatmapset_count if lzrcnt else 0, - loved_beatmapset_count=lzrcnt.loved_beatmapset_count if lzrcnt else 0, - mapping_follower_count=lzrcnt.mapping_follower_count if lzrcnt else 0, - nominated_beatmapset_count=lzrcnt.nominated_beatmapset_count if lzrcnt else 0, - pending_beatmapset_count=lzrcnt.pending_beatmapset_count if lzrcnt else 0, - ranked_beatmapset_count=lzrcnt.ranked_beatmapset_count if lzrcnt else 0, - ranked_and_approved_beatmapset_count=lzrcnt.ranked_and_approved_beatmapset_count - if lzrcnt - else 0, - unranked_beatmapset_count=lzrcnt.unranked_beatmapset_count if lzrcnt else 0, - scores_best_count=lzrcnt.scores_best_count if lzrcnt else 0, - scores_first_count=lzrcnt.scores_first_count if lzrcnt else 0, - scores_pinned_count=lzrcnt.scores_pinned_count, - scores_recent_count=lzrcnt.scores_recent_count if lzrcnt else 0, - account_history=[], # TODO: 获取用户历史账户信息 - # active_tournament_banner=len(active_tournament_banners), - active_tournament_banners=active_tournament_banners, - badges=badges, - current_season_stats=None, - daily_challenge_user_stats=DailyChallengeStats( - user_id=user_id, - daily_streak_best=db_user.daily_challenge_stats.daily_streak_best - if db_user.daily_challenge_stats - else 0, - daily_streak_current=db_user.daily_challenge_stats.daily_streak_current - if db_user.daily_challenge_stats - else 0, - last_update=db_user.daily_challenge_stats.last_update - if db_user.daily_challenge_stats - else None, - last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak - if db_user.daily_challenge_stats - else None, - playcount=db_user.daily_challenge_stats.playcount - if db_user.daily_challenge_stats - else 0, - top_10p_placements=db_user.daily_challenge_stats.top_10p_placements - if db_user.daily_challenge_stats - else 0, - top_50p_placements=db_user.daily_challenge_stats.top_50p_placements - if db_user.daily_challenge_stats - else 0, - weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best - if db_user.daily_challenge_stats - else 0, - weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current - if db_user.daily_challenge_stats - else 0, - ), - groups=[], - monthly_playcounts=monthly_playcounts, - page=Page(html=profile.page_html or "", raw=profile.page_raw or "") - if profile.page_html or profile.page_raw - else Page(), - previous_usernames=previous_usernames, - rank_highest=rank_highest, - rank_history=rank_history, - rankHistory=rank_history, - replays_watched_counts=replays_watched_counts, - team=team, - user_achievements=user_achievements, - ) - - return user +def camel_to_snake(name: str) -> str: + """Convert a camelCase string to snake_case.""" + result = [] + last_chr = "" + for char in name: + if char.isupper(): + if not last_chr.isupper() and result: + result.append("_") + result.append(char.lower()) + else: + result.append(char) + last_chr = char + return "".join(result) -def get_country_name(country_code: str) -> str: - """根据国家代码获取国家名称""" - country_names = { - "CN": "China", - "JP": "Japan", - "US": "United States", - "GB": "United Kingdom", - "DE": "Germany", - "FR": "France", - "KR": "South Korea", - "CA": "Canada", - "AU": "Australia", - "BR": "Brazil", - # 可以添加更多国家 +def snake_to_camel(name: str, use_abbr: bool = True) -> str: + """Convert a snake_case string to camelCase.""" + if not name: + return name + + parts = name.split("_") + if not parts: + return name + + # 常见缩写词列表 + abbreviations = { + "id", + "url", + "api", + "http", + "https", + "xml", + "json", + "css", + "html", + "sql", + "db", } - return country_names.get(country_code, "Unknown") + + result = [] + for part in parts: + if part.lower() in abbreviations and use_abbr: + result.append(part.upper()) + else: + if result: + result.append(part.capitalize()) + else: + result.append(part.lower()) + + return "".join(result) + + +def snake_to_pascal(name: str, use_abbr: bool = True) -> str: + """Convert a snake_case string to PascalCase.""" + if not name: + return name + + parts = name.split("_") + if not parts: + return name + + # 常见缩写词列表 + abbreviations = { + "id", + "url", + "api", + "http", + "https", + "xml", + "json", + "css", + "html", + "sql", + "db", + } + + result = [] + for part in parts: + if part.lower() in abbreviations and use_abbr: + result.append(part.upper()) + else: + result.append(part.capitalize()) + + return "".join(result) diff --git a/main.py b/main.py index 526d593..b12f543 100644 --- a/main.py +++ b/main.py @@ -4,25 +4,27 @@ from contextlib import asynccontextmanager from datetime import datetime from app.config import settings -from app.database import Team # noqa: F401 -from app.dependencies.database import create_tables, engine +from app.dependencies.database import create_tables, engine, redis_client from app.dependencies.fetcher import get_fetcher -from app.models.user import User -from app.router import api_router, auth_router, fetcher_router, signalr_router +from app.router import ( + api_router, + auth_router, + fetcher_router, + signalr_router, +) from fastapi import FastAPI -User.model_rebuild() - @asynccontextmanager async def lifespan(app: FastAPI): # on startup await create_tables() - get_fetcher() # 初始化 fetcher + await get_fetcher() # 初始化 fetcher # on shutdown yield await engine.dispose() + await redis_client.aclose() app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan) @@ -44,104 +46,6 @@ async def health_check(): return {"status": "ok", "timestamp": datetime.utcnow().isoformat()} -# @app.get("/api/v2/friends") -# async def get_friends(): -# return JSONResponse( -# content=[ -# { -# "id": 123456, -# "username": "BestFriend", -# "is_online": True, -# "is_supporter": False, -# "country": {"code": "US", "name": "United States"}, -# } -# ] -# ) - - -# @app.get("/api/v2/notifications") -# async def get_notifications(): -# return JSONResponse(content={"notifications": [], "unread_count": 0}) - - -# @app.post("/api/v2/chat/ack") -# async def chat_ack(): -# return JSONResponse(content={"status": "ok"}) - - -# @app.get("/api/v2/users/{user_id}/{mode}") -# async def get_user_mode(user_id: int, mode: str): -# return JSONResponse( -# content={ -# "id": user_id, -# "username": "测试测试测", -# "statistics": { -# "level": {"current": 97, "progress": 96}, -# "pp": 114514, -# "global_rank": 666, -# "country_rank": 1, -# "hit_accuracy": 100, -# }, -# "country": {"code": "JP", "name": "Japan"}, -# } -# ) - - -# @app.get("/api/v2/me") -# async def get_me(): -# return JSONResponse( -# content={ -# "id": 15651670, -# "username": "Googujiang", -# "is_online": True, -# "country": {"code": "JP", "name": "Japan"}, -# "statistics": { -# "level": {"current": 97, "progress": 96}, -# "pp": 2826.26, -# "global_rank": 298026, -# "country_rank": 11220, -# "hit_accuracy": 95.7168, -# }, -# } -# ) - - -# @app.post("/signalr/metadata/negotiate") -# async def metadata_negotiate(negotiateVersion: int = 1): -# return JSONResponse( -# content={ -# "connectionId": "abc123", -# "availableTransports": [ -# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} -# ], -# } -# ) - - -# @app.post("/signalr/spectator/negotiate") -# async def spectator_negotiate(negotiateVersion: int = 1): -# return JSONResponse( -# content={ -# "connectionId": "spec456", -# "availableTransports": [ -# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} -# ], -# } -# ) - - -# @app.post("/signalr/multiplayer/negotiate") -# async def multiplayer_negotiate(negotiateVersion: int = 1): -# return JSONResponse( -# content={ -# "connectionId": "multi789", -# "availableTransports": [ -# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} -# ], -# } -# ) - - if __name__ == "__main__": from app.log import logger # noqa: F401 diff --git a/migrations/versions/78be13c71791_score_remove_best_id_in_database.py b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py similarity index 62% rename from migrations/versions/78be13c71791_score_remove_best_id_in_database.py rename to migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py index d0cab2b..84bae15 100644 --- a/migrations/versions/78be13c71791_score_remove_best_id_in_database.py +++ b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py @@ -1,8 +1,8 @@ -"""score: remove best_id in database +"""beatmapset: support favourite count -Revision ID: 78be13c71791 -Revises: dc4d25c428c7 -Create Date: 2025-07-29 07:57:33.764517 +Revision ID: 1178d0758ebf +Revises: +Create Date: 2025-08-01 04:05:09.882800 """ @@ -15,8 +15,8 @@ import sqlalchemy as sa from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision: str = "78be13c71791" -down_revision: str | Sequence[str] | None = "dc4d25c428c7" +revision: str = "1178d0758ebf" +down_revision: str | Sequence[str] | None = None branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None @@ -24,7 +24,7 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("scores", "best_id") + op.drop_column("beatmapsets", "favourite_count") # ### end Alembic commands ### @@ -32,7 +32,9 @@ def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### op.add_column( - "scores", - sa.Column("best_id", mysql.INTEGER(), autoincrement=False, nullable=True), + "beatmapsets", + sa.Column( + "favourite_count", mysql.INTEGER(), autoincrement=False, nullable=False + ), ) # ### end Alembic commands ### diff --git a/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py b/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py new file mode 100644 index 0000000..e383621 --- /dev/null +++ b/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py @@ -0,0 +1,54 @@ +"""relationship: fix unique relationship + +Revision ID: 58a11441d302 +Revises: 1178d0758ebf +Create Date: 2025-08-01 04:23:02.498166 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "58a11441d302" +down_revision: str | Sequence[str] | None = "1178d0758ebf" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "relationship", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + ) + op.drop_constraint("PRIMARY", "relationship", type_="primary") + op.create_primary_key("pk_relationship", "relationship", ["id"]) + op.alter_column( + "relationship", "user_id", existing_type=mysql.BIGINT(), nullable=True + ) + op.alter_column( + "relationship", "target_id", existing_type=mysql.BIGINT(), nullable=True + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("pk_relationship", "relationship", type_="primary") + op.create_primary_key("PRIMARY", "relationship", ["user_id", "target_id"]) + op.alter_column( + "relationship", "target_id", existing_type=mysql.BIGINT(), nullable=False + ) + op.alter_column( + "relationship", "user_id", existing_type=mysql.BIGINT(), nullable=False + ) + op.drop_column("relationship", "id") + # ### end Alembic commands ### diff --git a/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py b/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py new file mode 100644 index 0000000..74f2e56 --- /dev/null +++ b/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py @@ -0,0 +1,89 @@ +"""playlist: index playlist id + +Revision ID: d0c1b2cefe91 +Revises: 58a11441d302 +Create Date: 2025-08-06 06:02:10.512616 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "d0c1b2cefe91" +down_revision: str | Sequence[str] | None = "58a11441d302" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_index( + op.f("ix_room_playlists_id"), "room_playlists", ["id"], unique=False + ) + op.create_table( + "playlist_best_scores", + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("score_id", sa.BigInteger(), nullable=False), + sa.Column("room_id", sa.Integer(), nullable=False), + sa.Column("playlist_id", sa.Integer(), nullable=False), + sa.Column("total_score", sa.BigInteger(), nullable=True), + sa.ForeignKeyConstraint( + ["playlist_id"], + ["room_playlists.id"], + ), + sa.ForeignKeyConstraint( + ["room_id"], + ["rooms.id"], + ), + sa.ForeignKeyConstraint( + ["score_id"], + ["scores.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("score_id"), + ) + op.create_index( + op.f("ix_playlist_best_scores_playlist_id"), + "playlist_best_scores", + ["playlist_id"], + unique=False, + ) + op.create_index( + op.f("ix_playlist_best_scores_room_id"), + "playlist_best_scores", + ["room_id"], + unique=False, + ) + op.create_index( + op.f("ix_playlist_best_scores_user_id"), + "playlist_best_scores", + ["user_id"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + op.f("ix_playlist_best_scores_user_id"), table_name="playlist_best_scores" + ) + op.drop_index( + op.f("ix_playlist_best_scores_room_id"), table_name="playlist_best_scores" + ) + op.drop_index( + op.f("ix_playlist_best_scores_playlist_id"), table_name="playlist_best_scores" + ) + op.drop_table("playlist_best_scores") + op.drop_index(op.f("ix_room_playlists_id"), table_name="room_playlists") + # ### end Alembic commands ### diff --git a/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py b/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py deleted file mode 100644 index d90ec3d..0000000 --- a/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py +++ /dev/null @@ -1,36 +0,0 @@ -"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator - -Revision ID: dc4d25c428c7 -Revises: -Create Date: 2025-07-29 01:43:40.221070 - -""" - -from __future__ import annotations - -from collections.abc import Sequence - -from alembic import op -import sqlalchemy as sa - -# revision identifiers, used by Alembic. -revision: str = "dc4d25c428c7" -down_revision: str | Sequence[str] | None = None -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True)) - op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("scores", "nsmall_tick_hit") - op.drop_column("scores", "nlarge_tick_hit") - # ### end Alembic commands ### diff --git a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi index 88b79c5..433c53b 100644 --- a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi +++ b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi @@ -1,11 +1,4 @@ from typing import Any -class APIMod: - def __init__(self, acronym: str, settings: dict[str, Any]) -> None: ... - @property - def acronym(self) -> str: ... - @property - def settings(self) -> str: ... - def encode(obj: Any) -> bytes: ... def decode(data: bytes) -> Any: ... diff --git a/packages/msgpack_lazer_api/src/decode.rs b/packages/msgpack_lazer_api/src/decode.rs index 15156ca..1e36c42 100644 --- a/packages/msgpack_lazer_api/src/decode.rs +++ b/packages/msgpack_lazer_api/src/decode.rs @@ -1,8 +1,6 @@ -use crate::APIMod; use chrono::{TimeZone, Utc}; use pyo3::types::PyDict; use pyo3::{prelude::*, IntoPyObjectExt}; -use std::collections::HashMap; use std::io::Read; pub fn read_object( @@ -13,6 +11,8 @@ pub fn read_object( match rmp::decode::read_marker(cursor) { Ok(marker) => match marker { rmp::Marker::Null => Ok(py.None()), + rmp::Marker::True => Ok(true.into_py_any(py)?), + rmp::Marker::False => Ok(false.into_py_any(py)?), rmp::Marker::FixPos(val) => Ok(val.into_pyobject(py)?.into_any().unbind()), rmp::Marker::FixNeg(val) => Ok(val.into_pyobject(py)?.into_any().unbind()), rmp::Marker::U8 => { @@ -86,8 +86,6 @@ pub fn read_object( cursor.read_exact(&mut data).map_err(to_py_err)?; Ok(data.into_pyobject(py)?.into_any().unbind()) } - rmp::Marker::True => Ok(true.into_py_any(py)?), - rmp::Marker::False => Ok(false.into_py_any(py)?), rmp::Marker::FixStr(len) => read_string(py, cursor, len as u32), rmp::Marker::Str8 => { let mut buf = [0u8; 1]; @@ -206,13 +204,12 @@ fn read_array( let obj1 = read_object(py, cursor, false)?; if obj1.extract::(py).map_or(false, |k| k.len() == 2) { let obj2 = read_object(py, cursor, true)?; - return Ok(APIMod { - acronym: obj1.extract::(py)?, - settings: obj2.extract::>(py)?, - } - .into_pyobject(py)? - .into_any() - .unbind()); + + let api_mod_dict = PyDict::new(py); + api_mod_dict.set_item("acronym", obj1)?; + api_mod_dict.set_item("settings", obj2)?; + + return Ok(api_mod_dict.into_pyobject(py)?.into_any().unbind()); } else { items.push(obj1); i += 1; diff --git a/packages/msgpack_lazer_api/src/encode.rs b/packages/msgpack_lazer_api/src/encode.rs index 88a732b..3ff4864 100644 --- a/packages/msgpack_lazer_api/src/encode.rs +++ b/packages/msgpack_lazer_api/src/encode.rs @@ -1,8 +1,7 @@ -use crate::APIMod; -use chrono::{DateTime, Utc}; -use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyStringMethods}; +use chrono::{DateTime, Utc}; +use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyResult, PyStringMethods}; use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString}; -use pyo3::{Bound, PyAny, PyRef, Python}; +use pyo3::{Bound, PyAny}; use std::io::Write; fn write_list(buf: &mut Vec, obj: &Bound<'_, PyList>) { @@ -61,19 +60,42 @@ fn write_hashmap(buf: &mut Vec, obj: &Bound<'_, PyDict>) { } } -fn write_nil(buf: &mut Vec){ +fn write_nil(buf: &mut Vec) { rmp::encode::write_nil(buf).unwrap(); } -// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs -fn write_api_mod(buf: &mut Vec, api_mod: PyRef) { - rmp::encode::write_array_len(buf, 2).unwrap(); - rmp::encode::write_str(buf, &api_mod.acronym).unwrap(); - rmp::encode::write_array_len(buf, api_mod.settings.len() as u32).unwrap(); - for (k, v) in api_mod.settings.iter() { - rmp::encode::write_str(buf, k).unwrap(); - Python::with_gil(|py| write_object(buf, &v.bind(py))); +fn is_api_mod(dict: &Bound<'_, PyDict>) -> bool { + if let Ok(Some(acronym)) = dict.get_item("acronym") { + if let Ok(acronym_str) = acronym.extract::() { + return acronym_str.len() == 2; + } } + false +} + +// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs +fn write_api_mod(buf: &mut Vec, api_mod: &Bound<'_, PyDict>) -> PyResult<()> { + let acronym = api_mod + .get_item("acronym")? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("APIMod missing 'acronym' field"))?; + let acronym_str = acronym.extract::()?; + + let settings = api_mod + .get_item("settings")? + .unwrap_or_else(|| PyDict::new(acronym.py()).into_any()); + let settings_dict = settings.downcast::()?; + + rmp::encode::write_array_len(buf, 2).unwrap(); + rmp::encode::write_str(buf, &acronym_str).unwrap(); + rmp::encode::write_array_len(buf, settings_dict.len() as u32).unwrap(); + + for (k, v) in settings_dict.iter() { + let key_str = k.extract::()?; + rmp::encode::write_str(buf, &key_str).unwrap(); + write_object(buf, &v); + } + + Ok(()) } fn write_datetime(buf: &mut Vec, obj: &Bound<'_, PyDateTime>) { @@ -110,22 +132,24 @@ pub fn write_object(buf: &mut Vec, obj: &Bound<'_, PyAny>) { write_list(buf, list); } else if let Ok(string) = obj.downcast::() { write_string(buf, string); - } else if let Ok(integer) = obj.downcast::() { - write_integer(buf, integer); - } else if let Ok(float) = obj.downcast::() { - write_float(buf, float); } else if let Ok(boolean) = obj.downcast::() { write_bool(buf, boolean); + } else if let Ok(float) = obj.downcast::() { + write_float(buf, float); + } else if let Ok(integer) = obj.downcast::() { + write_integer(buf, integer); } else if let Ok(bytes) = obj.downcast::() { write_bin(buf, bytes); } else if let Ok(dict) = obj.downcast::() { - write_hashmap(buf, dict); + if is_api_mod(dict) { + write_api_mod(buf, dict).unwrap_or_else(|_| write_hashmap(buf, dict)); + } else { + write_hashmap(buf, dict); + } } else if let Ok(_none) = obj.downcast::() { write_nil(buf); } else if let Ok(datetime) = obj.downcast::() { write_datetime(buf, datetime); - } else if let Ok(api_mod) = obj.extract::>() { - write_api_mod(buf, api_mod); } else { panic!("Unsupported type"); } diff --git a/packages/msgpack_lazer_api/src/lib.rs b/packages/msgpack_lazer_api/src/lib.rs index fda540c..220e645 100644 --- a/packages/msgpack_lazer_api/src/lib.rs +++ b/packages/msgpack_lazer_api/src/lib.rs @@ -2,30 +2,6 @@ mod decode; mod encode; use pyo3::prelude::*; -use std::collections::HashMap; - -#[pyclass] -struct APIMod { - #[pyo3(get, set)] - acronym: String, - #[pyo3(get, set)] - settings: HashMap, -} - -#[pymethods] -impl APIMod { - #[new] - fn new(acronym: String, settings: HashMap) -> Self { - APIMod { acronym, settings } - } - - fn __repr__(&self) -> String { - format!( - "APIMod(acronym='{}', settings={:?})", - self.acronym, self.settings - ) - } -} #[pyfunction] #[pyo3(name = "encode")] @@ -46,6 +22,5 @@ fn decode_py(py: Python, data: &[u8]) -> PyResult { fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(encode_py, m)?)?; m.add_function(wrap_pyfunction!(decode_py, m)?)?; - m.add_class::()?; Ok(()) } diff --git a/static/README.md b/static/README.md index 16ece63..77b54fe 100644 --- a/static/README.md +++ b/static/README.md @@ -2,4 +2,4 @@ - `mods.json`: 包含了游戏中的所有可用mod的详细信息。 - Origin: https://github.com/ppy/osu-web/blob/master/database/mods.json - - Version: 2025/6/10 `b68c920b1db3d443b9302fdc3f86010c875fe380` + - Version: 2025/7/30 `ff49b66b27a2850aea4b6b3ba563cfe936cb6082` diff --git a/static/mods.json b/static/mods.json index defb57f..0a8449b 100644 --- a/static/mods.json +++ b/static/mods.json @@ -2438,7 +2438,8 @@ "Settings": [], "IncompatibleMods": [ "CN", - "RX" + "RX", + "MF" ], "RequiresConfiguration": false, "UserPlayable": false, @@ -2460,7 +2461,8 @@ "AC", "AT", "CN", - "RX" + "RX", + "MF" ], "RequiresConfiguration": false, "UserPlayable": false, @@ -2477,7 +2479,8 @@ "Settings": [], "IncompatibleMods": [ "AT", - "CN" + "CN", + "MF" ], "RequiresConfiguration": false, "UserPlayable": true, @@ -2638,6 +2641,24 @@ "ValidForMultiplayerAsFreeMod": true, "AlwaysValidForSubmission": false }, + { + "Acronym": "MF", + "Name": "Moving Fast", + "Description": "Dashing by default, slow down!", + "Type": "Fun", + "Settings": [], + "IncompatibleMods": [ + "AT", + "CN", + "RX" + ], + "RequiresConfiguration": false, + "UserPlayable": true, + "ValidForMultiplayer": true, + "ValidForFreestyleAsRequiredMod": false, + "ValidForMultiplayerAsFreeMod": true, + "AlwaysValidForSubmission": false + }, { "Acronym": "SV2", "Name": "Score V2",