diff --git a/.gitignore b/.gitignore
index 369e759..05622b7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -37,6 +37,7 @@ pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
+test-cert/
htmlcov/
.tox/
.nox/
diff --git a/app/auth.py b/app/auth.py
index 4c690f8..4762662 100644
--- a/app/auth.py
+++ b/app/auth.py
@@ -8,7 +8,7 @@ import string
from app.config import settings
from app.database import (
OAuthToken,
- User as DBUser,
+ User,
)
from app.log import logger
@@ -74,7 +74,7 @@ def get_password_hash(password: str) -> str:
async def authenticate_user_legacy(
db: AsyncSession, name: str, password: str
-) -> DBUser | None:
+) -> User | None:
"""
验证用户身份 - 使用类似 from_login 的逻辑
"""
@@ -82,7 +82,7 @@ async def authenticate_user_legacy(
pw_md5 = hashlib.md5(password.encode()).hexdigest()
# 2. 根据用户名查找用户
- statement = select(DBUser).where(DBUser.name == name)
+ statement = select(User).where(User.username == name)
user = (await db.exec(statement)).first()
if not user:
return None
@@ -113,7 +113,7 @@ async def authenticate_user_legacy(
async def authenticate_user(
db: AsyncSession, username: str, password: str
-) -> DBUser | None:
+) -> User | None:
"""验证用户身份"""
return await authenticate_user_legacy(db, username, password)
diff --git a/app/database/__init__.py b/app/database/__init__.py
index 191a193..0ee253b 100644
--- a/app/database/__init__.py
+++ b/app/database/__init__.py
@@ -1,3 +1,4 @@
+from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthToken
from .beatmap import (
Beatmap as Beatmap,
@@ -8,65 +9,64 @@ from .beatmapset import (
BeatmapsetResp as BeatmapsetResp,
)
from .best_score import BestScore
-from .legacy import LegacyOAuthToken, LegacyUserStatistics
+from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
+from .favourite_beatmapset import FavouriteBeatmapset
+from .lazer_user import (
+ User,
+ UserResp,
+)
+from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
+from .playlist_attempts import ItemAttemptsCount, ItemAttemptsResp
+from .playlist_best_score import PlaylistBestScore
+from .playlists import Playlist, PlaylistResp
+from .pp_best_score import PPBestScore
from .relationship import Relationship, RelationshipResp, RelationshipType
+from .room import Room, RoomResp
from .score import (
+ MultiplayerScores,
Score,
+ ScoreAround,
ScoreBase,
ScoreResp,
ScoreStatistics,
)
from .score_token import ScoreToken, ScoreTokenResp
+from .statistics import (
+ UserStatistics,
+ UserStatisticsResp,
+)
from .team import Team, TeamMember
-from .user import (
- DailyChallengeStats,
- LazerUserAchievement,
- LazerUserBadge,
- LazerUserBanners,
- LazerUserCountry,
- LazerUserCounts,
- LazerUserKudosu,
- LazerUserMonthlyPlaycounts,
- LazerUserPreviousUsername,
- LazerUserProfile,
- LazerUserProfileSections,
- LazerUserReplaysWatched,
- LazerUserStatistics,
- RankHistory,
- User,
- UserAchievement,
- UserAvatar,
+from .user_account_history import (
+ UserAccountHistory,
+ UserAccountHistoryResp,
+ UserAccountHistoryType,
)
-BeatmapsetResp.model_rebuild()
-BeatmapResp.model_rebuild()
__all__ = [
"Beatmap",
- "BeatmapResp",
"Beatmapset",
"BeatmapsetResp",
"BestScore",
"DailyChallengeStats",
- "LazerUserAchievement",
- "LazerUserBadge",
- "LazerUserBanners",
- "LazerUserCountry",
- "LazerUserCounts",
- "LazerUserKudosu",
- "LazerUserMonthlyPlaycounts",
- "LazerUserPreviousUsername",
- "LazerUserProfile",
- "LazerUserProfileSections",
- "LazerUserReplaysWatched",
- "LazerUserStatistics",
- "LegacyOAuthToken",
- "LegacyUserStatistics",
+ "DailyChallengeStatsResp",
+ "FavouriteBeatmapset",
+ "ItemAttemptsCount",
+ "ItemAttemptsResp",
+ "MultiplayerEvent",
+ "MultiplayerEventResp",
+ "MultiplayerScores",
"OAuthToken",
- "RankHistory",
+ "PPBestScore",
+ "Playlist",
+ "PlaylistBestScore",
+ "PlaylistResp",
"Relationship",
"RelationshipResp",
"RelationshipType",
+ "Room",
+ "RoomResp",
"Score",
+ "ScoreAround",
"ScoreBase",
"ScoreResp",
"ScoreStatistics",
@@ -75,6 +75,17 @@ __all__ = [
"Team",
"TeamMember",
"User",
+ "UserAccountHistory",
+ "UserAccountHistoryResp",
+ "UserAccountHistoryType",
"UserAchievement",
- "UserAvatar",
+ "UserAchievement",
+ "UserAchievementResp",
+ "UserResp",
+ "UserStatistics",
+ "UserStatisticsResp",
]
+
+for i in __all__:
+ if i.endswith("Resp"):
+ globals()[i].model_rebuild() # type: ignore[call-arg]
diff --git a/app/database/achievement.py b/app/database/achievement.py
new file mode 100644
index 0000000..4be587f
--- /dev/null
+++ b/app/database/achievement.py
@@ -0,0 +1,40 @@
+from datetime import UTC, datetime
+from typing import TYPE_CHECKING
+
+from app.models.model import UTCBaseModel
+
+from sqlmodel import (
+ BigInteger,
+ Column,
+ DateTime,
+ Field,
+ ForeignKey,
+ Relationship,
+ SQLModel,
+)
+
+if TYPE_CHECKING:
+ from .lazer_user import User
+
+
+class UserAchievementBase(SQLModel, UTCBaseModel):
+ achievement_id: int = Field(primary_key=True)
+ achieved_at: datetime = Field(
+ default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
+ )
+
+
+class UserAchievement(UserAchievementBase, table=True):
+ __tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
+
+ id: int | None = Field(default=None, primary_key=True, index=True)
+ user_id: int = Field(
+ sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True
+ )
+ user: "User" = Relationship(back_populates="achievement")
+
+
+class UserAchievementResp(UserAchievementBase):
+ @classmethod
+ def from_db(cls, db_model: UserAchievement) -> "UserAchievementResp":
+ return cls.model_validate(db_model)
diff --git a/app/database/auth.py b/app/database/auth.py
index ae49676..554dced 100644
--- a/app/database/auth.py
+++ b/app/database/auth.py
@@ -1,19 +1,21 @@
from datetime import datetime
from typing import TYPE_CHECKING
+from app.models.model import UTCBaseModel
+
from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
- from .user import User
+ from .lazer_user import User
-class OAuthToken(SQLModel, table=True):
+class OAuthToken(UTCBaseModel, SQLModel, table=True):
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
- sa_column=Column(BigInteger, ForeignKey("users.id"), index=True)
+ sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
access_token: str = Field(max_length=500, unique=True)
refresh_token: str = Field(max_length=500, unique=True)
diff --git a/app/database/beatmap.py b/app/database/beatmap.py
index b28473e..62a0120 100644
--- a/app/database/beatmap.py
+++ b/app/database/beatmap.py
@@ -7,13 +7,14 @@ from app.models.score import MODE_TO_INT, GameMode
from .beatmapset import Beatmapset, BeatmapsetResp
from sqlalchemy import DECIMAL, Column, DateTime
-from sqlalchemy.orm import joinedload
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
+ from .lazer_user import User
+
class BeatmapOwner(SQLModel):
id: int
@@ -66,7 +67,9 @@ class Beatmap(BeatmapBase, table=True):
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus
# optional
- beatmapset: Beatmapset = Relationship(back_populates="beatmaps")
+ beatmapset: Beatmapset = Relationship(
+ back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
+ )
@property
def can_ranked(self) -> bool:
@@ -87,13 +90,7 @@ class Beatmap(BeatmapBase, table=True):
session.add(beatmap)
await session.commit()
beatmap = (
- await session.exec(
- select(Beatmap)
- .options(
- joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
- )
- .where(Beatmap.id == resp.id)
- )
+ await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
).first()
assert beatmap is not None, "Beatmap should not be None after commit"
return beatmap
@@ -131,13 +128,9 @@ class Beatmap(BeatmapBase, table=True):
) -> "Beatmap":
beatmap = (
await session.exec(
- select(Beatmap)
- .where(
+ select(Beatmap).where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
)
- .options(
- joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
- )
)
).first()
if not beatmap:
@@ -164,11 +157,13 @@ class BeatmapResp(BeatmapBase):
url: str = ""
@classmethod
- def from_db(
+ async def from_db(
cls,
beatmap: Beatmap,
query_mode: GameMode | None = None,
from_set: bool = False,
+ session: AsyncSession | None = None,
+ user: "User | None" = None,
) -> "BeatmapResp":
beatmap_ = beatmap.model_dump()
if query_mode is not None and beatmap.mode != query_mode:
@@ -178,5 +173,7 @@ class BeatmapResp(BeatmapBase):
beatmap_["ranked"] = beatmap.beatmap_status.value
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
if not from_set:
- beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset)
+ beatmap_["beatmapset"] = await BeatmapsetResp.from_db(
+ beatmap.beatmapset, session=session, user=user
+ )
return cls.model_validate(beatmap_)
diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py
index 1e6ba27..3bad7e9 100644
--- a/app/database/beatmapset.py
+++ b/app/database/beatmapset.py
@@ -4,13 +4,17 @@ from typing import TYPE_CHECKING, TypedDict, cast
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.score import GameMode
+from .lazer_user import BASE_INCLUDES, User, UserResp
+
from pydantic import BaseModel, model_serializer
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
-from sqlmodel import Field, Relationship, SQLModel
+from sqlalchemy.ext.asyncio import AsyncAttrs
+from sqlmodel import Field, Relationship, SQLModel, col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .beatmap import Beatmap, BeatmapResp
+ from .favourite_beatmapset import FavouriteBeatmapset
class BeatmapCovers(SQLModel):
@@ -88,7 +92,6 @@ class BeatmapsetBase(SQLModel):
artist_unicode: str = Field(index=True)
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
creator: str
- favourite_count: int
nsfw: bool = Field(default=False)
play_count: int
preview_url: str
@@ -112,11 +115,9 @@ class BeatmapsetBase(SQLModel):
pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
ratings: list[int] = Field(default=None, sa_column=Column(JSON))
- # TODO: recent_favourites: Optional[list[User]] = None
# TODO: related_users: Optional[list[User]] = None
# TODO: user: Optional[User] = Field(default=None)
track_id: int | None = Field(default=None) # feature artist?
- # TODO: has_favourited
# BeatmapsetExtended
bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2)))
@@ -129,7 +130,7 @@ class BeatmapsetBase(SQLModel):
tags: str = Field(default="", sa_column=Column(Text))
-class Beatmapset(BeatmapsetBase, table=True):
+class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
@@ -150,6 +151,7 @@ class Beatmapset(BeatmapsetBase, table=True):
hype_required: int = Field(default=0)
availability_info: str | None = Field(default=None)
download_disabled: bool = Field(default=False)
+ favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
async def from_resp(
@@ -197,40 +199,88 @@ class BeatmapsetResp(BeatmapsetBase):
genre: BeatmapTranslationText | None = None
language: BeatmapTranslationText | None = None
nominations: BeatmapNominations | None = None
+ has_favourited: bool = False
+ favourite_count: int = 0
+ recent_favourites: list[UserResp] = Field(default_factory=list)
@classmethod
- def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
+ async def from_db(
+ cls,
+ beatmapset: Beatmapset,
+ include: list[str] = [],
+ session: AsyncSession | None = None,
+ user: User | None = None,
+ ) -> "BeatmapsetResp":
from .beatmap import BeatmapResp
+ from .favourite_beatmapset import FavouriteBeatmapset
- beatmaps = [
- BeatmapResp.from_db(beatmap, from_set=True)
- for beatmap in beatmapset.beatmaps
- ]
+ update = {
+ "beatmaps": [
+ await BeatmapResp.from_db(beatmap, from_set=True)
+ for beatmap in await beatmapset.awaitable_attrs.beatmaps
+ ],
+ "hype": BeatmapHype(
+ current=beatmapset.hype_current, required=beatmapset.hype_required
+ ),
+ "availability": BeatmapAvailability(
+ more_information=beatmapset.availability_info,
+ download_disabled=beatmapset.download_disabled,
+ ),
+ "genre": BeatmapTranslationText(
+ name=beatmapset.beatmap_genre.name,
+ id=beatmapset.beatmap_genre.value,
+ ),
+ "language": BeatmapTranslationText(
+ name=beatmapset.beatmap_language.name,
+ id=beatmapset.beatmap_language.value,
+ ),
+ "nominations": BeatmapNominations(
+ required=beatmapset.nominations_required,
+ current=beatmapset.nominations_current,
+ ),
+ "status": beatmapset.beatmap_status.name.lower(),
+ "ranked": beatmapset.beatmap_status.value,
+ "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
+ **beatmapset.model_dump(),
+ }
+ if session and user:
+ existing_favourite = (
+ await session.exec(
+ select(FavouriteBeatmapset).where(
+ FavouriteBeatmapset.beatmapset_id == beatmapset.id
+ )
+ )
+ ).first()
+ update["has_favourited"] = existing_favourite is not None
+
+ if session and "recent_favourites" in include:
+ recent_favourites = (
+ await session.exec(
+ select(FavouriteBeatmapset)
+ .where(
+ FavouriteBeatmapset.beatmapset_id == beatmapset.id,
+ )
+ .order_by(col(FavouriteBeatmapset.date).desc())
+ .limit(50)
+ )
+ ).all()
+ update["recent_favourites"] = [
+ await UserResp.from_db(
+ await favourite.awaitable_attrs.user,
+ session=session,
+ include=BASE_INCLUDES,
+ )
+ for favourite in recent_favourites
+ ]
+
+ if session:
+ update["favourite_count"] = (
+ await session.exec(
+ select(func.count())
+ .select_from(FavouriteBeatmapset)
+ .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
+ )
+ ).one()
return cls.model_validate(
- {
- "beatmaps": beatmaps,
- "hype": BeatmapHype(
- current=beatmapset.hype_current, required=beatmapset.hype_required
- ),
- "availability": BeatmapAvailability(
- more_information=beatmapset.availability_info,
- download_disabled=beatmapset.download_disabled,
- ),
- "genre": BeatmapTranslationText(
- name=beatmapset.beatmap_genre.name,
- id=beatmapset.beatmap_genre.value,
- ),
- "language": BeatmapTranslationText(
- name=beatmapset.beatmap_language.name,
- id=beatmapset.beatmap_language.value,
- ),
- "nominations": BeatmapNominations(
- required=beatmapset.nominations_required,
- current=beatmapset.nominations_current,
- ),
- "status": beatmapset.beatmap_status.name.lower(),
- "ranked": beatmapset.beatmap_status.value,
- "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
- **beatmapset.model_dump(),
- }
+ update,
)
diff --git a/app/database/best_score.py b/app/database/best_score.py
index 313da3e..8688d5b 100644
--- a/app/database/best_score.py
+++ b/app/database/best_score.py
@@ -1,14 +1,14 @@
from typing import TYPE_CHECKING
-from app.models.score import GameMode
+from app.models.score import GameMode, Rank
-from .user import User
+from .lazer_user import User
from sqlmodel import (
+ JSON,
BigInteger,
Column,
Field,
- Float,
ForeignKey,
Relationship,
SQLModel,
@@ -20,22 +20,27 @@ if TYPE_CHECKING:
class BestScore(SQLModel, table=True):
- __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
+ __tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
- sa_column=Column(BigInteger, ForeignKey("users.id"), index=True)
+ sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
- pp: float = Field(
- sa_column=Column(Float, default=0),
- )
- acc: float = Field(
- sa_column=Column(Float, default=0),
+ total_score: int = Field(default=0, sa_column=Column(BigInteger))
+ mods: list[str] = Field(
+ default_factory=list,
+ sa_column=Column(JSON),
)
+ rank: Rank
user: User = Relationship()
- score: "Score" = Relationship()
+ score: "Score" = Relationship(
+ sa_relationship_kwargs={
+ "foreign_keys": "[BestScore.score_id]",
+ "lazy": "joined",
+ }
+ )
beatmap: "Beatmap" = Relationship()
diff --git a/app/database/daily_challenge.py b/app/database/daily_challenge.py
new file mode 100644
index 0000000..abf874f
--- /dev/null
+++ b/app/database/daily_challenge.py
@@ -0,0 +1,58 @@
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+from app.models.model import UTCBaseModel
+
+from sqlmodel import (
+ BigInteger,
+ Column,
+ DateTime,
+ Field,
+ ForeignKey,
+ Relationship,
+ SQLModel,
+)
+
+if TYPE_CHECKING:
+ from .lazer_user import User
+
+
+class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
+ daily_streak_best: int = Field(default=0)
+ daily_streak_current: int = Field(default=0)
+ last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
+ last_weekly_streak: datetime | None = Field(
+ default=None, sa_column=Column(DateTime)
+ )
+ playcount: int = Field(default=0)
+ top_10p_placements: int = Field(default=0)
+ top_50p_placements: int = Field(default=0)
+ weekly_streak_best: int = Field(default=0)
+ weekly_streak_current: int = Field(default=0)
+
+
+class DailyChallengeStats(DailyChallengeStatsBase, table=True):
+ __tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
+
+ user_id: int | None = Field(
+ default=None,
+ sa_column=Column(
+ BigInteger,
+ ForeignKey("lazer_users.id"),
+ unique=True,
+ index=True,
+ primary_key=True,
+ ),
+ )
+ user: "User" = Relationship(back_populates="daily_challenge_stats")
+
+
+class DailyChallengeStatsResp(DailyChallengeStatsBase):
+ user_id: int
+
+ @classmethod
+ def from_db(
+ cls,
+ obj: DailyChallengeStats,
+ ) -> "DailyChallengeStatsResp":
+ return cls.model_validate(obj)
diff --git a/app/database/favourite_beatmapset.py b/app/database/favourite_beatmapset.py
new file mode 100644
index 0000000..51bd578
--- /dev/null
+++ b/app/database/favourite_beatmapset.py
@@ -0,0 +1,53 @@
+import datetime
+
+from app.database.beatmapset import Beatmapset
+from app.database.lazer_user import User
+
+from sqlalchemy.ext.asyncio import AsyncAttrs
+from sqlmodel import (
+ BigInteger,
+ Column,
+ DateTime,
+ Field,
+ ForeignKey,
+ Relationship,
+ SQLModel,
+)
+
+
+class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True):
+ __tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType]
+ id: int | None = Field(
+ default=None,
+ sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
+ exclude=True,
+ )
+ user_id: int = Field(
+ default=None,
+ sa_column=Column(
+ BigInteger,
+ ForeignKey("lazer_users.id"),
+ index=True,
+ ),
+ )
+ beatmapset_id: int = Field(
+ default=None,
+ sa_column=Column(
+ ForeignKey("beatmapsets.id"),
+ index=True,
+ ),
+ )
+ date: datetime.datetime = Field(
+ default=datetime.datetime.now(datetime.UTC),
+ sa_column=Column(
+ DateTime,
+ ),
+ )
+
+ user: User = Relationship(back_populates="favourite_beatmapsets")
+ beatmapset: Beatmapset = Relationship(
+ sa_relationship_kwargs={
+ "lazy": "selectin",
+ },
+ back_populates="favourites",
+ )
diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py
new file mode 100644
index 0000000..2717c3a
--- /dev/null
+++ b/app/database/lazer_user.py
@@ -0,0 +1,333 @@
+from datetime import UTC, datetime
+from typing import TYPE_CHECKING, NotRequired, TypedDict
+
+from app.models.model import UTCBaseModel
+from app.models.score import GameMode
+from app.models.user import Country, Page, RankHistory
+
+from .achievement import UserAchievement, UserAchievementResp
+from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
+from .monthly_playcounts import MonthlyPlaycounts, MonthlyPlaycountsResp
+from .statistics import UserStatistics, UserStatisticsResp
+from .team import Team, TeamMember
+from .user_account_history import UserAccountHistory, UserAccountHistoryResp
+
+from sqlalchemy.ext.asyncio import AsyncAttrs
+from sqlmodel import (
+ JSON,
+ BigInteger,
+ Column,
+ DateTime,
+ Field,
+ Relationship,
+ SQLModel,
+ func,
+ select,
+)
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+if TYPE_CHECKING:
+ from .favourite_beatmapset import FavouriteBeatmapset
+ from .relationship import RelationshipResp
+
+
+class Kudosu(TypedDict):
+ available: int
+ total: int
+
+
+class RankHighest(TypedDict):
+ rank: int
+ updated_at: datetime
+
+
+class UserProfileCover(TypedDict):
+ url: str
+ custom_url: NotRequired[str]
+ id: NotRequired[str]
+
+
+Badge = TypedDict(
+ "Badge",
+ {
+ "awarded_at": datetime,
+ "description": str,
+ "image@2x_url": str,
+ "image_url": str,
+ "url": str,
+ },
+)
+
+
+class UserBase(UTCBaseModel, SQLModel):
+ avatar_url: str = ""
+ country_code: str = Field(default="CN", max_length=2, index=True)
+ # ? default_group: str|None
+ is_active: bool = True
+ is_bot: bool = False
+ is_supporter: bool = False
+ last_visit: datetime | None = Field(
+ default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
+ )
+ pm_friends_only: bool = False
+ profile_colour: str | None = None
+ username: str = Field(max_length=32, unique=True, index=True)
+ page: Page = Field(sa_column=Column(JSON), default=Page(html="", raw=""))
+ previous_usernames: list[str] = Field(default_factory=list, sa_column=Column(JSON))
+ # TODO: replays_watched_counts
+ support_level: int = 0
+ badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON))
+
+ # optional
+ is_restricted: bool = False
+ # blocks
+ cover: UserProfileCover = Field(
+ default=UserProfileCover(
+ url="https://assets.ppy.sh/user-profile-covers/default.jpeg"
+ ),
+ sa_column=Column(JSON),
+ )
+ beatmap_playcounts_count: int = 0
+ # kudosu
+
+ # UserExtended
+ playmode: GameMode = GameMode.OSU
+ discord: str | None = None
+ has_supported: bool = False
+ interests: str | None = None
+ join_date: datetime = Field(default=datetime.now(UTC))
+ location: str | None = None
+ max_blocks: int = 50
+ max_friends: int = 500
+ occupation: str | None = None
+ playstyle: list[str] = Field(default_factory=list, sa_column=Column(JSON))
+ # TODO: post_count
+ profile_hue: int | None = None
+ profile_order: list[str] = Field(
+ default_factory=lambda: [
+ "me",
+ "recent_activity",
+ "top_ranks",
+ "medals",
+ "historical",
+ "beatmaps",
+ "kudosu",
+ ],
+ sa_column=Column(JSON),
+ )
+ title: str | None = None
+ title_url: str | None = None
+ twitter: str | None = None
+ website: str | None = None
+
+ # undocumented
+ comments_count: int = 0
+ post_count: int = 0
+ is_admin: bool = False
+ is_gmt: bool = False
+ is_qat: bool = False
+ is_bng: bool = False
+
+
+class User(AsyncAttrs, UserBase, table=True):
+ __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
+
+ id: int | None = Field(
+ default=None,
+ sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
+ )
+ account_history: list[UserAccountHistory] = Relationship()
+ statistics: list[UserStatistics] = Relationship()
+ achievement: list[UserAchievement] = Relationship(back_populates="user")
+ team_membership: TeamMember | None = Relationship(back_populates="user")
+ daily_challenge_stats: DailyChallengeStats | None = Relationship(
+ back_populates="user"
+ )
+ monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
+ favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
+ back_populates="user"
+ )
+
+ email: str = Field(max_length=254, unique=True, index=True, exclude=True)
+ priv: int = Field(default=1, exclude=True)
+ pw_bcrypt: str = Field(max_length=60, exclude=True)
+ silence_end_at: datetime | None = Field(
+ default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
+ )
+ donor_end_at: datetime | None = Field(
+ default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
+ )
+
+
+class UserResp(UserBase):
+ id: int | None = None
+ is_online: bool = False
+ groups: list = [] # TODO
+ country: Country = Field(default_factory=lambda: Country(code="CN", name="China"))
+ favourite_beatmapset_count: int = 0 # TODO
+ graveyard_beatmapset_count: int = 0 # TODO
+ guest_beatmapset_count: int = 0 # TODO
+ loved_beatmapset_count: int = 0 # TODO
+ mapping_follower_count: int = 0 # TODO
+ nominated_beatmapset_count: int = 0 # TODO
+ pending_beatmapset_count: int = 0 # TODO
+ ranked_beatmapset_count: int = 0 # TODO
+ follow_user_mapping: list[int] = Field(default_factory=list)
+ follower_count: int = 0
+ friends: list["RelationshipResp"] | None = None
+ scores_best_count: int = 0
+ scores_first_count: int = 0
+ scores_recent_count: int = 0
+ scores_pinned_count: int = 0
+ account_history: list[UserAccountHistoryResp] = []
+ active_tournament_banners: list[dict] = [] # TODO
+ kudosu: Kudosu = Field(default_factory=lambda: Kudosu(available=0, total=0)) # TODO
+ monthly_playcounts: list[MonthlyPlaycountsResp] = Field(default_factory=list)
+ unread_pm_count: int = 0 # TODO
+ rank_history: RankHistory | None = None # TODO
+ rank_highest: RankHighest | None = None # TODO
+ statistics: UserStatisticsResp | None = None
+ statistics_rulesets: dict[str, UserStatisticsResp] | None = None
+ user_achievements: list[UserAchievementResp] = Field(default_factory=list)
+ cover_url: str = "" # deprecated
+ team: Team | None = None
+ session_verified: bool = True
+ daily_challenge_user_stats: DailyChallengeStatsResp | None = None
+
+ # TODO: monthly_playcounts, unread_pm_count, rank_history, user_preferences
+
+ @classmethod
+ async def from_db(
+ cls,
+ obj: User,
+ session: AsyncSession,
+ include: list[str] = [],
+ ruleset: GameMode | None = None,
+ ) -> "UserResp":
+ from app.dependencies.database import get_redis
+
+ from .best_score import BestScore
+ from .relationship import Relationship, RelationshipResp, RelationshipType
+
+ u = cls.model_validate(obj.model_dump())
+ u.id = obj.id
+ u.follower_count = (
+ await session.exec(
+ select(func.count())
+ .select_from(Relationship)
+ .where(
+ Relationship.target_id == obj.id,
+ Relationship.type == RelationshipType.FOLLOW,
+ )
+ )
+ ).one()
+ u.scores_best_count = (
+ await session.exec(
+ select(func.count())
+ .select_from(BestScore)
+ .where(
+ BestScore.user_id == obj.id,
+ )
+ .limit(200)
+ )
+ ).one()
+ redis = get_redis()
+ u.is_online = await redis.exists(f"metadata:online:{obj.id}")
+ u.cover_url = (
+ obj.cover.get(
+ "url", "https://assets.ppy.sh/user-profile-covers/default.jpeg"
+ )
+ if obj.cover
+ else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
+ )
+
+ if "friends" in include:
+ u.friends = [
+ await RelationshipResp.from_db(session, r)
+ for r in (
+ await session.exec(
+ select(Relationship).where(
+ Relationship.user_id == obj.id,
+ Relationship.type == RelationshipType.FOLLOW,
+ )
+ )
+ ).all()
+ ]
+
+ if "team" in include:
+ if await obj.awaitable_attrs.team_membership:
+ assert obj.team_membership
+ u.team = obj.team_membership.team
+
+ if "account_history" in include:
+ u.account_history = [
+ UserAccountHistoryResp.from_db(ah)
+ for ah in await obj.awaitable_attrs.account_history
+ ]
+
+ if "daily_challenge_user_stats":
+ if await obj.awaitable_attrs.daily_challenge_stats:
+ assert obj.daily_challenge_stats
+ u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
+ obj.daily_challenge_stats
+ )
+
+ if "statistics" in include:
+ current_stattistics = None
+ for i in await obj.awaitable_attrs.statistics:
+ if i.mode == (ruleset or obj.playmode):
+ current_stattistics = i
+ break
+ u.statistics = (
+ UserStatisticsResp.from_db(current_stattistics)
+ if current_stattistics
+ else None
+ )
+
+ if "statistics_rulesets" in include:
+ u.statistics_rulesets = {
+ i.mode.value: UserStatisticsResp.from_db(i)
+ for i in await obj.awaitable_attrs.statistics
+ }
+
+ if "monthly_playcounts" in include:
+ u.monthly_playcounts = [
+ MonthlyPlaycountsResp.from_db(pc)
+ for pc in await obj.awaitable_attrs.monthly_playcounts
+ ]
+
+ if "achievements" in include:
+ u.user_achievements = [
+ UserAchievementResp.from_db(ua)
+ for ua in await obj.awaitable_attrs.achievement
+ ]
+
+ return u
+
+
+ALL_INCLUDED = [
+ "friends",
+ "team",
+ "account_history",
+ "daily_challenge_user_stats",
+ "statistics",
+ "statistics_rulesets",
+ "achievements",
+ "monthly_playcounts",
+]
+
+
+SEARCH_INCLUDED = [
+ "team",
+ "daily_challenge_user_stats",
+ "statistics",
+ "statistics_rulesets",
+ "achievements",
+ "monthly_playcounts",
+]
+
+BASE_INCLUDES = [
+ "team",
+ "daily_challenge_user_stats",
+ "statistics",
+]
diff --git a/app/database/legacy.py b/app/database/legacy.py
deleted file mode 100644
index ff1e957..0000000
--- a/app/database/legacy.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from datetime import datetime
-from typing import TYPE_CHECKING
-
-from sqlalchemy import JSON, Column, DateTime
-from sqlalchemy.orm import Mapped
-from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
-
-if TYPE_CHECKING:
- from .user import User
-# ============================================
-# 旧的兼容性表模型(保留以便向后兼容)
-# ============================================
-
-
-class LegacyUserStatistics(SQLModel, table=True):
- __tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True, index=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- mode: str = Field(max_length=10) # osu, taiko, fruits, mania
-
- # 基本统计
- count_100: int = Field(default=0)
- count_300: int = Field(default=0)
- count_50: int = Field(default=0)
- count_miss: int = Field(default=0)
-
- # 等级信息
- level_current: int = Field(default=1)
- level_progress: int = Field(default=0)
-
- # 排名信息
- global_rank: int | None = Field(default=None)
- global_rank_exp: int | None = Field(default=None)
- country_rank: int | None = Field(default=None)
-
- # PP 和分数
- pp: float = Field(default=0.0)
- pp_exp: float = Field(default=0.0)
- ranked_score: int = Field(default=0)
- hit_accuracy: float = Field(default=0.0)
- total_score: int = Field(default=0)
- total_hits: int = Field(default=0)
- maximum_combo: int = Field(default=0)
-
- # 游戏统计
- play_count: int = Field(default=0)
- play_time: int = Field(default=0)
- replays_watched_by_others: int = Field(default=0)
- is_ranked: bool = Field(default=False)
-
- # 成绩等级计数
- grade_ss: int = Field(default=0)
- grade_ssh: int = Field(default=0)
- grade_s: int = Field(default=0)
- grade_sh: int = Field(default=0)
- grade_a: int = Field(default=0)
-
- # 最高排名记录
- rank_highest: int | None = Field(default=None)
- rank_highest_updated_at: datetime | None = Field(
- default=None, sa_column=Column(DateTime)
- )
-
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
- # 关联关系
- user: Mapped["User"] = Relationship(back_populates="statistics")
-
-
-class LegacyOAuthToken(SQLModel, table=True):
- __tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- access_token: str = Field(max_length=255, index=True)
- refresh_token: str = Field(max_length=255, index=True)
- expires_at: datetime = Field(sa_column=Column(DateTime))
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- previous_usernames: list = Field(default_factory=list, sa_column=Column(JSON))
- replays_watched_counts: list = Field(default_factory=list, sa_column=Column(JSON))
-
- # 用户关系
- user: "User" = Relationship()
diff --git a/app/database/monthly_playcounts.py b/app/database/monthly_playcounts.py
new file mode 100644
index 0000000..46192d1
--- /dev/null
+++ b/app/database/monthly_playcounts.py
@@ -0,0 +1,43 @@
+from datetime import date
+from typing import TYPE_CHECKING
+
+from sqlmodel import (
+ BigInteger,
+ Column,
+ Field,
+ ForeignKey,
+ Relationship,
+ SQLModel,
+)
+
+if TYPE_CHECKING:
+ from .lazer_user import User
+
+
+class MonthlyPlaycounts(SQLModel, table=True):
+ __tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType]
+
+ id: int | None = Field(
+ default=None,
+ sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
+ )
+ user_id: int = Field(
+ sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
+ )
+ year: int = Field(index=True)
+ month: int = Field(index=True)
+ playcount: int = Field(default=0)
+
+ user: "User" = Relationship(back_populates="monthly_playcounts")
+
+
+class MonthlyPlaycountsResp(SQLModel):
+ start_date: date
+ count: int
+
+ @classmethod
+ def from_db(cls, db_model: MonthlyPlaycounts) -> "MonthlyPlaycountsResp":
+ return cls(
+ start_date=date(db_model.year, db_model.month, 1),
+ count=db_model.playcount,
+ )
diff --git a/app/database/multiplayer_event.py b/app/database/multiplayer_event.py
new file mode 100644
index 0000000..904fbe4
--- /dev/null
+++ b/app/database/multiplayer_event.py
@@ -0,0 +1,56 @@
+from datetime import UTC, datetime
+from typing import Any
+
+from app.models.model import UTCBaseModel
+
+from sqlmodel import (
+ JSON,
+ BigInteger,
+ Column,
+ DateTime,
+ Field,
+ ForeignKey,
+ SQLModel,
+)
+
+
+class MultiplayerEventBase(SQLModel, UTCBaseModel):
+ playlist_item_id: int | None = None
+ user_id: int | None = Field(
+ default=None,
+ sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),
+ )
+ created_at: datetime = Field(
+ sa_column=Column(
+ DateTime(timezone=True),
+ ),
+ default=datetime.now(UTC),
+ )
+ event_type: str = Field(index=True)
+
+
+class MultiplayerEvent(MultiplayerEventBase, table=True):
+ __tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType]
+ id: int | None = Field(
+ default=None,
+ sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
+ )
+ room_id: int = Field(foreign_key="rooms.id", index=True)
+ updated_at: datetime = Field(
+ sa_column=Column(
+ DateTime(timezone=True),
+ ),
+ default=datetime.now(UTC),
+ )
+ event_detail: dict[str, Any] | None = Field(
+ sa_column=Column(JSON),
+ default_factory=dict,
+ )
+
+
+class MultiplayerEventResp(MultiplayerEventBase):
+ id: int
+
+ @classmethod
+ def from_db(cls, event: MultiplayerEvent) -> "MultiplayerEventResp":
+ return cls.model_validate(event)
diff --git a/app/database/playlist_attempts.py b/app/database/playlist_attempts.py
new file mode 100644
index 0000000..93bc8c5
--- /dev/null
+++ b/app/database/playlist_attempts.py
@@ -0,0 +1,114 @@
+from .lazer_user import User, UserResp
+from .playlist_best_score import PlaylistBestScore
+
+from sqlmodel import (
+ BigInteger,
+ Column,
+ Field,
+ ForeignKey,
+ Relationship,
+ SQLModel,
+ col,
+ func,
+ select,
+)
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+
+class ItemAttemptsCountBase(SQLModel):
+ room_id: int = Field(foreign_key="rooms.id", index=True)
+ attempts: int = Field(default=0)
+ completed: int = Field(default=0)
+ user_id: int = Field(
+ sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
+ )
+ accuracy: float = 0.0
+ pp: float = 0
+ total_score: int = 0
+
+
+class ItemAttemptsCount(ItemAttemptsCountBase, table=True):
+ __tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType]
+ id: int | None = Field(default=None, primary_key=True)
+
+ user: User = Relationship()
+
+ async def get_position(self, session: AsyncSession) -> int:
+ rownum = (
+ func.row_number()
+ .over(
+ partition_by=col(ItemAttemptsCountBase.room_id),
+ order_by=col(ItemAttemptsCountBase.total_score).desc(),
+ )
+ .label("rn")
+ )
+ subq = select(ItemAttemptsCountBase, rownum).subquery()
+ stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
+ result = await session.exec(stmt)
+ return result.one()
+
+ async def update(self, session: AsyncSession):
+ playlist_scores = (
+ await session.exec(
+ select(PlaylistBestScore).where(
+ PlaylistBestScore.room_id == self.room_id,
+ PlaylistBestScore.user_id == self.user_id,
+ )
+ )
+ ).all()
+ self.attempts = sum(score.attempts for score in playlist_scores)
+ self.total_score = sum(score.total_score for score in playlist_scores)
+ self.pp = sum(score.score.pp for score in playlist_scores)
+ self.completed = len(playlist_scores)
+ self.accuracy = (
+ sum(score.score.accuracy * score.attempts for score in playlist_scores)
+ / self.completed
+ if self.completed > 0
+ else 0.0
+ )
+ await session.commit()
+ await session.refresh(self)
+
+ @classmethod
+ async def get_or_create(
+ cls,
+ room_id: int,
+ user_id: int,
+ session: AsyncSession,
+ ) -> "ItemAttemptsCount":
+ item_attempts = await session.exec(
+ select(cls).where(
+ cls.room_id == room_id,
+ cls.user_id == user_id,
+ )
+ )
+ item_attempts = item_attempts.first()
+ if item_attempts is None:
+ item_attempts = cls(room_id=room_id, user_id=user_id)
+ session.add(item_attempts)
+ await session.commit()
+ await session.refresh(item_attempts)
+ await item_attempts.update(session)
+ return item_attempts
+
+
+class ItemAttemptsResp(ItemAttemptsCountBase):
+ user: UserResp | None = None
+ position: int | None = None
+
+ @classmethod
+ async def from_db(
+ cls,
+ item_attempts: ItemAttemptsCount,
+ session: AsyncSession,
+ include: list[str] = [],
+ ) -> "ItemAttemptsResp":
+ resp = cls.model_validate(item_attempts)
+ resp.user = await UserResp.from_db(
+ item_attempts.user,
+ session=session,
+ include=["statistics", "team", "daily_challenge_user_stats"],
+ )
+ if "position" in include:
+ resp.position = await item_attempts.get_position(session)
+ return resp
diff --git a/app/database/playlist_best_score.py b/app/database/playlist_best_score.py
new file mode 100644
index 0000000..46bbfba
--- /dev/null
+++ b/app/database/playlist_best_score.py
@@ -0,0 +1,109 @@
+from typing import TYPE_CHECKING
+
+from .lazer_user import User
+
+from redis.asyncio import Redis
+from sqlmodel import (
+ BigInteger,
+ Column,
+ Field,
+ ForeignKey,
+ Relationship,
+ SQLModel,
+ col,
+ func,
+ select,
+)
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+if TYPE_CHECKING:
+ from .score import Score
+
+
+class PlaylistBestScore(SQLModel, table=True):
+ __tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType]
+
+ user_id: int = Field(
+ sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
+ )
+ score_id: int = Field(
+ sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
+ )
+ room_id: int = Field(foreign_key="rooms.id", index=True)
+ playlist_id: int = Field(foreign_key="room_playlists.id", index=True)
+ total_score: int = Field(default=0, sa_column=Column(BigInteger))
+ attempts: int = Field(default=0) # playlist
+
+ user: User = Relationship()
+ score: "Score" = Relationship(
+ sa_relationship_kwargs={
+ "foreign_keys": "[PlaylistBestScore.score_id]",
+ "lazy": "joined",
+ }
+ )
+
+
+async def process_playlist_best_score(
+ room_id: int,
+ playlist_id: int,
+ user_id: int,
+ score_id: int,
+ total_score: int,
+ session: AsyncSession,
+ redis: Redis,
+):
+ previous = (
+ await session.exec(
+ select(PlaylistBestScore).where(
+ PlaylistBestScore.room_id == room_id,
+ PlaylistBestScore.playlist_id == playlist_id,
+ PlaylistBestScore.user_id == user_id,
+ )
+ )
+ ).first()
+ if previous is None:
+ score = PlaylistBestScore(
+ user_id=user_id,
+ score_id=score_id,
+ room_id=room_id,
+ playlist_id=playlist_id,
+ total_score=total_score,
+ )
+ session.add(score)
+ else:
+ previous.score_id = score_id
+ previous.total_score = total_score
+ previous.attempts += 1
+ await session.commit()
+ await redis.decr(f"multiplayer:{room_id}:gameplay:players")
+
+
+async def get_position(
+ room_id: int,
+ playlist_id: int,
+ score_id: int,
+ session: AsyncSession,
+) -> int:
+ rownum = (
+ func.row_number()
+ .over(
+ partition_by=(
+ col(PlaylistBestScore.playlist_id),
+ col(PlaylistBestScore.room_id),
+ ),
+ order_by=col(PlaylistBestScore.total_score).desc(),
+ )
+ .label("row_number")
+ )
+ subq = (
+ select(PlaylistBestScore, rownum)
+ .where(
+ PlaylistBestScore.playlist_id == playlist_id,
+ PlaylistBestScore.room_id == room_id,
+ )
+ .subquery()
+ )
+ stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
+ result = await session.exec(stmt)
+ s = result.one_or_none()
+ return s if s else 0
diff --git a/app/database/playlists.py b/app/database/playlists.py
new file mode 100644
index 0000000..3f7ae40
--- /dev/null
+++ b/app/database/playlists.py
@@ -0,0 +1,143 @@
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+from app.models.model import UTCBaseModel
+from app.models.mods import APIMod
+from app.models.multiplayer_hub import PlaylistItem
+
+from .beatmap import Beatmap, BeatmapResp
+
+from sqlmodel import (
+ JSON,
+ BigInteger,
+ Column,
+ DateTime,
+ Field,
+ ForeignKey,
+ Relationship,
+ SQLModel,
+ func,
+ select,
+)
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+if TYPE_CHECKING:
+ from .room import Room
+
+
+class PlaylistBase(SQLModel, UTCBaseModel):
+ id: int = Field(index=True)
+ owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
+ ruleset_id: int = Field(ge=0, le=3)
+ expired: bool = Field(default=False)
+ playlist_order: int = Field(default=0)
+ played_at: datetime | None = Field(
+ sa_column=Column(DateTime(timezone=True)),
+ default=None,
+ )
+ allowed_mods: list[APIMod] = Field(
+ default_factory=list,
+ sa_column=Column(JSON),
+ )
+ required_mods: list[APIMod] = Field(
+ default_factory=list,
+ sa_column=Column(JSON),
+ )
+ beatmap_id: int = Field(
+ foreign_key="beatmaps.id",
+ )
+ freestyle: bool = Field(default=False)
+
+
+class Playlist(PlaylistBase, table=True):
+ __tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType]
+ db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
+ room_id: int = Field(foreign_key="rooms.id", exclude=True)
+
+ beatmap: Beatmap = Relationship(
+ sa_relationship_kwargs={
+ "lazy": "joined",
+ }
+ )
+ room: "Room" = Relationship()
+
+ @classmethod
+ async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
+ stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(
+ cls.room_id == room_id
+ )
+ result = await session.exec(stmt)
+ return result.one()
+
+ @classmethod
+ async def from_hub(
+ cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
+ ) -> "Playlist":
+ next_id = await cls.get_next_id_for_room(room_id, session=session)
+ return cls(
+ id=next_id,
+ owner_id=playlist.owner_id,
+ ruleset_id=playlist.ruleset_id,
+ beatmap_id=playlist.beatmap_id,
+ required_mods=playlist.required_mods,
+ allowed_mods=playlist.allowed_mods,
+ expired=playlist.expired,
+ playlist_order=playlist.playlist_order,
+ played_at=playlist.played_at,
+ freestyle=playlist.freestyle,
+ room_id=room_id,
+ )
+
+ @classmethod
+ async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
+ db_playlist = await session.exec(
+ select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
+ )
+ db_playlist = db_playlist.first()
+ if db_playlist is None:
+ raise ValueError("Playlist item not found")
+ db_playlist.owner_id = playlist.owner_id
+ db_playlist.ruleset_id = playlist.ruleset_id
+ db_playlist.beatmap_id = playlist.beatmap_id
+ db_playlist.required_mods = playlist.required_mods
+ db_playlist.allowed_mods = playlist.allowed_mods
+ db_playlist.expired = playlist.expired
+ db_playlist.playlist_order = playlist.playlist_order
+ db_playlist.played_at = playlist.played_at
+ db_playlist.freestyle = playlist.freestyle
+ await session.commit()
+
+ @classmethod
+ async def add_to_db(
+ cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
+ ):
+ db_playlist = await cls.from_hub(playlist, room_id, session)
+ session.add(db_playlist)
+ await session.commit()
+ await session.refresh(db_playlist)
+ playlist.id = db_playlist.id
+
+ @classmethod
+ async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
+ db_playlist = await session.exec(
+ select(cls).where(cls.id == item_id, cls.room_id == room_id)
+ )
+ db_playlist = db_playlist.first()
+ if db_playlist is None:
+ raise ValueError("Playlist item not found")
+ await session.delete(db_playlist)
+ await session.commit()
+
+
+class PlaylistResp(PlaylistBase):
+ beatmap: BeatmapResp | None = None
+
+ @classmethod
+ async def from_db(
+ cls, playlist: Playlist, include: list[str] = []
+ ) -> "PlaylistResp":
+ data = playlist.model_dump()
+ if "beatmap" in include:
+ data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap, from_set=True)
+ resp = cls.model_validate(data)
+ return resp
diff --git a/app/database/pp_best_score.py b/app/database/pp_best_score.py
new file mode 100644
index 0000000..ffc74d3
--- /dev/null
+++ b/app/database/pp_best_score.py
@@ -0,0 +1,41 @@
+from typing import TYPE_CHECKING
+
+from app.models.score import GameMode
+
+from .lazer_user import User
+
+from sqlmodel import (
+ BigInteger,
+ Column,
+ Field,
+ Float,
+ ForeignKey,
+ Relationship,
+ SQLModel,
+)
+
+if TYPE_CHECKING:
+ from .beatmap import Beatmap
+ from .score import Score
+
+
+class PPBestScore(SQLModel, table=True):
+ __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
+ user_id: int = Field(
+ sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
+ )
+ score_id: int = Field(
+ sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
+ )
+ beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
+ gamemode: GameMode = Field(index=True)
+ pp: float = Field(
+ sa_column=Column(Float, default=0),
+ )
+ acc: float = Field(
+ sa_column=Column(Float, default=0),
+ )
+
+ user: User = Relationship()
+ score: "Score" = Relationship()
+ beatmap: "Beatmap" = Relationship()
diff --git a/app/database/relationship.py b/app/database/relationship.py
index 61dc109..b941c28 100644
--- a/app/database/relationship.py
+++ b/app/database/relationship.py
@@ -1,8 +1,6 @@
from enum import Enum
-from app.models.user import User as APIUser
-
-from .user import User as DBUser
+from .lazer_user import User, UserResp
from pydantic import BaseModel
from sqlmodel import (
@@ -24,12 +22,16 @@ class RelationshipType(str, Enum):
class Relationship(SQLModel, table=True):
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
+ id: int | None = Field(
+ default=None,
+ sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
+ exclude=True,
+ )
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
- ForeignKey("users.id"),
- primary_key=True,
+ ForeignKey("lazer_users.id"),
index=True,
),
)
@@ -37,20 +39,22 @@ class Relationship(SQLModel, table=True):
default=None,
sa_column=Column(
BigInteger,
- ForeignKey("users.id"),
- primary_key=True,
+ ForeignKey("lazer_users.id"),
index=True,
),
)
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
- target: DBUser = SQLRelationship(
- sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
+ target: User = SQLRelationship(
+ sa_relationship_kwargs={
+ "foreign_keys": "[Relationship.target_id]",
+ "lazy": "selectin",
+ }
)
class RelationshipResp(BaseModel):
target_id: int
- target: APIUser
+ target: UserResp
mutual: bool = False
type: RelationshipType
@@ -58,8 +62,6 @@ class RelationshipResp(BaseModel):
async def from_db(
cls, session: AsyncSession, relationship: Relationship
) -> "RelationshipResp":
- from app.utils import convert_db_user_to_api_user
-
target_relationship = (
await session.exec(
select(Relationship).where(
@@ -75,7 +77,16 @@ class RelationshipResp(BaseModel):
)
return cls(
target_id=relationship.target_id,
- target=await convert_db_user_to_api_user(relationship.target),
+ target=await UserResp.from_db(
+ relationship.target,
+ session,
+ include=[
+ "team",
+ "daily_challenge_user_stats",
+ "statistics",
+ "statistics_rulesets",
+ ],
+ ),
mutual=mutual,
type=relationship.type,
)
diff --git a/app/database/room.py b/app/database/room.py
index 0b79ee6..7817805 100644
--- a/app/database/room.py
+++ b/app/database/room.py
@@ -1,6 +1,125 @@
-from sqlmodel import Field, SQLModel
+from datetime import UTC, datetime
+
+from app.models.multiplayer_hub import ServerMultiplayerRoom
+from app.models.room import (
+ MatchType,
+ QueueMode,
+ RoomCategory,
+ RoomDifficultyRange,
+ RoomPlaylistItemStats,
+ RoomStatus,
+)
+
+from .lazer_user import User, UserResp
+from .playlists import Playlist, PlaylistResp
+
+from sqlmodel import (
+ BigInteger,
+ Column,
+ DateTime,
+ Field,
+ ForeignKey,
+ Relationship,
+ SQLModel,
+)
-class RoomIndex(SQLModel, table=True):
- __tablename__ = "mp_room_index" # pyright: ignore[reportAssignmentType]
- id: int = Field(default=None, primary_key=True, index=True) # pyright: ignore[reportCallIssue]
+class RoomBase(SQLModel):
+ name: str = Field(index=True)
+ category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
+ duration: int | None = Field(default=None) # minutes
+ starts_at: datetime | None = Field(
+ sa_column=Column(
+ DateTime(timezone=True),
+ ),
+ default=datetime.now(UTC),
+ )
+ ended_at: datetime | None = Field(
+ sa_column=Column(
+ DateTime(timezone=True),
+ ),
+ default=None,
+ )
+ participant_count: int = Field(default=0)
+ max_attempts: int | None = Field(default=None) # playlists
+ type: MatchType
+ queue_mode: QueueMode
+ auto_skip: bool
+ auto_start_duration: int
+ status: RoomStatus
+ # TODO: channel_id
+ # recent_participants: list[User]
+
+
+class Room(RoomBase, table=True):
+ __tablename__ = "rooms" # pyright: ignore[reportAssignmentType]
+ id: int = Field(default=None, primary_key=True, index=True)
+ host_id: int = Field(
+ sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
+ )
+
+ host: User = Relationship()
+ playlist: list[Playlist] = Relationship(
+ sa_relationship_kwargs={
+ "lazy": "joined",
+ "cascade": "all, delete-orphan",
+ "overlaps": "room",
+ }
+ )
+
+
+class RoomResp(RoomBase):
+ id: int
+ password: str | None = None
+ host: UserResp | None = None
+ playlist: list[PlaylistResp] = []
+ playlist_item_stats: RoomPlaylistItemStats | None = None
+ difficulty_range: RoomDifficultyRange | None = None
+ current_playlist_item: PlaylistResp | None = None
+
+ @classmethod
+ async def from_db(cls, room: Room) -> "RoomResp":
+ resp = cls.model_validate(room.model_dump())
+
+ stats = RoomPlaylistItemStats(count_active=0, count_total=0)
+ difficulty_range = RoomDifficultyRange(
+ min=0,
+ max=0,
+ )
+ rulesets = set()
+ for playlist in room.playlist:
+ stats.count_total += 1
+ if not playlist.expired:
+ stats.count_active += 1
+ rulesets.add(playlist.ruleset_id)
+ difficulty_range.min = min(
+ difficulty_range.min, playlist.beatmap.difficulty_rating
+ )
+ difficulty_range.max = max(
+ difficulty_range.max, playlist.beatmap.difficulty_rating
+ )
+ resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
+ stats.ruleset_ids = list(rulesets)
+ resp.playlist_item_stats = stats
+ resp.difficulty_range = difficulty_range
+ resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
+
+ return resp
+
+ @classmethod
+ async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp":
+ room = server_room.room
+ resp = cls(
+ id=room.room_id,
+ name=room.settings.name,
+ type=room.settings.match_type,
+ queue_mode=room.settings.queue_mode,
+ auto_skip=room.settings.auto_skip,
+ auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
+ status=server_room.status,
+ category=server_room.category,
+ # duration = room.settings.duration,
+ starts_at=server_room.start_at,
+ participant_count=len(room.users),
+ )
+ return resp
diff --git a/app/database/score.py b/app/database/score.py
index 046f83c..37b96a3 100644
--- a/app/database/score.py
+++ b/app/database/score.py
@@ -1,8 +1,9 @@
import asyncio
from collections.abc import Sequence
-from datetime import UTC, datetime
+from datetime import UTC, date, datetime
+import json
import math
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from app.calculator import (
calculate_pp,
@@ -12,9 +13,8 @@ from app.calculator import (
calculate_weighted_pp,
clamp,
)
-from app.database.score_token import ScoreToken
-from app.database.user import LazerUserStatistics, User
-from app.models.beatmap import BeatmapRankStatus
+from app.database.team import TeamMember
+from app.models.model import RespWithCursor, UTCBaseModel
from app.models.mods import APIMod, mods_can_get_pp
from app.models.score import (
INT_TO_MODE,
@@ -26,15 +26,24 @@ from app.models.score import (
ScoreStatistics,
SoloScoreSubmissionInfo,
)
-from app.models.user import User as APIUser
from .beatmap import Beatmap, BeatmapResp
-from .beatmapset import Beatmapset, BeatmapsetResp
+from .beatmapset import BeatmapsetResp
from .best_score import BestScore
+from .lazer_user import User, UserResp
+from .monthly_playcounts import MonthlyPlaycounts
+from .pp_best_score import PPBestScore
+from .relationship import (
+ Relationship as DBRelationship,
+ RelationshipType,
+)
+from .score_token import ScoreToken
-from redis import Redis
+from redis.asyncio import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
-from sqlalchemy.orm import aliased, joinedload
+from sqlalchemy.ext.asyncio import AsyncAttrs
+from sqlalchemy.orm import aliased
+from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import (
JSON,
BigInteger,
@@ -43,9 +52,10 @@ from sqlmodel import (
Relationship,
SQLModel,
col,
- false,
func,
select,
+ text,
+ true,
)
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql._expression_select_cls import SelectOfScalar
@@ -54,7 +64,7 @@ if TYPE_CHECKING:
from app.fetcher import Fetcher
-class ScoreBase(SQLModel):
+class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段
accuracy: float
map_md5: str = Field(max_length=32, index=True)
@@ -78,10 +88,11 @@ class ScoreBase(SQLModel):
default=0, sa_column=Column(BigInteger), exclude=True
)
type: str
+ beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
# optional
# TODO: current_user_attributes
- position: int | None = Field(default=None) # multiplayer
+ # position: int | None = Field(default=None) # multiplayer
class Score(ScoreBase, table=True):
@@ -89,12 +100,11 @@ class Score(ScoreBase, table=True):
id: int | None = Field(
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
)
- beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
- ForeignKey("users.id"),
+ ForeignKey("lazer_users.id"),
index=True,
),
)
@@ -112,28 +122,13 @@ class Score(ScoreBase, table=True):
gamemode: GameMode = Field(index=True)
# optional
- beatmap: "Beatmap" = Relationship()
- user: "User" = Relationship()
+ beatmap: Beatmap = Relationship()
+ user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@property
def is_perfect_combo(self) -> bool:
return self.max_combo == self.beatmap.max_combo
- @staticmethod
- def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]:
- clause = select(Score).options(
- joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
- .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
- .selectinload(
- Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
- ),
- )
- if with_user:
- return clause.options(
- joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType]
- )
- return clause
-
@staticmethod
def select_clause_unique(
*where_clauses: ColumnExpressionArgument[bool] | bool,
@@ -147,18 +142,7 @@ class Score(ScoreBase, table=True):
)
subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True)
- return (
- select(best)
- .where(subq.c.rn == 1)
- .options(
- joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
- .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
- .selectinload(
- Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
- ),
- joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType]
- )
- )
+ return select(best).where(subq.c.rn == 1)
class ScoreResp(ScoreBase):
@@ -173,22 +157,23 @@ class ScoreResp(ScoreBase):
ruleset_id: int | None = None
beatmap: BeatmapResp | None = None
beatmapset: BeatmapsetResp | None = None
- user: APIUser | None = None
+ user: UserResp | None = None
statistics: ScoreStatistics | None = None
maximum_statistics: ScoreStatistics | None = None
rank_global: int | None = None
rank_country: int | None = None
+ position: int | None = None
+ scores_around: "ScoreAround | None" = None
@classmethod
- async def from_db(
- cls, session: AsyncSession, score: Score, user: User | None = None
- ) -> "ScoreResp":
- from app.utils import convert_db_user_to_api_user
-
+ async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
s = cls.model_validate(score.model_dump())
assert score.id
- s.beatmap = BeatmapResp.from_db(score.beatmap)
- s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset)
+ await score.awaitable_attrs.beatmap
+ s.beatmap = await BeatmapResp.from_db(score.beatmap)
+ s.beatmapset = await BeatmapsetResp.from_db(
+ score.beatmap.beatmapset, session=session, user=score.user
+ )
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = MODE_TO_INT[score.gamemode]
@@ -220,163 +205,184 @@ class ScoreResp(ScoreBase):
s.maximum_statistics = {
HitResult.GREAT: score.beatmap.max_combo,
}
- if user:
- s.user = await convert_db_user_to_api_user(user)
+ s.user = await UserResp.from_db(
+ score.user,
+ session,
+ include=["statistics", "team", "daily_challenge_user_stats"],
+ ruleset=score.gamemode,
+ )
s.rank_global = (
await get_score_position_by_id(
session,
- score.map_md5,
+ score.beatmap_id,
score.id,
mode=score.gamemode,
- user=user or score.user,
+ user=score.user,
)
or None
)
s.rank_country = (
await get_score_position_by_id(
session,
- score.map_md5,
+ score.beatmap_id,
score.id,
score.gamemode,
- user or score.user,
+ score.user,
+ type=LeaderboardType.COUNTRY,
)
or None
)
return s
+class MultiplayerScores(RespWithCursor):
+ scores: list[ScoreResp] = Field(default_factory=list)
+ params: dict[str, Any] = Field(default_factory=dict)
+
+
+class ScoreAround(SQLModel):
+ higher: MultiplayerScores | None = None
+ lower: MultiplayerScores | None = None
+
+
async def get_best_id(session: AsyncSession, score_id: int) -> None:
rownum = (
func.row_number()
- .over(partition_by=col(BestScore.user_id), order_by=col(BestScore.pp).desc())
+ .over(
+ partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()
+ )
.label("rn")
)
- subq = select(BestScore, rownum).subquery()
+ subq = select(PPBestScore, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
result = await session.exec(stmt)
return result.one_or_none()
+async def _score_where(
+ type: LeaderboardType,
+ beatmap: int,
+ mode: GameMode,
+ mods: list[str] | None = None,
+ user: User | None = None,
+) -> list[ColumnElement[bool]] | None:
+ wheres = [
+ col(BestScore.beatmap_id) == beatmap,
+ col(BestScore.gamemode) == mode,
+ ]
+
+ if type == LeaderboardType.FRIENDS:
+ if user and user.is_supporter:
+ subq = (
+ select(DBRelationship.target_id)
+ .where(
+ DBRelationship.type == RelationshipType.FOLLOW,
+ DBRelationship.user_id == user.id,
+ )
+ .subquery()
+ )
+ wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id)))
+ else:
+ return None
+ elif type == LeaderboardType.COUNTRY:
+ if user and user.is_supporter:
+ wheres.append(
+ col(BestScore.user).has(col(User.country_code) == user.country_code)
+ )
+ else:
+ return None
+ elif type == LeaderboardType.TEAM:
+ if user:
+ team_membership = await user.awaitable_attrs.team_membership
+ if team_membership:
+ team_id = team_membership.team_id
+ wheres.append(
+ col(BestScore.user).has(
+ col(User.team_membership).has(TeamMember.team_id == team_id)
+ )
+ )
+ if mods:
+ if user and user.is_supporter:
+ wheres.append(
+ text(
+ "JSON_CONTAINS(total_score_best_scores.mods, :w)"
+ " AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
+ ) # pyright: ignore[reportArgumentType]
+ )
+ else:
+ return None
+ return wheres
+
+
async def get_leaderboard(
session: AsyncSession,
- beatmap_md5: str,
+ beatmap: int,
mode: GameMode,
type: LeaderboardType = LeaderboardType.GLOBAL,
- mods: list[APIMod] | None = None,
+ mods: list[str] | None = None,
user: User | None = None,
limit: int = 50,
-) -> list[Score]:
- scores = []
- if type == LeaderboardType.GLOBAL:
- query = (
- select(Score)
- .where(
- col(Beatmap.beatmap_status).in_(
- [
- BeatmapRankStatus.RANKED,
- BeatmapRankStatus.LOVED,
- BeatmapRankStatus.QUALIFIED,
- BeatmapRankStatus.APPROVED,
- ]
- ),
- Score.map_md5 == beatmap_md5,
- Score.gamemode == mode,
- col(Score.passed).is_(True),
- Score.mods == mods if user and user.is_supporter else false(),
- )
- .limit(limit)
- .order_by(
- col(Score.total_score).desc(),
- )
- )
- result = await session.exec(query)
- scores = list[Score](result.all())
- elif type == LeaderboardType.FRIENDS and user and user.is_supporter:
- # TODO
- ...
- elif type == LeaderboardType.TEAM and user and user.team_membership:
- team_id = user.team_membership.team_id
- query = (
- select(Score)
- .join(Beatmap)
- .options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
- .where(
- Score.map_md5 == beatmap_md5,
- Score.gamemode == mode,
- col(Score.passed).is_(True),
- col(Score.user.team_membership).is_not(None),
- Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
- Score.mods == mods if user and user.is_supporter else false(),
- )
- .limit(limit)
- .order_by(
- col(Score.total_score).desc(),
- )
- )
- result = await session.exec(query)
- scores = list[Score](result.all())
+) -> tuple[list[Score], Score | None]:
+ wheres = await _score_where(type, beatmap, mode, mods, user)
+ if wheres is None:
+ return [], None
+ query = (
+ select(BestScore)
+ .where(*wheres)
+ .limit(limit)
+ .order_by(col(BestScore.total_score).desc())
+ )
+ if mods:
+ query = query.params(w=json.dumps(mods))
+ scores = [s.score for s in await session.exec(query)]
+ user_score = None
if user:
- user_score = (
- await session.exec(
- select(Score).where(
- Score.map_md5 == beatmap_md5,
- Score.gamemode == mode,
- Score.user_id == user.id,
- col(Score.passed).is_(True),
- )
+ self_query = (
+ select(BestScore)
+ .where(BestScore.user_id == user.id)
+ .where(
+ col(BestScore.beatmap_id) == beatmap,
+ col(BestScore.gamemode) == mode,
)
- ).first()
+ .order_by(col(BestScore.total_score).desc())
+ .limit(1)
+ )
+ if mods:
+ self_query = self_query.where(
+ text(
+ "JSON_CONTAINS(total_score_best_scores.mods, :w)"
+ " AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
+ )
+ ).params(w=json.dumps(mods))
+ user_bs = (await session.exec(self_query)).first()
+ if user_bs:
+ user_score = user_bs.score
if user_score and user_score not in scores:
scores.append(user_score)
- return scores
+ return scores, user_score
async def get_score_position_by_user(
session: AsyncSession,
- beatmap_md5: str,
+ beatmap: int,
user: User,
mode: GameMode,
type: LeaderboardType = LeaderboardType.GLOBAL,
- mods: list[APIMod] | None = None,
+ mods: list[str] | None = None,
) -> int:
- where_clause = [
- Score.map_md5 == beatmap_md5,
- Score.gamemode == mode,
- col(Score.passed).is_(True),
- col(Beatmap.beatmap_status).in_(
- [
- BeatmapRankStatus.RANKED,
- BeatmapRankStatus.LOVED,
- BeatmapRankStatus.QUALIFIED,
- BeatmapRankStatus.APPROVED,
- ]
- ),
- ]
- if mods and user.is_supporter:
- where_clause.append(Score.mods == mods)
- else:
- where_clause.append(false())
- if type == LeaderboardType.FRIENDS and user.is_supporter:
- # TODO
- ...
- elif type == LeaderboardType.TEAM and user.team_membership:
- team_id = user.team_membership.team_id
- where_clause.append(
- col(Score.user.team_membership).is_not(None),
- )
- where_clause.append(
- Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
- )
+ wheres = await _score_where(type, beatmap, mode, mods, user=user)
+ if wheres is None:
+ return 0
rownum = (
func.row_number()
.over(
- partition_by=Score.map_md5,
- order_by=col(Score.total_score).desc(),
+ partition_by=col(BestScore.beatmap_id),
+ order_by=col(BestScore.total_score).desc(),
)
.label("row_number")
)
- subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery()
- stmt = select(subq.c.row_number).where(subq.c.user == user)
+ subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
+ stmt = select(subq.c.row_number).where(subq.c.user_id == user.id)
result = await session.exec(stmt)
s = result.one_or_none()
return s if s else 0
@@ -384,57 +390,26 @@ async def get_score_position_by_user(
async def get_score_position_by_id(
session: AsyncSession,
- beatmap_md5: str,
+ beatmap: int,
score_id: int,
mode: GameMode,
user: User | None = None,
type: LeaderboardType = LeaderboardType.GLOBAL,
- mods: list[APIMod] | None = None,
+ mods: list[str] | None = None,
) -> int:
- where_clause = [
- Score.map_md5 == beatmap_md5,
- Score.id == score_id,
- Score.gamemode == mode,
- col(Score.passed).is_(True),
- col(Beatmap.beatmap_status).in_(
- [
- BeatmapRankStatus.RANKED,
- BeatmapRankStatus.LOVED,
- BeatmapRankStatus.QUALIFIED,
- BeatmapRankStatus.APPROVED,
- ]
- ),
- ]
- if mods and user and user.is_supporter:
- where_clause.append(Score.mods == mods)
- elif mods:
- where_clause.append(false())
+ wheres = await _score_where(type, beatmap, mode, mods, user=user)
+ if wheres is None:
+ return 0
rownum = (
func.row_number()
.over(
- partition_by=[col(Score.user_id), col(Score.map_md5)],
- order_by=col(Score.total_score).desc(),
+ partition_by=col(BestScore.beatmap_id),
+ order_by=col(BestScore.total_score).desc(),
)
- .label("rownum")
+ .label("row_number")
)
- subq = (
- select(Score.user_id, Score.id, Score.total_score, rownum)
- .join(Beatmap)
- .where(*where_clause)
- .subquery()
- )
- best_scores = aliased(subq)
- overall_rank = (
- func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank")
- )
- final_q = (
- select(best_scores.c.id, overall_rank)
- .select_from(best_scores)
- .where(best_scores.c.rownum == 1)
- .subquery()
- )
-
- stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id)
+ subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
+ stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
result = await session.exec(stmt)
s = result.one_or_none()
return s if s else 0
@@ -445,16 +420,38 @@ async def get_user_best_score_in_beatmap(
beatmap: int,
user: int,
mode: GameMode | None = None,
-) -> Score | None:
+) -> BestScore | None:
return (
await session.exec(
- Score.select_clause(False)
+ select(BestScore)
.where(
- Score.gamemode == mode if mode is not None else True,
- Score.beatmap_id == beatmap,
- Score.user_id == user,
+ BestScore.gamemode == mode if mode is not None else true(),
+ BestScore.beatmap_id == beatmap,
+ BestScore.user_id == user,
)
- .order_by(col(Score.total_score).desc())
+ .order_by(col(BestScore.total_score).desc())
+ )
+ ).first()
+
+
+# FIXME
+async def get_user_best_score_with_mod_in_beatmap(
+ session: AsyncSession,
+ beatmap: int,
+ user: int,
+ mod: list[str],
+ mode: GameMode | None = None,
+) -> BestScore | None:
+ return (
+ await session.exec(
+ select(BestScore)
+ .where(
+ BestScore.gamemode == mode if mode is not None else True,
+ BestScore.beatmap_id == beatmap,
+ BestScore.user_id == user,
+ # BestScore.mods == mod,
+ )
+ .order_by(col(BestScore.total_score).desc())
)
).first()
@@ -464,13 +461,13 @@ async def get_user_best_pp_in_beatmap(
beatmap: int,
user: int,
mode: GameMode,
-) -> BestScore | None:
+) -> PPBestScore | None:
return (
await session.exec(
- select(BestScore).where(
- BestScore.beatmap_id == beatmap,
- BestScore.user_id == user,
- BestScore.gamemode == mode,
+ select(PPBestScore).where(
+ PPBestScore.beatmap_id == beatmap,
+ PPBestScore.user_id == user,
+ PPBestScore.gamemode == mode,
)
)
).first()
@@ -480,12 +477,12 @@ async def get_user_best_pp(
session: AsyncSession,
user: int,
limit: int = 200,
-) -> Sequence[BestScore]:
+) -> Sequence[PPBestScore]:
return (
await session.exec(
- select(BestScore)
- .where(BestScore.user_id == user)
- .order_by(col(BestScore.pp).desc())
+ select(PPBestScore)
+ .where(PPBestScore.user_id == user)
+ .order_by(col(PPBestScore.pp).desc())
.limit(limit)
)
).all()
@@ -494,27 +491,45 @@ async def get_user_best_pp(
async def process_user(
session: AsyncSession, user: User, score: Score, ranked: bool = False
):
+ assert user.id
+ assert score.id
+ mod_for_save = list({mod["acronym"] for mod in score.mods})
previous_score_best = await get_user_best_score_in_beatmap(
session, score.beatmap_id, user.id, score.gamemode
)
- statistics = None
+ previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap(
+ session, score.beatmap_id, user.id, mod_for_save, score.gamemode
+ )
add_to_db = False
- for i in user.lazer_statistics:
+ mouthly_playcount = (
+ await session.exec(
+ select(MonthlyPlaycounts).where(
+ MonthlyPlaycounts.user_id == user.id,
+ MonthlyPlaycounts.year == date.today().year,
+ MonthlyPlaycounts.month == date.today().month,
+ )
+ )
+ ).first()
+ if mouthly_playcount is None:
+ mouthly_playcount = MonthlyPlaycounts(
+ user_id=user.id, year=date.today().year, month=date.today().month
+ )
+ add_to_db = True
+ statistics = None
+ for i in await user.awaitable_attrs.statistics:
if i.mode == score.gamemode.value:
statistics = i
break
if statistics is None:
- statistics = LazerUserStatistics(
- mode=score.gamemode.value,
- user_id=user.id,
+ raise ValueError(
+ f"User {user.id} does not have statistics for mode {score.gamemode.value}"
)
- add_to_db = True
# pc, pt, tth, tts
statistics.total_score += score.total_score
difference = (
score.total_score - previous_score_best.total_score
- if previous_score_best and previous_score_best.id != score.id
+ if previous_score_best
else score.total_score
)
if difference > 0 and score.passed and ranked:
@@ -541,11 +556,48 @@ async def process_user(
statistics.grade_sh -= 1
case Rank.A:
statistics.grade_a -= 1
+ else:
+ previous_score_best = BestScore(
+ user_id=user.id,
+ beatmap_id=score.beatmap_id,
+ gamemode=score.gamemode,
+ score_id=score.id,
+ total_score=score.total_score,
+ rank=score.rank,
+ mods=mod_for_save,
+ )
+ session.add(previous_score_best)
+
statistics.ranked_score += difference
statistics.level_current = calculate_score_to_level(statistics.ranked_score)
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
+ if score.passed and ranked:
+ if previous_score_best_mod is not None:
+ previous_score_best_mod.mods = mod_for_save
+ previous_score_best_mod.score_id = score.id
+ previous_score_best_mod.rank = score.rank
+ previous_score_best_mod.total_score = score.total_score
+ elif (
+ previous_score_best is not None and previous_score_best.score_id != score.id
+ ):
+ session.add(
+ BestScore(
+ user_id=user.id,
+ beatmap_id=score.beatmap_id,
+ gamemode=score.gamemode,
+ score_id=score.id,
+ total_score=score.total_score,
+ rank=score.rank,
+ mods=mod_for_save,
+ )
+ )
statistics.play_count += 1
+ mouthly_playcount.playcount += 1
statistics.play_time += int((score.ended_at - score.started_at).total_seconds())
+ statistics.count_100 += score.n100 + score.nkatu
+ statistics.count_300 += score.n300 + score.ngeki
+ statistics.count_50 += score.n50
+ statistics.count_miss += score.nmiss
statistics.total_hits += (
score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
)
@@ -563,11 +615,8 @@ async def process_user(
acc_sum = clamp(acc_sum, 0.0, 100.0)
statistics.pp = pp_sum
statistics.hit_accuracy = acc_sum
-
- statistics.updated_at = datetime.now(UTC)
-
if add_to_db:
- session.add(statistics)
+ session.add(mouthly_playcount)
await session.commit()
await session.refresh(user)
@@ -581,7 +630,10 @@ async def process_score(
fetcher: "Fetcher",
session: AsyncSession,
redis: Redis,
+ item_id: int | None = None,
+ room_id: int | None = None,
) -> Score:
+ assert user.id
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
score = Score(
accuracy=info.accuracy,
@@ -611,6 +663,8 @@ async def process_score(
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
+ playlist_item_id=item_id,
+ room_id=room_id,
)
if can_get_pp:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
@@ -628,7 +682,7 @@ async def process_score(
)
if previous_pp_best is None or score.pp > previous_pp_best.pp:
assert score.id
- best_score = BestScore(
+ best_score = PPBestScore(
user_id=user_id,
score_id=score.id,
beatmap_id=beatmap_id,
@@ -637,7 +691,7 @@ async def process_score(
acc=score.accuracy,
)
session.add(best_score)
- session.delete(previous_pp_best) if previous_pp_best else None
+ await session.delete(previous_pp_best) if previous_pp_best else None
await session.commit()
await session.refresh(score)
await session.refresh(score_token)
diff --git a/app/database/score_token.py b/app/database/score_token.py
index 6a6edb3..4467b8b 100644
--- a/app/database/score_token.py
+++ b/app/database/score_token.py
@@ -1,15 +1,16 @@
from datetime import datetime
+from app.models.model import UTCBaseModel
from app.models.score import GameMode
from .beatmap import Beatmap
-from .user import User
+from .lazer_user import User
from sqlalchemy import Column, DateTime, Index
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
-class ScoreTokenBase(SQLModel):
+class ScoreTokenBase(SQLModel, UTCBaseModel):
score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
ruleset_id: GameMode
playlist_item_id: int | None = Field(default=None) # playlist
@@ -34,10 +35,10 @@ class ScoreToken(ScoreTokenBase, table=True):
autoincrement=True,
),
)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
+ user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
beatmap_id: int = Field(foreign_key="beatmaps.id")
- user: "User" = Relationship()
- beatmap: "Beatmap" = Relationship()
+ user: User = Relationship()
+ beatmap: Beatmap = Relationship()
class ScoreTokenResp(ScoreTokenBase):
diff --git a/app/database/statistics.py b/app/database/statistics.py
new file mode 100644
index 0000000..cac2971
--- /dev/null
+++ b/app/database/statistics.py
@@ -0,0 +1,95 @@
+from typing import TYPE_CHECKING
+
+from app.models.score import GameMode
+
+from sqlmodel import (
+ BigInteger,
+ Column,
+ Field,
+ ForeignKey,
+ Relationship,
+ SQLModel,
+)
+
+if TYPE_CHECKING:
+ from .lazer_user import User
+
+
+class UserStatisticsBase(SQLModel):
+ mode: GameMode
+ count_100: int = Field(default=0, sa_column=Column(BigInteger))
+ count_300: int = Field(default=0, sa_column=Column(BigInteger))
+ count_50: int = Field(default=0, sa_column=Column(BigInteger))
+ count_miss: int = Field(default=0, sa_column=Column(BigInteger))
+
+ global_rank: int | None = Field(default=None)
+ country_rank: int | None = Field(default=None)
+
+ pp: float = Field(default=0.0)
+ ranked_score: int = Field(default=0)
+ hit_accuracy: float = Field(default=0.00)
+ total_score: int = Field(default=0, sa_column=Column(BigInteger))
+ total_hits: int = Field(default=0, sa_column=Column(BigInteger))
+ maximum_combo: int = Field(default=0)
+
+ play_count: int = Field(default=0)
+ play_time: int = Field(default=0, sa_column=Column(BigInteger))
+ replays_watched_by_others: int = Field(default=0)
+ is_ranked: bool = Field(default=True)
+
+
+class UserStatistics(UserStatisticsBase, table=True):
+ __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
+ id: int | None = Field(default=None, primary_key=True)
+ user_id: int = Field(
+ default=None,
+ sa_column=Column(
+ BigInteger,
+ ForeignKey("lazer_users.id"),
+ index=True,
+ ),
+ )
+ grade_ss: int = Field(default=0)
+ grade_ssh: int = Field(default=0)
+ grade_s: int = Field(default=0)
+ grade_sh: int = Field(default=0)
+ grade_a: int = Field(default=0)
+
+ level_current: int = Field(default=1)
+ level_progress: int = Field(default=0)
+
+ user: "User" = Relationship(back_populates="statistics") # type: ignore[valid-type]
+
+
+class UserStatisticsResp(UserStatisticsBase):
+ grade_counts: dict[str, int] = Field(
+ default_factory=lambda: {
+ "ss": 0,
+ "ssh": 0,
+ "s": 0,
+ "sh": 0,
+ "a": 0,
+ }
+ )
+ level: dict[str, int] = Field(
+ default_factory=lambda: {
+ "current": 1,
+ "progress": 0,
+ }
+ )
+
+ @classmethod
+ def from_db(cls, obj: UserStatistics) -> "UserStatisticsResp":
+ s = cls.model_validate(obj)
+ s.grade_counts = {
+ "ss": obj.grade_ss,
+ "ssh": obj.grade_ssh,
+ "s": obj.grade_s,
+ "sh": obj.grade_sh,
+ "a": obj.grade_a,
+ }
+ s.level = {
+ "current": obj.level_current,
+ "progress": obj.level_progress,
+ }
+ return s
diff --git a/app/database/team.py b/app/database/team.py
index 360e805..562b0c8 100644
--- a/app/database/team.py
+++ b/app/database/team.py
@@ -1,14 +1,16 @@
from datetime import datetime
from typing import TYPE_CHECKING
+from app.models.model import UTCBaseModel
+
from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
- from .user import User
+ from .lazer_user import User
-class Team(SQLModel, table=True):
+class Team(SQLModel, UTCBaseModel, table=True):
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
@@ -22,15 +24,19 @@ class Team(SQLModel, table=True):
members: list["TeamMember"] = Relationship(back_populates="team")
-class TeamMember(SQLModel, table=True):
+class TeamMember(SQLModel, UTCBaseModel, table=True):
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
+ user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
team_id: int = Field(foreign_key="teams.id")
joined_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
- user: "User" = Relationship(back_populates="team_membership")
- team: "Team" = Relationship(back_populates="members")
+ user: "User" = Relationship(
+ back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
+ )
+ team: "Team" = Relationship(
+ back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
+ )
diff --git a/app/database/user.py b/app/database/user.py
deleted file mode 100644
index a188497..0000000
--- a/app/database/user.py
+++ /dev/null
@@ -1,527 +0,0 @@
-from dataclasses import dataclass
-from datetime import datetime
-from typing import Optional
-
-from .legacy import LegacyUserStatistics
-from .team import TeamMember
-
-from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text
-from sqlalchemy.dialects.mysql import VARCHAR
-from sqlalchemy.orm import joinedload, selectinload
-from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select
-
-
-class User(SQLModel, table=True):
- __tablename__ = "users" # pyright: ignore[reportAssignmentType]
-
- # 主键
- id: int = Field(
- default=None, sa_column=Column(BigInteger, primary_key=True, index=True)
- )
-
- # 基本信息(匹配 migrations_old 中的结构)
- name: str = Field(max_length=32, unique=True, index=True) # 用户名
- safe_name: str = Field(max_length=32, unique=True, index=True) # 安全用户名
- email: str = Field(max_length=254, unique=True, index=True)
- priv: int = Field(default=1) # 权限
- pw_bcrypt: str = Field(max_length=60) # bcrypt 哈希密码
- country: str = Field(default="CN", max_length=2) # 国家代码
-
- # 状态和时间
- silence_end: int = Field(default=0)
- donor_end: int = Field(default=0)
- creation_time: int = Field(default=0) # Unix 时间戳
- latest_activity: int = Field(default=0) # Unix 时间戳
-
- # 游戏相关
- preferred_mode: int = Field(default=0) # 偏好游戏模式
- play_style: int = Field(default=0) # 游戏风格
-
- # 扩展信息
- clan_id: int = Field(default=0)
- clan_priv: int = Field(default=0)
- custom_badge_name: str | None = Field(default=None, max_length=16)
- custom_badge_icon: str | None = Field(default=None, max_length=64)
- userpage_content: str | None = Field(default=None, max_length=2048)
- api_key: str | None = Field(default=None, max_length=36, unique=True)
-
- # 虚拟字段用于兼容性
- @property
- def username(self):
- return self.name
-
- @property
- def country_code(self):
- return self.country
-
- @property
- def join_date(self):
- creation_time = getattr(self, "creation_time", 0)
- return (
- datetime.fromtimestamp(creation_time)
- if creation_time > 0
- else datetime.utcnow()
- )
-
- @property
- def last_visit(self):
- latest_activity = getattr(self, "latest_activity", 0)
- return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None
-
- @property
- def is_supporter(self):
- return self.lazer_profile.is_supporter if self.lazer_profile else False
-
- # 关联关系
- lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user")
- lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user")
- lazer_counts: Optional["LazerUserCounts"] = Relationship(back_populates="user")
- lazer_achievements: list["LazerUserAchievement"] = Relationship(
- back_populates="user"
- )
- lazer_profile_sections: list["LazerUserProfileSections"] = Relationship(
- back_populates="user"
- )
- statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user")
- team_membership: Optional["TeamMember"] = Relationship(back_populates="user")
- daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship(
- back_populates="user"
- )
- rank_history: list["RankHistory"] = Relationship(back_populates="user")
- avatar: Optional["UserAvatar"] = Relationship(back_populates="user")
- active_banners: list["LazerUserBanners"] = Relationship(back_populates="user")
- lazer_badges: list["LazerUserBadge"] = Relationship(back_populates="user")
- lazer_monthly_playcounts: list["LazerUserMonthlyPlaycounts"] = Relationship(
- back_populates="user"
- )
- lazer_previous_usernames: list["LazerUserPreviousUsername"] = Relationship(
- back_populates="user"
- )
- lazer_replays_watched: list["LazerUserReplaysWatched"] = Relationship(
- back_populates="user"
- )
-
- @classmethod
- def all_select_option(cls):
- return (
- joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType]
- joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType]
- joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
- joinedload(cls.avatar), # pyright: ignore[reportArgumentType]
- selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType]
- selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType]
- selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType]
- selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
- joinedload(cls.team_membership), # pyright: ignore[reportArgumentType]
- selectinload(cls.rank_history), # pyright: ignore[reportArgumentType]
- selectinload(cls.active_banners), # pyright: ignore[reportArgumentType]
- selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType]
- selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
- selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
- selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType]
- )
-
- @classmethod
- def all_select_clause(cls):
- return select(cls).options(*cls.all_select_option())
-
-
-# ============================================
-# Lazer API 专用表模型
-# ============================================
-
-
-class LazerUserProfile(SQLModel, table=True):
- __tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType]
-
- user_id: int = Field(
- default=None,
- sa_column=Column(
- BigInteger,
- ForeignKey("users.id"),
- primary_key=True,
- ),
- )
-
- # 基本状态字段
- is_active: bool = Field(default=True)
- is_bot: bool = Field(default=False)
- is_deleted: bool = Field(default=False)
- is_online: bool = Field(default=True)
- is_supporter: bool = Field(default=False)
- is_restricted: bool = Field(default=False)
- session_verified: bool = Field(default=False)
- has_supported: bool = Field(default=False)
- pm_friends_only: bool = Field(default=False)
-
- # 基本资料字段
- default_group: str = Field(default="default", max_length=50)
- last_visit: datetime | None = Field(default=None, sa_column=Column(DateTime))
- join_date: datetime | None = Field(default=None, sa_column=Column(DateTime))
- profile_colour: str | None = Field(default=None, max_length=7)
- profile_hue: int | None = Field(default=None)
-
- # 社交媒体和个人资料字段
- avatar_url: str | None = Field(default=None, max_length=500)
- cover_url: str | None = Field(default=None, max_length=500)
- discord: str | None = Field(default=None, max_length=100)
- twitter: str | None = Field(default=None, max_length=100)
- website: str | None = Field(default=None, max_length=500)
- title: str | None = Field(default=None, max_length=100)
- title_url: str | None = Field(default=None, max_length=500)
- interests: str | None = Field(default=None, sa_column=Column(Text))
- location: str | None = Field(default=None, max_length=100)
-
- occupation: str | None = Field(default=None) # 职业字段,默认为 None
-
- # 游戏相关字段
- playmode: str = Field(default="osu", max_length=10)
- support_level: int = Field(default=0)
- max_blocks: int = Field(default=100)
- max_friends: int = Field(default=500)
- post_count: int = Field(default=0)
-
- # 页面内容
- page_html: str | None = Field(default=None, sa_column=Column(Text))
- page_raw: str | None = Field(default=None, sa_column=Column(Text))
-
- profile_order: str = Field(
- default="me,recent_activity,top_ranks,medals,historical,beatmaps,kudosu"
- )
-
- # 关联关系
- user: "User" = Relationship(back_populates="lazer_profile")
-
-
-class LazerUserProfileSections(SQLModel, table=True):
- __tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- section_name: str = Field(sa_column=Column(VARCHAR(50)))
- display_order: int | None = Field(default=None)
-
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
- user: "User" = Relationship(back_populates="lazer_profile_sections")
-
-
-class LazerUserCountry(SQLModel, table=True):
- __tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType]
-
- user_id: int = Field(
- default=None,
- sa_column=Column(
- BigInteger,
- ForeignKey("users.id"),
- primary_key=True,
- ),
- )
- code: str = Field(max_length=2)
- name: str = Field(max_length=100)
-
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
-
-class LazerUserKudosu(SQLModel, table=True):
- __tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType]
-
- user_id: int = Field(
- default=None,
- sa_column=Column(
- BigInteger,
- ForeignKey("users.id"),
- primary_key=True,
- ),
- )
- available: int = Field(default=0)
- total: int = Field(default=0)
-
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
-
-class LazerUserCounts(SQLModel, table=True):
- __tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType]
-
- user_id: int = Field(
- default=None,
- sa_column=Column(
- BigInteger,
- ForeignKey("users.id"),
- primary_key=True,
- ),
- )
-
- # 统计计数字段
- beatmap_playcounts_count: int = Field(default=0)
- comments_count: int = Field(default=0)
- favourite_beatmapset_count: int = Field(default=0)
- follower_count: int = Field(default=0)
- graveyard_beatmapset_count: int = Field(default=0)
- guest_beatmapset_count: int = Field(default=0)
- loved_beatmapset_count: int = Field(default=0)
- mapping_follower_count: int = Field(default=0)
- nominated_beatmapset_count: int = Field(default=0)
- pending_beatmapset_count: int = Field(default=0)
- ranked_beatmapset_count: int = Field(default=0)
- ranked_and_approved_beatmapset_count: int = Field(default=0)
- unranked_beatmapset_count: int = Field(default=0)
- scores_best_count: int = Field(default=0)
- scores_first_count: int = Field(default=0)
- scores_pinned_count: int = Field(default=0)
- scores_recent_count: int = Field(default=0)
-
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
- # 关联关系
- user: "User" = Relationship(back_populates="lazer_counts")
-
-
-class LazerUserStatistics(SQLModel, table=True):
- __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
-
- user_id: int = Field(
- default=None,
- sa_column=Column(
- BigInteger,
- ForeignKey("users.id"),
- primary_key=True,
- ),
- )
- mode: str = Field(default="osu", max_length=10, primary_key=True)
-
- # 基本命中统计
- count_100: int = Field(default=0)
- count_300: int = Field(default=0)
- count_50: int = Field(default=0)
- count_miss: int = Field(default=0)
-
- # 等级信息
- level_current: int = Field(default=1)
- level_progress: int = Field(default=0)
-
- # 排名信息
- global_rank: int | None = Field(default=None)
- global_rank_exp: int | None = Field(default=None)
- country_rank: int | None = Field(default=None)
-
- # PP 和分数
- pp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2)))
- pp_exp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2)))
- ranked_score: int = Field(default=0, sa_column=Column(BigInteger))
- hit_accuracy: float = Field(default=0.00, sa_column=Column(DECIMAL(5, 2)))
- total_score: int = Field(default=0, sa_column=Column(BigInteger))
- total_hits: int = Field(default=0, sa_column=Column(BigInteger))
- maximum_combo: int = Field(default=0)
-
- # 游戏统计
- play_count: int = Field(default=0)
- play_time: int = Field(default=0) # 秒
- replays_watched_by_others: int = Field(default=0)
- is_ranked: bool = Field(default=False)
-
- # 成绩等级计数
- grade_ss: int = Field(default=0)
- grade_ssh: int = Field(default=0)
- grade_s: int = Field(default=0)
- grade_sh: int = Field(default=0)
- grade_a: int = Field(default=0)
-
- # 最高排名记录
- rank_highest: int | None = Field(default=None)
- rank_highest_updated_at: datetime | None = Field(
- default=None, sa_column=Column(DateTime)
- )
-
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
- # 关联关系
- user: "User" = Relationship(back_populates="lazer_statistics")
-
-
-class LazerUserBanners(SQLModel, table=True):
- __tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- tournament_id: int
- image_url: str = Field(sa_column=Column(VARCHAR(500)))
- is_active: bool | None = Field(default=None)
-
- # 修正user关系的back_populates值
- user: "User" = Relationship(back_populates="active_banners")
-
-
-class LazerUserAchievement(SQLModel, table=True):
- __tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True, index=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- achievement_id: int
- achieved_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
- user: "User" = Relationship(back_populates="lazer_achievements")
-
-
-class LazerUserBadge(SQLModel, table=True):
- __tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True, index=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- badge_id: int
- awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
- description: str | None = Field(default=None, sa_column=Column(Text))
- image_url: str | None = Field(default=None, max_length=500)
- url: str | None = Field(default=None, max_length=500)
-
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
- user: "User" = Relationship(back_populates="lazer_badges")
-
-
-class LazerUserMonthlyPlaycounts(SQLModel, table=True):
- __tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True, index=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- start_date: datetime = Field(sa_column=Column(Date))
- play_count: int = Field(default=0)
-
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
- user: "User" = Relationship(back_populates="lazer_monthly_playcounts")
-
-
-class LazerUserPreviousUsername(SQLModel, table=True):
- __tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True, index=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- username: str = Field(max_length=32)
- changed_at: datetime = Field(sa_column=Column(DateTime))
-
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
- user: "User" = Relationship(back_populates="lazer_previous_usernames")
-
-
-class LazerUserReplaysWatched(SQLModel, table=True):
- __tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True, index=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- start_date: datetime = Field(sa_column=Column(Date))
- count: int = Field(default=0)
-
- created_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
- user: "User" = Relationship(back_populates="lazer_replays_watched")
-
-
-# 类型转换用的 UserAchievement(不是 SQLAlchemy 模型)
-@dataclass
-class UserAchievement:
- achieved_at: datetime
- achievement_id: int
-
-
-class DailyChallengeStats(SQLModel, table=True):
- __tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True, index=True)
- user_id: int = Field(
- sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True)
- )
-
- daily_streak_best: int = Field(default=0)
- daily_streak_current: int = Field(default=0)
- last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
- last_weekly_streak: datetime | None = Field(
- default=None, sa_column=Column(DateTime)
- )
- playcount: int = Field(default=0)
- top_10p_placements: int = Field(default=0)
- top_50p_placements: int = Field(default=0)
- weekly_streak_best: int = Field(default=0)
- weekly_streak_current: int = Field(default=0)
-
- user: "User" = Relationship(back_populates="daily_challenge_stats")
-
-
-class RankHistory(SQLModel, table=True):
- __tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True, index=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- mode: str = Field(max_length=10)
- rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks
- date_recorded: datetime = Field(
- default_factory=datetime.utcnow, sa_column=Column(DateTime)
- )
-
- user: "User" = Relationship(back_populates="rank_history")
-
-
-class UserAvatar(SQLModel, table=True):
- __tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType]
-
- id: int | None = Field(default=None, primary_key=True, index=True)
- user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
- filename: str = Field(max_length=255)
- original_filename: str = Field(max_length=255)
- file_size: int
- mime_type: str = Field(max_length=100)
- is_active: bool = Field(default=True)
- created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
- updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
- r2_original_url: str | None = Field(default=None, max_length=500)
- r2_game_url: str | None = Field(default=None, max_length=500)
-
- user: "User" = Relationship(back_populates="avatar")
diff --git a/app/database/user_account_history.py b/app/database/user_account_history.py
new file mode 100644
index 0000000..217c8eb
--- /dev/null
+++ b/app/database/user_account_history.py
@@ -0,0 +1,45 @@
+from datetime import UTC, datetime
+from enum import Enum
+
+from app.models.model import UTCBaseModel
+
+from sqlmodel import BigInteger, Column, Field, ForeignKey, Integer, SQLModel
+
+
+class UserAccountHistoryType(str, Enum):
+ NOTE = "note"
+ RESTRICTION = "restriction"
+ SLIENCE = "silence"
+ TOURNAMENT_BAN = "tournament_ban"
+
+
+class UserAccountHistoryBase(SQLModel, UTCBaseModel):
+ description: str | None = None
+ length: int
+ permanent: bool = False
+ timestamp: datetime = Field(default=datetime.now(UTC))
+ type: UserAccountHistoryType
+
+
+class UserAccountHistory(UserAccountHistoryBase, table=True):
+ __tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType]
+
+ id: int | None = Field(
+ sa_column=Column(
+ Integer,
+ autoincrement=True,
+ index=True,
+ primary_key=True,
+ )
+ )
+ user_id: int = Field(
+ sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
+ )
+
+
+class UserAccountHistoryResp(UserAccountHistoryBase):
+ id: int | None = None
+
+ @classmethod
+ def from_db(cls, db_model: UserAccountHistory) -> "UserAccountHistoryResp":
+ return cls.model_validate(db_model)
diff --git a/app/dependencies/database.py b/app/dependencies/database.py
index fe09139..77b15c3 100644
--- a/app/dependencies/database.py
+++ b/app/dependencies/database.py
@@ -5,15 +5,11 @@ import json
from app.config import settings
from pydantic import BaseModel
+import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
-try:
- import redis
-except ImportError:
- redis = None
-
def json_serializer(value):
if isinstance(value, BaseModel | SQLModel):
@@ -25,10 +21,7 @@ def json_serializer(value):
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
# Redis 连接
-if redis:
- redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
-else:
- redis_client = None
+redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
# 数据库依赖
diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py
index d3c216a..806eb87 100644
--- a/app/dependencies/fetcher.py
+++ b/app/dependencies/fetcher.py
@@ -8,7 +8,7 @@ from app.log import logger
fetcher: Fetcher | None = None
-def get_fetcher() -> Fetcher:
+async def get_fetcher() -> Fetcher:
global fetcher
if fetcher is None:
fetcher = Fetcher(
@@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher:
settings.FETCHER_CALLBACK_URL,
)
redis = get_redis()
- if redis:
- access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}")
- if access_token:
- fetcher.access_token = str(access_token)
- refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
- if refresh_token:
- fetcher.refresh_token = str(refresh_token)
- if not fetcher.access_token or not fetcher.refresh_token:
- logger.opt(colors=True).info(
- f"Login to initialize fetcher: {fetcher.authorize_url}"
- )
+ access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
+ if access_token:
+ fetcher.access_token = str(access_token)
+ refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
+ if refresh_token:
+ fetcher.refresh_token = str(refresh_token)
+ if not fetcher.access_token or not fetcher.refresh_token:
+ logger.opt(colors=True).info(
+ f"Login to initialize fetcher: {fetcher.authorize_url}"
+ )
return fetcher
diff --git a/app/dependencies/user.py b/app/dependencies/user.py
index 0c8f8bc..5537f4f 100644
--- a/app/dependencies/user.py
+++ b/app/dependencies/user.py
@@ -1,14 +1,13 @@
from __future__ import annotations
from app.auth import get_token_by_access_token
-from app.database import (
- User as DBUser,
-)
+from app.database import User
from .database import get_db
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer()
@@ -17,7 +16,7 @@ security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
-) -> DBUser:
+) -> User:
"""获取当前认证用户"""
token = credentials.credentials
@@ -27,13 +26,9 @@ async def get_current_user(
return user
-async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
+async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None:
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
- user = (
- await db.exec(
- DBUser.all_select_clause().where(DBUser.id == token_record.user_id)
- )
- ).first()
+ user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
return user
diff --git a/app/signalr/exception.py b/app/exception.py
similarity index 100%
rename from app/signalr/exception.py
rename to app/exception.py
diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py
index 08e3508..2717a35 100644
--- a/app/fetcher/_base.py
+++ b/app/fetcher/_base.py
@@ -59,16 +59,15 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
- if redis:
- redis.set(
- f"fetcher:access_token:{self.client_id}",
- self.access_token,
- ex=token_data["expires_in"],
- )
- redis.set(
- f"fetcher:refresh_token:{self.client_id}",
- self.refresh_token,
- )
+ await redis.set(
+ f"fetcher:access_token:{self.client_id}",
+ self.access_token,
+ ex=token_data["expires_in"],
+ )
+ await redis.set(
+ f"fetcher:refresh_token:{self.client_id}",
+ self.refresh_token,
+ )
async def refresh_access_token(self) -> None:
async with AsyncClient() as client:
@@ -87,13 +86,12 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
- if redis:
- redis.set(
- f"fetcher:access_token:{self.client_id}",
- self.access_token,
- ex=token_data["expires_in"],
- )
- redis.set(
- f"fetcher:refresh_token:{self.client_id}",
- self.refresh_token,
- )
+ await redis.set(
+ f"fetcher:access_token:{self.client_id}",
+ self.access_token,
+ ex=token_data["expires_in"],
+ )
+ await redis.set(
+ f"fetcher:refresh_token:{self.client_id}",
+ self.refresh_token,
+ )
diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py
index 08b8dfc..6e18435 100644
--- a/app/fetcher/osu_dot_direct.py
+++ b/app/fetcher/osu_dot_direct.py
@@ -4,7 +4,7 @@ from ._base import BaseFetcher
from httpx import AsyncClient
from loguru import logger
-import redis
+import redis.asyncio as redis
class OsuDotDirectFetcher(BaseFetcher):
@@ -22,8 +22,8 @@ class OsuDotDirectFetcher(BaseFetcher):
async def get_or_fetch_beatmap_raw(
self, redis: redis.Redis, beatmap_id: int
) -> str:
- if redis.exists(f"beatmap:{beatmap_id}:raw"):
- return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
+ if await redis.exists(f"beatmap:{beatmap_id}:raw"):
+ return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
raw = await self.get_beatmap_raw(beatmap_id)
- redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
+ await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
return raw
diff --git a/app/models/beatmap.py b/app/models/beatmap.py
index 4f12e13..fae18ba 100644
--- a/app/models/beatmap.py
+++ b/app/models/beatmap.py
@@ -42,11 +42,12 @@ class Language(IntEnum):
KOREAN = 6
FRENCH = 7
GERMAN = 8
- ITALIAN = 9
- SPANISH = 10
- RUSSIAN = 11
- POLISH = 12
- OTHER = 13
+ SWEDISH = 9
+ ITALIAN = 10
+ SPANISH = 11
+ RUSSIAN = 12
+ POLISH = 13
+ OTHER = 14
class BeatmapAttributes(BaseModel):
diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py
index 8ae3e65..684ab54 100644
--- a/app/models/metadata_hub.py
+++ b/app/models/metadata_hub.py
@@ -1,114 +1,85 @@
from __future__ import annotations
from enum import IntEnum
-from typing import Any, Literal
+from typing import ClassVar, Literal
-from app.models.signalr import UserState
+from app.models.signalr import SignalRUnionMessage, UserState
-from pydantic import BaseModel, ConfigDict, Field
+from pydantic import BaseModel
-class _UserActivity(BaseModel):
- model_config = ConfigDict(serialize_by_alias=True)
- type: Literal[
- "ChoosingBeatmap",
- "InSoloGame",
- "WatchingReplay",
- "SpectatingUser",
- "SearchingForLobby",
- "InLobby",
- "InMultiplayerGame",
- "SpectatingMultiplayerGame",
- "InPlaylistGame",
- "EditingBeatmap",
- "ModdingBeatmap",
- "TestingBeatmap",
- "InDailyChallengeLobby",
- "PlayingDailyChallenge",
- ] = Field(alias="$dtype")
- value: Any | None = Field(alias="$value")
+class _UserActivity(SignalRUnionMessage): ...
class ChoosingBeatmap(_UserActivity):
- type: Literal["ChoosingBeatmap"] = Field(alias="$dtype")
-
-
-class InGameValue(BaseModel):
- beatmap_id: int = Field(alias="BeatmapID")
- beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
- ruleset_id: int = Field(alias="RulesetID")
- ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb")
+ union_type: ClassVar[Literal[11]] = 11
class _InGame(_UserActivity):
- value: InGameValue = Field(alias="$value")
+ beatmap_id: int
+ beatmap_display_title: str
+ ruleset_id: int
+ ruleset_playing_verb: str
class InSoloGame(_InGame):
- type: Literal["InSoloGame"] = Field(alias="$dtype")
+ union_type: ClassVar[Literal[12]] = 12
class InMultiplayerGame(_InGame):
- type: Literal["InMultiplayerGame"] = Field(alias="$dtype")
+ union_type: ClassVar[Literal[23]] = 23
class SpectatingMultiplayerGame(_InGame):
- type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype")
+ union_type: ClassVar[Literal[24]] = 24
class InPlaylistGame(_InGame):
- type: Literal["InPlaylistGame"] = Field(alias="$dtype")
+ union_type: ClassVar[Literal[31]] = 31
-class EditingBeatmapValue(BaseModel):
- beatmap_id: int = Field(alias="BeatmapID")
- beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
+class PlayingDailyChallenge(_InGame):
+ union_type: ClassVar[Literal[52]] = 52
class EditingBeatmap(_UserActivity):
- type: Literal["EditingBeatmap"] = Field(alias="$dtype")
- value: EditingBeatmapValue = Field(alias="$value")
+ union_type: ClassVar[Literal[41]] = 41
+ beatmap_id: int
+ beatmap_display_title: str
-class TestingBeatmap(_UserActivity):
- type: Literal["TestingBeatmap"] = Field(alias="$dtype")
+class TestingBeatmap(EditingBeatmap):
+ union_type: ClassVar[Literal[43]] = 43
-class ModdingBeatmap(_UserActivity):
- type: Literal["ModdingBeatmap"] = Field(alias="$dtype")
-
-
-class WatchingReplayValue(BaseModel):
- score_id: int = Field(alias="ScoreID")
- player_name: str = Field(alias="PlayerName")
- beatmap_id: int = Field(alias="BeatmapID")
- beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
+class ModdingBeatmap(EditingBeatmap):
+ union_type: ClassVar[Literal[42]] = 42
class WatchingReplay(_UserActivity):
- type: Literal["WatchingReplay"] = Field(alias="$dtype")
- value: int | None = Field(alias="$value") # Replay ID
+ union_type: ClassVar[Literal[13]] = 13
+ score_id: int
+ player_name: str
+ beatmap_id: int
+ beatmap_display_title: str
class SpectatingUser(WatchingReplay):
- type: Literal["SpectatingUser"] = Field(alias="$dtype")
+ union_type: ClassVar[Literal[14]] = 14
class SearchingForLobby(_UserActivity):
- type: Literal["SearchingForLobby"] = Field(alias="$dtype")
-
-
-class InLobbyValue(BaseModel):
- room_id: int = Field(alias="RoomID")
- room_name: str = Field(alias="RoomName")
+ union_type: ClassVar[Literal[21]] = 21
class InLobby(_UserActivity):
- type: Literal["InLobby"] = "InLobby"
+ union_type: ClassVar[Literal[22]] = 22
+ room_id: int
+ room_name: str
class InDailyChallengeLobby(_UserActivity):
- type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype")
+ union_type: ClassVar[Literal[51]] = 51
UserActivity = (
@@ -128,23 +99,25 @@ UserActivity = (
)
-class MetadataClientState(UserState):
- user_activity: UserActivity | None = None
- status: OnlineStatus | None = None
+class UserPresence(BaseModel):
+ activity: UserActivity | None = None
- def to_dict(self) -> dict[str, Any] | None:
- if self.status is None or self.status == OnlineStatus.OFFLINE:
- return None
- dumped = self.model_dump(by_alias=True, exclude_none=True)
- return {
- "Activity": dumped.get("user_activity"),
- "Status": dumped.get("status"),
- }
+ status: OnlineStatus | None = None
@property
def pushable(self) -> bool:
return self.status is not None and self.status != OnlineStatus.OFFLINE
+ @property
+ def for_push(self) -> "UserPresence | None":
+ return UserPresence(
+ activity=self.activity,
+ status=self.status,
+ )
+
+
+class MetadataClientState(UserPresence, UserState): ...
+
class OnlineStatus(IntEnum):
OFFLINE = 0 # 隐身
diff --git a/app/models/model.py b/app/models/model.py
new file mode 100644
index 0000000..5ba8093
--- /dev/null
+++ b/app/models/model.py
@@ -0,0 +1,22 @@
+from __future__ import annotations
+
+from datetime import UTC, datetime
+
+from pydantic import BaseModel, field_serializer
+
+
+class UTCBaseModel(BaseModel):
+ @field_serializer("*", when_used="json")
+ def serialize_datetime(self, v, _info):
+ if isinstance(v, datetime):
+ if v.tzinfo is None:
+ v = v.replace(tzinfo=UTC)
+ return v.astimezone(UTC).isoformat()
+ return v
+
+
+Cursor = dict[str, int]
+
+
+class RespWithCursor(BaseModel):
+ cursor: Cursor | None = None
diff --git a/app/models/mods.py b/app/models/mods.py
index abcd2cd..299a05f 100644
--- a/app/models/mods.py
+++ b/app/models/mods.py
@@ -8,7 +8,7 @@ from app.path import STATIC_DIR
class APIMod(TypedDict):
acronym: str
- settings: NotRequired[dict[str, bool | float | str]]
+ settings: NotRequired[dict[str, bool | float | str | int]]
# https://github.com/ppy/osu-api/wiki#mods
diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py
new file mode 100644
index 0000000..09d8900
--- /dev/null
+++ b/app/models/multiplayer_hub.py
@@ -0,0 +1,925 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+import asyncio
+from collections.abc import Awaitable, Callable
+from dataclasses import dataclass
+from datetime import UTC, datetime, timedelta
+from enum import IntEnum
+from typing import (
+ TYPE_CHECKING,
+ Annotated,
+ Any,
+ ClassVar,
+ Literal,
+ TypedDict,
+ cast,
+ override,
+)
+
+from app.database.beatmap import Beatmap
+from app.dependencies.database import engine
+from app.exception import InvokeException
+
+from .mods import APIMod
+from .room import (
+ DownloadState,
+ MatchType,
+ MultiplayerRoomState,
+ MultiplayerUserState,
+ QueueMode,
+ RoomCategory,
+ RoomStatus,
+)
+from .signalr import (
+ SignalRMeta,
+ SignalRUnionMessage,
+ UserState,
+)
+
+from pydantic import BaseModel, Field
+from sqlalchemy import update
+from sqlmodel import col
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+if TYPE_CHECKING:
+ from app.signalr.hub import MultiplayerHub
+
+HOST_LIMIT = 50
+PER_USER_LIMIT = 3
+
+
+class MultiplayerClientState(UserState):
+ room_id: int = 0
+
+
+class MultiplayerRoomSettings(BaseModel):
+ name: str = "Unnamed Room"
+ playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
+ password: str = ""
+ match_type: MatchType = MatchType.HEAD_TO_HEAD
+ queue_mode: QueueMode = QueueMode.HOST_ONLY
+ auto_start_duration: timedelta = timedelta(seconds=0)
+ auto_skip: bool = False
+
+ @property
+ def auto_start_enabled(self) -> bool:
+ return self.auto_start_duration != timedelta(seconds=0)
+
+
+class BeatmapAvailability(BaseModel):
+ state: DownloadState = DownloadState.UNKNOWN
+ download_progress: float | None = None
+
+
+class _MatchUserState(SignalRUnionMessage): ...
+
+
+class TeamVersusUserState(_MatchUserState):
+ team_id: int
+
+ union_type: ClassVar[Literal[0]] = 0
+
+
+MatchUserState = TeamVersusUserState
+
+
+class _MatchRoomState(SignalRUnionMessage): ...
+
+
+class MultiplayerTeam(BaseModel):
+ id: int
+ name: str
+
+
+class TeamVersusRoomState(_MatchRoomState):
+ teams: list[MultiplayerTeam] = Field(
+ default_factory=lambda: [
+ MultiplayerTeam(id=0, name="Team Red"),
+ MultiplayerTeam(id=1, name="Team Blue"),
+ ]
+ )
+
+ union_type: ClassVar[Literal[0]] = 0
+
+
+MatchRoomState = TeamVersusRoomState
+
+
+class PlaylistItem(BaseModel):
+ id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
+ owner_id: int
+ beatmap_id: int
+ beatmap_checksum: str
+ ruleset_id: int
+ required_mods: list[APIMod] = Field(default_factory=list)
+ allowed_mods: list[APIMod] = Field(default_factory=list)
+ expired: bool
+ playlist_order: int
+ played_at: datetime | None = None
+ star_rating: float
+ freestyle: bool
+
+ def _get_api_mods(self):
+ from app.models.mods import API_MODS, init_mods
+
+ if not API_MODS:
+ init_mods()
+ return API_MODS
+
+ def _validate_mod_for_ruleset(
+ self, mod: APIMod, ruleset_key: int, context: str = "mod"
+ ) -> None:
+ from typing import Literal, cast
+
+ API_MODS = self._get_api_mods()
+ typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
+
+ # Check if mod is valid for ruleset
+ if (
+ typed_ruleset_key not in API_MODS
+ or mod["acronym"] not in API_MODS[typed_ruleset_key]
+ ):
+ raise InvokeException(
+ f"{context} {mod['acronym']} is invalid for this ruleset"
+ )
+
+ mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
+
+ # Check if mod is unplayable in multiplayer
+ if mod_settings.get("UserPlayable", True) is False:
+ raise InvokeException(
+ f"{context} {mod['acronym']} is not playable by users"
+ )
+
+ if mod_settings.get("ValidForMultiplayer", True) is False:
+ raise InvokeException(
+ f"{context} {mod['acronym']} is not valid for multiplayer"
+ )
+
+ def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
+ from typing import Literal, cast
+
+ API_MODS = self._get_api_mods()
+ typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
+
+ for i, mod1 in enumerate(mods):
+ mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
+ if mod1_settings:
+ incompatible = set(mod1_settings.get("IncompatibleMods", []))
+ for mod2 in mods[i + 1 :]:
+ if mod2["acronym"] in incompatible:
+ raise InvokeException(
+ f"Mods {mod1['acronym']} and "
+ f"{mod2['acronym']} are incompatible"
+ )
+
+ def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
+ from typing import Literal, cast
+
+ API_MODS = self._get_api_mods()
+ typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
+ allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
+
+ for req_mod in self.required_mods:
+ req_acronym = req_mod["acronym"]
+ req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
+ if req_settings:
+ incompatible = set(req_settings.get("IncompatibleMods", []))
+ conflicting_allowed = allowed_acronyms & incompatible
+ if conflicting_allowed:
+ conflict_list = ", ".join(conflicting_allowed)
+ raise InvokeException(
+ f"Required mod {req_acronym} conflicts with "
+ f"allowed mods: {conflict_list}"
+ )
+
+ def validate_playlist_item_mods(self) -> None:
+ ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
+
+ # Validate required mods
+ for mod in self.required_mods:
+ self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
+
+ # Validate allowed mods
+ for mod in self.allowed_mods:
+ self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
+
+ # Check internal compatibility of required mods
+ self._check_mod_compatibility(self.required_mods, ruleset_key)
+
+ # Check compatibility between required and allowed mods
+ self._check_required_allowed_compatibility(ruleset_key)
+
+ def validate_user_mods(
+ self,
+ user: "MultiplayerRoomUser",
+ proposed_mods: list[APIMod],
+ ) -> tuple[bool, list[APIMod]]:
+ """
+ Validates user mods against playlist item rules and returns valid mods.
+ Returns (is_valid, valid_mods).
+ """
+ from typing import Literal, cast
+
+ API_MODS = self._get_api_mods()
+
+ ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
+ ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
+
+ valid_mods = []
+ all_proposed_valid = True
+
+ # Check if mods are valid for the ruleset
+ for mod in proposed_mods:
+ if (
+ ruleset_key not in API_MODS
+ or mod["acronym"] not in API_MODS[ruleset_key]
+ ):
+ all_proposed_valid = False
+ continue
+ valid_mods.append(mod)
+
+ # Check mod compatibility within user mods
+ incompatible_mods = set()
+ final_valid_mods = []
+ for mod in valid_mods:
+ if mod["acronym"] in incompatible_mods:
+ all_proposed_valid = False
+ continue
+ setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
+ if setting_mods:
+ incompatible_mods.update(setting_mods["IncompatibleMods"])
+ final_valid_mods.append(mod)
+
+ # If not freestyle, check against allowed mods
+ if not self.freestyle:
+ allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
+ filtered_valid_mods = []
+ for mod in final_valid_mods:
+ if mod["acronym"] not in allowed_acronyms:
+ all_proposed_valid = False
+ else:
+ filtered_valid_mods.append(mod)
+ final_valid_mods = filtered_valid_mods
+
+ # Check compatibility with required mods
+ required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
+ all_mod_acronyms = {
+ mod["acronym"] for mod in final_valid_mods
+ } | required_mod_acronyms
+
+ # Check for incompatibility between required and user mods
+ filtered_valid_mods = []
+ for mod in final_valid_mods:
+ mod_acronym = mod["acronym"]
+ is_compatible = True
+
+ for other_acronym in all_mod_acronyms:
+ if other_acronym == mod_acronym:
+ continue
+ setting_mods = API_MODS[ruleset_key].get(mod_acronym)
+ if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
+ is_compatible = False
+ all_proposed_valid = False
+ break
+
+ if is_compatible:
+ filtered_valid_mods.append(mod)
+
+ return all_proposed_valid, filtered_valid_mods
+
+ def clone(self) -> "PlaylistItem":
+ copy = self.model_copy()
+ copy.required_mods = list(self.required_mods)
+ copy.allowed_mods = list(self.allowed_mods)
+ copy.expired = False
+ copy.played_at = None
+ return copy
+
+
+class _MultiplayerCountdown(SignalRUnionMessage):
+ id: int = 0
+ time_remaining: timedelta
+ is_exclusive: Annotated[
+ bool, Field(default=True), SignalRMeta(member_ignore=True)
+ ] = True
+
+
+class MatchStartCountdown(_MultiplayerCountdown):
+ union_type: ClassVar[Literal[0]] = 0
+
+
+class ForceGameplayStartCountdown(_MultiplayerCountdown):
+ union_type: ClassVar[Literal[1]] = 1
+
+
+class ServerShuttingDownCountdown(_MultiplayerCountdown):
+ union_type: ClassVar[Literal[2]] = 2
+
+
+MultiplayerCountdown = (
+ MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
+)
+
+
+class MultiplayerRoomUser(BaseModel):
+ user_id: int
+ state: MultiplayerUserState = MultiplayerUserState.IDLE
+ availability: BeatmapAvailability = BeatmapAvailability(
+ state=DownloadState.UNKNOWN, download_progress=None
+ )
+ mods: list[APIMod] = Field(default_factory=list)
+ match_state: MatchUserState | None = None
+ ruleset_id: int | None = None # freestyle
+ beatmap_id: int | None = None # freestyle
+
+
+class MultiplayerRoom(BaseModel):
+ room_id: int
+ state: MultiplayerRoomState
+ settings: MultiplayerRoomSettings
+ users: list[MultiplayerRoomUser] = Field(default_factory=list)
+ host: MultiplayerRoomUser | None = None
+ match_state: MatchRoomState | None = None
+ playlist: list[PlaylistItem] = Field(default_factory=list)
+ active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
+ channel_id: int
+
+ @classmethod
+ def from_db(cls, room) -> "MultiplayerRoom":
+ """
+ 将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
+ """
+
+ # 用户列表
+ users = [MultiplayerRoomUser(user_id=room.host_id)]
+ host_user = MultiplayerRoomUser(user_id=room.host_id)
+ # playlist 转换
+ playlist = []
+ if hasattr(room, "playlist"):
+ for item in room.playlist:
+ playlist.append(
+ PlaylistItem(
+ id=item.id,
+ owner_id=item.owner_id,
+ beatmap_id=item.beatmap_id,
+ beatmap_checksum=item.beatmap.checksum if item.beatmap else "",
+ ruleset_id=item.ruleset_id,
+ required_mods=item.required_mods,
+ allowed_mods=item.allowed_mods,
+ expired=item.expired,
+ playlist_order=item.playlist_order,
+ played_at=item.played_at,
+ star_rating=item.beatmap.difficulty_rating
+ if item.beatmap is not None
+ else 0.0,
+ freestyle=item.freestyle,
+ )
+ )
+
+ return cls(
+ room_id=room.id,
+ state=getattr(room, "state", MultiplayerRoomState.OPEN),
+ settings=MultiplayerRoomSettings(
+ name=room.name,
+ playlist_item_id=playlist[0].id if playlist else 0,
+ password=getattr(room, "password", ""),
+ match_type=room.type,
+ queue_mode=room.queue_mode,
+ auto_start_duration=timedelta(seconds=room.auto_start_duration),
+ auto_skip=room.auto_skip,
+ ),
+ users=users,
+ host=host_user,
+ match_state=None,
+ playlist=playlist,
+ active_countdowns=[],
+ channel_id=getattr(room, "channel_id", 0),
+ )
+
+
+class MultiplayerQueue:
+ def __init__(self, room: "ServerMultiplayerRoom"):
+ self.server_room = room
+ self.current_index = 0
+
+ @property
+ def hub(self) -> "MultiplayerHub":
+ return self.server_room.hub
+
+ @property
+ def upcoming_items(self):
+ return sorted(
+ (item for item in self.room.playlist if not item.expired),
+ key=lambda i: i.playlist_order,
+ )
+
+ @property
+ def room(self):
+ return self.server_room.room
+
+ async def update_order(self):
+ from app.database import Playlist
+
+ match self.room.settings.queue_mode:
+ case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
+ ordered_active_items = []
+
+ is_first_set = True
+ first_set_order_by_user_id = {}
+
+ active_items = [item for item in self.room.playlist if not item.expired]
+ active_items.sort(key=lambda x: x.id)
+
+ user_item_groups = {}
+ for item in active_items:
+ if item.owner_id not in user_item_groups:
+ user_item_groups[item.owner_id] = []
+ user_item_groups[item.owner_id].append(item)
+
+ max_items = max(
+ (len(items) for items in user_item_groups.values()), default=0
+ )
+
+ for i in range(max_items):
+ current_set = []
+ for user_id, items in user_item_groups.items():
+ if i < len(items):
+ current_set.append(items[i])
+
+ if is_first_set:
+ current_set.sort(
+ key=lambda item: (item.playlist_order, item.id)
+ )
+ ordered_active_items.extend(current_set)
+ first_set_order_by_user_id = {
+ item.owner_id: idx
+ for idx, item in enumerate(ordered_active_items)
+ }
+ else:
+ current_set.sort(
+ key=lambda item: first_set_order_by_user_id.get(
+ item.owner_id, 0
+ )
+ )
+ ordered_active_items.extend(current_set)
+
+ is_first_set = False
+
+ for idx, item in enumerate(ordered_active_items):
+ item.playlist_order = idx
+ case _:
+ ordered_active_items = sorted(
+ (item for item in self.room.playlist if not item.expired),
+ key=lambda x: x.id,
+ )
+ async with AsyncSession(engine) as session:
+ for idx, item in enumerate(ordered_active_items):
+ if item.playlist_order == idx:
+ continue
+ item.playlist_order = idx
+ await Playlist.update(item, self.room.room_id, session)
+ await self.hub.playlist_changed(
+ self.server_room, item, beatmap_changed=False
+ )
+
+ async def update_current_item(self):
+ upcoming_items = self.upcoming_items
+ next_item = (
+ upcoming_items[0]
+ if upcoming_items
+ else max(
+ self.room.playlist,
+ key=lambda i: i.played_at or datetime.min,
+ )
+ )
+ self.current_index = self.room.playlist.index(next_item)
+ last_id = self.room.settings.playlist_item_id
+ self.room.settings.playlist_item_id = next_item.id
+ if last_id != next_item.id:
+ await self.hub.setting_changed(self.server_room, True)
+
+ async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
+ from app.database import Playlist
+
+ is_host = self.room.host and self.room.host.user_id == user.user_id
+ if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
+ raise InvokeException("You are not the host")
+
+ limit = HOST_LIMIT if is_host else PER_USER_LIMIT
+ if (
+ len([True for u in self.room.playlist if u.owner_id == user.user_id])
+ >= limit
+ ):
+ raise InvokeException(f"You can only have {limit} items in the queue")
+
+ if item.freestyle and len(item.allowed_mods) > 0:
+ raise InvokeException("Freestyle items cannot have allowed mods")
+
+ async with AsyncSession(engine) as session:
+ async with session:
+ beatmap = await session.get(Beatmap, item.beatmap_id)
+ if beatmap is None:
+ raise InvokeException("Beatmap not found")
+ if item.beatmap_checksum != beatmap.checksum:
+ raise InvokeException("Checksum mismatch")
+
+ item.validate_playlist_item_mods()
+ item.owner_id = user.user_id
+ item.star_rating = float(
+ beatmap.difficulty_rating
+ ) # FIXME: beatmap use decimal
+ await Playlist.add_to_db(item, self.room.room_id, session)
+ self.room.playlist.append(item)
+ await self.hub.playlist_added(self.server_room, item)
+ await self.update_order()
+ await self.update_current_item()
+
+ async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
+ from app.database import Playlist
+
+ if item.freestyle and len(item.allowed_mods) > 0:
+ raise InvokeException("Freestyle items cannot have allowed mods")
+
+ async with AsyncSession(engine) as session:
+ async with session:
+ beatmap = await session.get(Beatmap, item.beatmap_id)
+ if beatmap is None:
+ raise InvokeException("Beatmap not found")
+ if item.beatmap_checksum != beatmap.checksum:
+ raise InvokeException("Checksum mismatch")
+
+ existing_item = next(
+ (i for i in self.room.playlist if i.id == item.id), None
+ )
+ if existing_item is None:
+ raise InvokeException(
+ "Attempted to change an item that doesn't exist"
+ )
+
+ if existing_item.owner_id != user.user_id and self.room.host != user:
+ raise InvokeException(
+ "Attempted to change an item which is not owned by the user"
+ )
+
+ if existing_item.expired:
+ raise InvokeException(
+ "Attempted to change an item which has already been played"
+ )
+
+ item.validate_playlist_item_mods()
+ item.owner_id = user.user_id
+ item.star_rating = float(beatmap.difficulty_rating)
+ item.playlist_order = existing_item.playlist_order
+
+ await Playlist.update(item, self.room.room_id, session)
+
+ # Update item in playlist
+ for idx, playlist_item in enumerate(self.room.playlist):
+ if playlist_item.id == item.id:
+ self.room.playlist[idx] = item
+ break
+
+ await self.hub.playlist_changed(
+ self.server_room,
+ item,
+ beatmap_changed=item.beatmap_checksum
+ != existing_item.beatmap_checksum,
+ )
+
+ async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
+ from app.database import Playlist
+
+ item = next(
+ (i for i in self.room.playlist if i.id == playlist_item_id),
+ None,
+ )
+
+ if item is None:
+ raise InvokeException("Item does not exist in the room")
+
+ # Check if it's the only item and current item
+ if item == self.current_item:
+ upcoming_items = [i for i in self.room.playlist if not i.expired]
+ if len(upcoming_items) == 1:
+ raise InvokeException("The only item in the room cannot be removed")
+
+ if item.owner_id != user.user_id and self.room.host != user:
+ raise InvokeException(
+ "Attempted to remove an item which is not owned by the user"
+ )
+
+ if item.expired:
+ raise InvokeException(
+ "Attempted to remove an item which has already been played"
+ )
+
+ async with AsyncSession(engine) as session:
+ await Playlist.delete_item(item.id, self.room.room_id, session)
+
+ self.room.playlist.remove(item)
+ self.current_index = self.room.playlist.index(self.upcoming_items[0])
+
+ await self.update_order()
+ await self.update_current_item()
+ await self.hub.playlist_removed(self.server_room, item.id)
+
+ async def finish_current_item(self):
+ from app.database import Playlist
+
+ async with AsyncSession(engine) as session:
+ played_at = datetime.now(UTC)
+ await session.execute(
+ update(Playlist)
+ .where(
+ col(Playlist.id) == self.current_item.id,
+ col(Playlist.room_id) == self.room.room_id,
+ )
+ .values(expired=True, played_at=played_at)
+ )
+ self.room.playlist[self.current_index].expired = True
+ self.room.playlist[self.current_index].played_at = played_at
+ await self.hub.playlist_changed(self.server_room, self.current_item, True)
+ await self.update_order()
+ if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
+ playitem.expired for playitem in self.room.playlist
+ ):
+ assert self.room.host
+ await self.add_item(self.current_item.clone(), self.room.host)
+
+ async def update_queue_mode(self):
+ if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
+ playitem.expired for playitem in self.room.playlist
+ ):
+ assert self.room.host
+ await self.add_item(self.current_item.clone(), self.room.host)
+ await self.update_order()
+ await self.update_current_item()
+
+ @property
+ def current_item(self):
+ return self.room.playlist[self.current_index]
+
+
+@dataclass
+class CountdownInfo:
+ countdown: MultiplayerCountdown
+ duration: timedelta
+ task: asyncio.Task | None = None
+
+ def __init__(self, countdown: MultiplayerCountdown):
+ self.countdown = countdown
+ self.duration = (
+ countdown.time_remaining
+ if countdown.time_remaining > timedelta(seconds=0)
+ else timedelta(seconds=0)
+ )
+
+
+class _MatchRequest(SignalRUnionMessage): ...
+
+
+class ChangeTeamRequest(_MatchRequest):
+ union_type: ClassVar[Literal[0]] = 0
+ team_id: int
+
+
+class StartMatchCountdownRequest(_MatchRequest):
+ union_type: ClassVar[Literal[1]] = 1
+ duration: timedelta
+
+
+class StopCountdownRequest(_MatchRequest):
+ union_type: ClassVar[Literal[2]] = 2
+ id: int
+
+
+MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
+
+
+class MatchTypeHandler(ABC):
+ def __init__(self, room: "ServerMultiplayerRoom"):
+ self.room = room
+ self.hub = room.hub
+
+ @abstractmethod
+ async def handle_join(self, user: MultiplayerRoomUser): ...
+
+ @abstractmethod
+ async def handle_request(
+ self, user: MultiplayerRoomUser, request: MatchRequest
+ ): ...
+
+ @abstractmethod
+ async def handle_leave(self, user: MultiplayerRoomUser): ...
+
+ @abstractmethod
+ def get_details(self) -> MatchStartedEventDetail: ...
+
+
+class HeadToHeadHandler(MatchTypeHandler):
+ @override
+ async def handle_join(self, user: MultiplayerRoomUser):
+ if user.match_state is not None:
+ user.match_state = None
+ await self.hub.change_user_match_state(self.room, user)
+
+ @override
+ async def handle_request(
+ self, user: MultiplayerRoomUser, request: MatchRequest
+ ): ...
+
+ @override
+ async def handle_leave(self, user: MultiplayerRoomUser): ...
+
+ @override
+ def get_details(self) -> MatchStartedEventDetail:
+ detail = MatchStartedEventDetail(room_type="head_to_head", team=None)
+ return detail
+
+
+class TeamVersusHandler(MatchTypeHandler):
+ @override
+ def __init__(self, room: "ServerMultiplayerRoom"):
+ super().__init__(room)
+ self.state = TeamVersusRoomState()
+ room.room.match_state = self.state
+ task = asyncio.create_task(self.hub.change_room_match_state(self.room))
+ self.hub.tasks.add(task)
+ task.add_done_callback(self.hub.tasks.discard)
+
+ def _get_best_available_team(self) -> int:
+ for team in self.state.teams:
+ if all(
+ (
+ user.match_state is None
+ or not isinstance(user.match_state, TeamVersusUserState)
+ or user.match_state.team_id != team.id
+ )
+ for user in self.room.room.users
+ ):
+ return team.id
+
+ from collections import defaultdict
+
+ team_counts = defaultdict(int)
+ for user in self.room.room.users:
+ if user.match_state is not None and isinstance(
+ user.match_state, TeamVersusUserState
+ ):
+ team_counts[user.match_state.team_id] += 1
+
+ if team_counts:
+ min_count = min(team_counts.values())
+ for team_id, count in team_counts.items():
+ if count == min_count:
+ return team_id
+ return self.state.teams[0].id if self.state.teams else 0
+
+ @override
+ async def handle_join(self, user: MultiplayerRoomUser):
+ best_team_id = self._get_best_available_team()
+ user.match_state = TeamVersusUserState(team_id=best_team_id)
+ await self.hub.change_user_match_state(self.room, user)
+
+ @override
+ async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
+ if not isinstance(request, ChangeTeamRequest):
+ return
+
+ if request.team_id not in [team.id for team in self.state.teams]:
+ raise InvokeException("Invalid team ID")
+
+ user.match_state = TeamVersusUserState(team_id=request.team_id)
+ await self.hub.change_user_match_state(self.room, user)
+
+ @override
+ async def handle_leave(self, user: MultiplayerRoomUser): ...
+
+ @override
+ def get_details(self) -> MatchStartedEventDetail:
+ teams: dict[int, Literal["blue", "red"]] = {}
+ for user in self.room.room.users:
+ if user.match_state is not None and isinstance(
+ user.match_state, TeamVersusUserState
+ ):
+ teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
+ detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
+ return detail
+
+
+MATCH_TYPE_HANDLERS = {
+ MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
+ MatchType.TEAM_VERSUS: TeamVersusHandler,
+}
+
+
+@dataclass
+class ServerMultiplayerRoom:
+ room: MultiplayerRoom
+ category: RoomCategory
+ status: RoomStatus
+ start_at: datetime
+ hub: "MultiplayerHub"
+ match_type_handler: MatchTypeHandler
+ queue: MultiplayerQueue
+ _next_countdown_id: int
+ _countdown_id_lock: asyncio.Lock
+ _tracked_countdown: dict[int, CountdownInfo]
+
+ def __init__(
+ self,
+ room: MultiplayerRoom,
+ category: RoomCategory,
+ start_at: datetime,
+ hub: "MultiplayerHub",
+ ):
+ self.room = room
+ self.category = category
+ self.status = RoomStatus.IDLE
+ self.start_at = start_at
+ self.hub = hub
+ self.queue = MultiplayerQueue(self)
+ self._next_countdown_id = 0
+ self._countdown_id_lock = asyncio.Lock()
+ self._tracked_countdown = {}
+
+ async def set_handler(self):
+ self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](
+ self
+ )
+ for i in self.room.users:
+ await self.match_type_handler.handle_join(i)
+
+ async def get_next_countdown_id(self) -> int:
+ async with self._countdown_id_lock:
+ self._next_countdown_id += 1
+ return self._next_countdown_id
+
+ async def start_countdown(
+ self,
+ countdown: MultiplayerCountdown,
+ on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
+ ):
+ async def _countdown_task(self: "ServerMultiplayerRoom"):
+ await asyncio.sleep(info.duration.total_seconds())
+ if on_complete is not None:
+ await on_complete(self)
+ await self.stop_countdown(countdown)
+
+ if countdown.is_exclusive:
+ await self.stop_all_countdowns()
+ countdown.id = await self.get_next_countdown_id()
+ info = CountdownInfo(countdown)
+ self.room.active_countdowns.append(info.countdown)
+ self._tracked_countdown[countdown.id] = info
+ await self.hub.send_match_event(
+ self, CountdownStartedEvent(countdown=info.countdown)
+ )
+ info.task = asyncio.create_task(_countdown_task(self))
+
+ async def stop_countdown(self, countdown: MultiplayerCountdown):
+ info = self._tracked_countdown.get(countdown.id)
+ if info is None:
+ return
+ del self._tracked_countdown[countdown.id]
+ self.room.active_countdowns.remove(countdown)
+ await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
+ if info.task is not None and not info.task.done():
+ info.task.cancel()
+
+ async def stop_all_countdowns(self):
+ for countdown in list(self._tracked_countdown.values()):
+ await self.stop_countdown(countdown.countdown)
+
+ self._tracked_countdown.clear()
+ self.room.active_countdowns.clear()
+
+
+class _MatchServerEvent(SignalRUnionMessage): ...
+
+
+class CountdownStartedEvent(_MatchServerEvent):
+ countdown: MultiplayerCountdown
+
+ union_type: ClassVar[Literal[0]] = 0
+
+
+class CountdownStoppedEvent(_MatchServerEvent):
+ id: int
+
+ union_type: ClassVar[Literal[1]] = 1
+
+
+MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
+
+
+class GameplayAbortReason(IntEnum):
+ LOAD_TOOK_TOO_LONG = 0
+ HOST_ABORTED = 1
+
+
+class MatchStartedEventDetail(TypedDict):
+ room_type: Literal["playlists", "head_to_head", "team_versus"]
+ team: dict[int, Literal["blue", "red"]] | None
diff --git a/app/models/oauth.py b/app/models/oauth.py
index 22fcf63..6665965 100644
--- a/app/models/oauth.py
+++ b/app/models/oauth.py
@@ -1,7 +1,6 @@
# OAuth 相关模型
from __future__ import annotations
-from typing import List
from pydantic import BaseModel
@@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel):
class RegistrationErrorResponse(BaseModel):
"""注册错误响应模型"""
+
form_error: dict
class UserRegistrationErrors(BaseModel):
"""用户注册错误模型"""
- username: List[str] = []
- user_email: List[str] = []
- password: List[str] = []
+
+ username: list[str] = []
+ user_email: list[str] = []
+ password: list[str] = []
class RegistrationRequestErrors(BaseModel):
"""注册请求错误模型"""
+
message: str | None = None
redirect: str | None = None
user: UserRegistrationErrors | None = None
diff --git a/app/models/room.py b/app/models/room.py
index 2d01a26..392562a 100644
--- a/app/models/room.py
+++ b/app/models/room.py
@@ -1,17 +1,8 @@
from __future__ import annotations
-from datetime import datetime, timedelta
from enum import Enum
-from app.database.beatmap import Beatmap, BeatmapResp
-from app.database.user import User as DBUser
-from app.fetcher import Fetcher
-from app.models.mods import APIMod
-from app.models.user import User
-from app.utils import convert_db_user_to_api_user
-
-from pydantic import BaseModel, Field
-from sqlmodel.ext.asyncio.session import AsyncSession
+from pydantic import BaseModel
class RoomCategory(str, Enum):
@@ -62,53 +53,24 @@ class MultiplayerUserState(str, Enum):
RESULTS = "results"
SPECTATING = "spectating"
+ @property
+ def is_playing(self) -> bool:
+ return self in {
+ self.WAITING_FOR_LOAD,
+ self.PLAYING,
+ self.READY_FOR_GAMEPLAY,
+ self.LOADED,
+ }
+
class DownloadState(str, Enum):
- UNKOWN = "unkown"
+ UNKNOWN = "unknown"
NOT_DOWNLOADED = "not_downloaded"
DOWNLOADING = "downloading"
IMPORTING = "importing"
LOCALLY_AVAILABLE = "locally_available"
-class PlaylistItem(BaseModel):
- id: int
- owner_id: int
- ruleset_id: int
- expired: bool
- playlist_order: int | None
- played_at: datetime | None
- allowed_mods: list[APIMod] = []
- required_mods: list[APIMod] = []
- beatmap_id: int
- beatmap: BeatmapResp | None
- freestyle: bool
-
- class Config:
- exclude_none = True
-
- @classmethod
- async def from_mpListItem(
- cls, item: MultiPlayerListItem, db: AsyncSession, fetcher: Fetcher
- ):
- s = cls.model_validate(item.model_dump())
- s.id = item.id
- s.owner_id = item.OwnerID
- s.ruleset_id = item.RulesetID
- s.expired = item.Expired
- s.playlist_order = item.PlaylistOrder
- s.played_at = item.PlayedAt
- s.required_mods = item.RequierdMods
- s.allowed_mods = item.AllowedMods
- s.freestyle = item.Freestyle
- cur_beatmap = await Beatmap.get_or_fetch(
- db, fetcher=fetcher, bid=item.BeatmapID
- )
- s.beatmap = BeatmapResp.from_db(cur_beatmap)
- s.beatmap_id = item.BeatmapID
- return s
-
-
class RoomPlaylistItemStats(BaseModel):
count_active: int
count_total: int
@@ -120,269 +82,7 @@ class RoomDifficultyRange(BaseModel):
max: float
-class ItemAttemptsCount(BaseModel):
- id: int
- attempts: int
- passed: bool
-
-
-class PlaylistAggregateScore(BaseModel):
- playlist_item_attempts: list[ItemAttemptsCount]
-
-
-class MultiplayerRoomSettings(BaseModel):
- Name: str = "Unnamed Room"
- PlaylistItemId: int
- Password: str = ""
- MatchType: MatchType
- QueueMode: QueueMode
- AutoStartDuration: timedelta
- AutoSkip: bool
-
- @classmethod
- def from_apiRoom(cls, room: Room):
- s = cls.model_validate(room.model_dump())
- s.Name = room.name
- s.Password = room.password if room.password is not None else ""
- s.MatchType = room.type
- s.QueueMode = room.queue_mode
- s.AutoStartDuration = timedelta(seconds=room.auto_start_duration)
- s.AutoSkip = room.auto_skip
- return s
-
-
-class BeatmapAvailability(BaseModel):
- State: DownloadState
- DownloadProgress: float | None
-
-
-class MatchUserState(BaseModel):
- class Config:
- extra = "allow"
-
-
-class TeamVersusState(MatchUserState):
- TeamId: int
-
-
-MatchUserStateType = TeamVersusState | MatchUserState
-
-
-class MultiplayerRoomUser(BaseModel):
- UserID: int
- State: MultiplayerUserState = MultiplayerUserState.IDLE
- BeatmapAvailability: BeatmapAvailability
- Mods: list[APIMod] = []
- MatchUserState: MatchUserStateType | None
- RulesetId: int | None
- BeatmapId: int | None
- User: User | None
-
- @classmethod
- async def from_id(cls, id: int, db: AsyncSession):
- actualUser = (
- await db.exec(
- DBUser.all_select_clause().where(
- DBUser.id == id,
- )
- )
- ).first()
- user = (
- await convert_db_user_to_api_user(actualUser)
- if actualUser is not None
- else None
- )
- return MultiplayerRoomUser(
- UserID=id,
- MatchUserState=None,
- BeatmapAvailability=BeatmapAvailability(
- State=DownloadState.UNKOWN, DownloadProgress=None
- ),
- RulesetId=None,
- BeatmapId=None,
- User=user,
- )
-
-
-class MatchRoomState(BaseModel):
- class Config:
- extra = "allow"
-
-
-class MultiPlayerTeam(BaseModel):
- id: int = 0
- name: str = ""
-
-
-class TeamVersusRoomState(BaseModel):
- teams: list[MultiPlayerTeam] = []
-
- class Config:
- pass
-
- @classmethod
- def create_default(cls):
- return cls(
- teams=[
- MultiPlayerTeam(id=0, name="Team Red"),
- MultiPlayerTeam(id=1, name="Team Blue"),
- ]
- )
-
-
-MatchRoomStateType = TeamVersusRoomState | MatchRoomState
-
-
-class MultiPlayerListItem(BaseModel):
- id: int
- OwnerID: int
- BeatmapID: int
- BeatmapChecksum: str = ""
- RulesetID: int
- RequierdMods: list[APIMod]
- AllowedMods: list[APIMod]
- Expired: bool
- PlaylistOrder: int | None
- PlayedAt: datetime | None
- StarRating: float
- Freestyle: bool
-
- @classmethod
- async def from_apiItem(cls, item: PlaylistItem, db: AsyncSession, fetcher: Fetcher):
- s = cls.model_validate(item.model_dump())
- s.id = item.id
- s.OwnerID = item.owner_id
- if item.beatmap is None: # 从客户端接受的一定没有这字段
- cur_beatmap = await Beatmap.get_or_fetch(
- db, fetcher=fetcher, bid=item.beatmap_id
- )
- s.BeatmapID = cur_beatmap.id if cur_beatmap.id is not None else 0
- s.BeatmapChecksum = cur_beatmap.checksum
- s.StarRating = cur_beatmap.difficulty_rating
- s.RulesetID = item.ruleset_id
- s.RequierdMods = item.required_mods
- s.AllowedMods = item.allowed_mods
- s.Expired = item.expired
- s.PlaylistOrder = item.playlist_order if item.playlist_order is not None else 0
- s.PlayedAt = item.played_at
- s.Freestyle = item.freestyle
- return s
-
-
-class MultiplayerCountdown(BaseModel):
- id: int = 0
- time_remaining: timedelta = timedelta(seconds=0)
- is_exclusive: bool = True
-
- class Config:
- extra = "allow"
-
-
-class MatchStartCountdown(MultiplayerCountdown):
- pass
-
-
-class ForceGameplayStartCountdown(MultiplayerCountdown):
- pass
-
-
-class ServerShuttingCountdown(MultiplayerCountdown):
- pass
-
-
-MultiplayerCountdownType = (
- MatchStartCountdown
- | ForceGameplayStartCountdown
- | ServerShuttingCountdown
- | MultiplayerCountdown
-)
-
-
class PlaylistStatus(BaseModel):
count_active: int
count_total: int
ruleset_ids: list[int]
-
-
-class MultiplayerRoom(BaseModel):
- RoomId: int
- State: MultiplayerRoomState
- Settings: MultiplayerRoomSettings = MultiplayerRoomSettings(
- PlaylistItemId=0,
- MatchType=MatchType.HEAD_TO_HEAD,
- QueueMode=QueueMode.HOST_ONLY,
- AutoStartDuration=timedelta(0),
- AutoSkip=False,
- )
- Users: list[MultiplayerRoomUser]
- Host: MultiplayerRoomUser
- MatchState: MatchRoomState | None
- Playlist: list[MultiPlayerListItem]
- ActivecCountDowns: list[MultiplayerCountdownType]
- ChannelID: int
-
- @classmethod
- def CanAddPlayistItem(cls, user: MultiplayerRoomUser) -> bool:
- return user == cls.Host or cls.Settings.QueueMode != QueueMode.HOST_ONLY
-
- @classmethod
- async def from_apiRoom(cls, room: Room, db: AsyncSession, fetcher: Fetcher):
- s = cls.model_validate(room.model_dump())
- s.RoomId = room.room_id if room.room_id is not None else 0
- s.ChannelID = room.channel_id
- s.Settings = MultiplayerRoomSettings.from_apiRoom(room)
- s.Host = await MultiplayerRoomUser.from_id(room.host.id if room.host else 0, db)
- s.Playlist = [
- await MultiPlayerListItem.from_apiItem(item, db, fetcher)
- for item in room.playlist
- ]
- return s
-
-
-class Room(BaseModel):
- room_id: int
- name: str
- password: str | None
- has_password: bool = Field(exclude=True)
- host: User | None
- category: RoomCategory
- duration: int | None
- starts_at: datetime | None
- ends_at: datetime | None
- max_particapants: int | None = Field(exclude=True)
- particapant_count: int
- recent_particapants: list[User]
- type: MatchType
- max_attempts: int | None
- playlist: list[PlaylistItem]
- playlist_item_status: list[RoomPlaylistItemStats]
- difficulity_range: RoomDifficultyRange
- queue_mode: QueueMode
- auto_skip: bool
- auto_start_duration: int
- current_user_score: PlaylistAggregateScore | None
- current_playlist_item: PlaylistItem | None
- channel_id: int
- status: RoomStatus
- availability: RoomAvailability = Field(exclude=True)
-
- class Config:
- exclude_none = True
-
- @classmethod
- async def from_mpRoom(
- cls, room: MultiplayerRoom, db: AsyncSession, fetcher: Fetcher
- ):
- s = cls.model_validate(room.model_dump())
- s.room_id = room.RoomId
- s.name = room.Settings.Name
- s.password = room.Settings.Password
- s.type = room.Settings.MatchType
- s.queue_mode = room.Settings.QueueMode
- s.auto_skip = room.Settings.AutoSkip
- s.host = room.Host.User
- s.playlist = [
- await PlaylistItem.from_mpListItem(item, db, fetcher)
- for item in room.Playlist
- ]
- return s
diff --git a/app/models/score.py b/app/models/score.py
index b613ae2..cef6b28 100644
--- a/app/models/score.py
+++ b/app/models/score.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from enum import Enum, IntEnum
+from enum import Enum
from typing import Literal, TypedDict
from .mods import API_MODS, APIMod, init_mods
@@ -93,52 +93,14 @@ class HitResult(str, Enum):
)
-class HitResultInt(IntEnum):
- PERFECT = 0
- GREAT = 1
- GOOD = 2
- OK = 3
- MEH = 4
- MISS = 5
-
- LARGE_TICK_HIT = 6
- SMALL_TICK_HIT = 7
- SLIDER_TAIL_HIT = 8
-
- LARGE_BONUS = 9
- SMALL_BONUS = 10
-
- LARGE_TICK_MISS = 11
- SMALL_TICK_MISS = 12
-
- IGNORE_HIT = 13
- IGNORE_MISS = 14
-
- NONE = 15
- COMBO_BREAK = 16
-
- LEGACY_COMBO_INCREASE = 99
-
- def is_hit(self) -> bool:
- return self not in (
- HitResultInt.NONE,
- HitResultInt.IGNORE_MISS,
- HitResultInt.COMBO_BREAK,
- HitResultInt.LARGE_TICK_MISS,
- HitResultInt.SMALL_TICK_MISS,
- HitResultInt.MISS,
- )
-
-
class LeaderboardType(Enum):
GLOBAL = "global"
- FRIENDS = "friends"
+ FRIENDS = "friend"
COUNTRY = "country"
TEAM = "team"
ScoreStatistics = dict[HitResult, int]
-ScoreStatisticsInt = dict[HitResultInt, int]
class SoloScoreSubmissionInfo(BaseModel):
@@ -176,8 +138,8 @@ class SoloScoreSubmissionInfo(BaseModel):
class LegacyReplaySoloScoreInfo(TypedDict):
online_id: int
mods: list[APIMod]
- statistics: ScoreStatisticsInt
- maximum_statistics: ScoreStatisticsInt
+ statistics: ScoreStatistics
+ maximum_statistics: ScoreStatistics
client_version: str
rank: Rank
user_id: int
diff --git a/app/models/signalr.py b/app/models/signalr.py
index 37b2741..ffbaf6b 100644
--- a/app/models/signalr.py
+++ b/app/models/signalr.py
@@ -1,54 +1,23 @@
from __future__ import annotations
-import datetime
-from typing import Any, get_origin
+from dataclasses import dataclass
+from typing import ClassVar
from pydantic import (
BaseModel,
- ConfigDict,
Field,
- TypeAdapter,
- model_serializer,
- model_validator,
)
-def serialize_to_list(value: BaseModel) -> list[Any]:
- data = []
- for field, info in value.__class__.model_fields.items():
- v = getattr(value, field)
- anno = get_origin(info.annotation)
- if anno and issubclass(anno, BaseModel):
- data.append(serialize_to_list(v))
- elif anno and issubclass(anno, list):
- data.append(
- TypeAdapter(
- info.annotation, config=ConfigDict(arbitrary_types_allowed=True)
- ).dump_python(v)
- )
- elif isinstance(v, datetime.datetime):
- data.append([v, 0])
- else:
- data.append(v)
- return data
+@dataclass
+class SignalRMeta:
+ member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
+ json_ignore: bool = False # implement of JsonIgnore (json) attribute
+ use_abbr: bool = True
-class MessagePackArrayModel(BaseModel):
- model_config = ConfigDict(arbitrary_types_allowed=True)
-
- @model_validator(mode="before")
- @classmethod
- def unpack(cls, v: Any) -> Any:
- if isinstance(v, list):
- fields = list(cls.model_fields.keys())
- if len(v) != len(fields):
- raise ValueError(f"Expected list of length {len(fields)}, got {len(v)}")
- return dict(zip(fields, v))
- return v
-
- @model_serializer
- def serialize(self) -> list[Any]:
- return serialize_to_list(self)
+class SignalRUnionMessage(BaseModel):
+ union_type: ClassVar[int]
class Transport(BaseModel):
diff --git a/app/models/spectator_hub.py b/app/models/spectator_hub.py
index 994e083..9f35932 100644
--- a/app/models/spectator_hub.py
+++ b/app/models/spectator_hub.py
@@ -2,17 +2,17 @@ from __future__ import annotations
import datetime
from enum import IntEnum
-from typing import Any
+from typing import Annotated, Any
from app.models.beatmap import BeatmapRankStatus
+from app.models.mods import APIMod
from .score import (
- ScoreStatisticsInt,
+ ScoreStatistics,
)
-from .signalr import MessagePackArrayModel, UserState
+from .signalr import SignalRMeta, UserState
-from msgpack_lazer_api import APIMod
-from pydantic import BaseModel, ConfigDict, Field, field_validator
+from pydantic import BaseModel, Field, field_validator
class SpectatedUserState(IntEnum):
@@ -24,14 +24,12 @@ class SpectatedUserState(IntEnum):
Quit = 5
-class SpectatorState(MessagePackArrayModel):
- model_config = ConfigDict(arbitrary_types_allowed=True)
-
+class SpectatorState(BaseModel):
beatmap_id: int | None = None
ruleset_id: int | None = None # 0,1,2,3
mods: list[APIMod] = Field(default_factory=list)
state: SpectatedUserState
- maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict)
+ maximum_statistics: ScoreStatistics = Field(default_factory=dict)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SpectatorState):
@@ -44,22 +42,20 @@ class SpectatorState(MessagePackArrayModel):
)
-class ScoreProcessorStatistics(MessagePackArrayModel):
- base_score: int
- maximum_base_score: int
+class ScoreProcessorStatistics(BaseModel):
+ base_score: float
+ maximum_base_score: float
accuracy_judgement_count: int
combo_portion: float
- bouns_portion: float
+ bonus_portion: float
-class FrameHeader(MessagePackArrayModel):
- model_config = ConfigDict(arbitrary_types_allowed=True)
-
+class FrameHeader(BaseModel):
total_score: int
- acc: float
+ accuracy: float
combo: int
max_combo: int
- statistics: ScoreStatisticsInt = Field(default_factory=dict)
+ statistics: ScoreStatistics = Field(default_factory=dict)
score_processor_statistics: ScoreProcessorStatistics
received_time: datetime.datetime
mods: list[APIMod] = Field(default_factory=list)
@@ -87,14 +83,18 @@ class FrameHeader(MessagePackArrayModel):
# SMOKE = 16
-class LegacyReplayFrame(MessagePackArrayModel):
+class LegacyReplayFrame(BaseModel):
time: float # from ReplayFrame,the parent of LegacyReplayFrame
- x: float | None = None
- y: float | None = None
+ mouse_x: float | None = None
+ mouse_y: float | None = None
button_state: int
+ header: Annotated[
+ FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)
+ ]
-class FrameDataBundle(MessagePackArrayModel):
+
+class FrameDataBundle(BaseModel):
header: FrameHeader
frames: list[LegacyReplayFrame]
@@ -106,18 +106,16 @@ class APIUser(BaseModel):
class ScoreInfo(BaseModel):
- model_config = ConfigDict(arbitrary_types_allowed=True)
-
mods: list[APIMod]
user: APIUser
ruleset: int
- maximum_statistics: ScoreStatisticsInt
+ maximum_statistics: ScoreStatistics
id: int | None = None
total_score: int | None = None
- acc: float | None = None
+ accuracy: float | None = None
max_combo: int | None = None
combo: int | None = None
- statistics: ScoreStatisticsInt = Field(default_factory=dict)
+ statistics: ScoreStatistics = Field(default_factory=dict)
class StoreScore(BaseModel):
diff --git a/app/models/user.py b/app/models/user.py
index 4eb0911..3052eef 100644
--- a/app/models/user.py
+++ b/app/models/user.py
@@ -2,15 +2,11 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
-from typing import TYPE_CHECKING
-from .score import GameMode
+from .model import UTCBaseModel
from pydantic import BaseModel
-if TYPE_CHECKING:
- from app.database import LazerUserAchievement, Team
-
class PlayStyle(str, Enum):
MOUSE = "mouse"
@@ -77,24 +73,7 @@ class MonthlyPlaycount(BaseModel):
count: int
-class UserAchievement(BaseModel):
- achieved_at: datetime
- achievement_id: int
-
- # 添加数据库模型转换方法
- def to_db_model(self, user_id: int) -> "LazerUserAchievement":
- from app.database import (
- LazerUserAchievement,
- )
-
- return LazerUserAchievement(
- user_id=user_id,
- achievement_id=self.achievement_id,
- achieved_at=self.achieved_at,
- )
-
-
-class RankHighest(BaseModel):
+class RankHighest(UTCBaseModel):
rank: int
updated_at: datetime
@@ -104,115 +83,6 @@ class RankHistory(BaseModel):
data: list[int]
-class DailyChallengeStats(BaseModel):
- daily_streak_best: int = 0
- daily_streak_current: int = 0
- last_update: datetime | None = None
- last_weekly_streak: datetime | None = None
- playcount: int = 0
- top_10p_placements: int = 0
- top_50p_placements: int = 0
- user_id: int
- weekly_streak_best: int = 0
- weekly_streak_current: int = 0
-
-
class Page(BaseModel):
html: str = ""
raw: str = ""
-
-
-class User(BaseModel):
- # 基本信息
- id: int
- username: str
- avatar_url: str
- country_code: str
- default_group: str = "default"
- is_active: bool = True
- is_bot: bool = False
- is_deleted: bool = False
- is_online: bool = True
- is_supporter: bool = False
- is_restricted: bool = False
- last_visit: datetime | None = None
- pm_friends_only: bool = False
- profile_colour: str | None = None
-
- # 个人资料
- cover_url: str | None = None
- discord: str | None = None
- has_supported: bool = False
- interests: str | None = None
- join_date: datetime
- location: str | None = None
- max_blocks: int = 100
- max_friends: int = 500
- occupation: str | None = None
- playmode: GameMode = GameMode.OSU
- playstyle: list[PlayStyle] = []
- post_count: int = 0
- profile_hue: int | None = None
- profile_order: list[str] = [
- "me",
- "recent_activity",
- "top_ranks",
- "medals",
- "historical",
- "beatmaps",
- "kudosu",
- ]
- title: str | None = None
- title_url: str | None = None
- twitter: str | None = None
- website: str | None = None
- session_verified: bool = False
- support_level: int = 0
-
- # 关联对象
- country: Country
- cover: Cover
- kudosu: Kudosu
- statistics: Statistics
- statistics_rulesets: dict[str, Statistics]
-
- # 计数信息
- beatmap_playcounts_count: int = 0
- comments_count: int = 0
- favourite_beatmapset_count: int = 0
- follower_count: int = 0
- graveyard_beatmapset_count: int = 0
- guest_beatmapset_count: int = 0
- loved_beatmapset_count: int = 0
- mapping_follower_count: int = 0
- nominated_beatmapset_count: int = 0
- pending_beatmapset_count: int = 0
- ranked_beatmapset_count: int = 0
- ranked_and_approved_beatmapset_count: int = 0
- unranked_beatmapset_count: int = 0
- scores_best_count: int = 0
- scores_first_count: int = 0
- scores_pinned_count: int = 0
- scores_recent_count: int = 0
-
- # 历史数据
- account_history: list[dict] = []
- active_tournament_banner: dict | None = None
- active_tournament_banners: list[dict] = []
- badges: list[dict] = []
- current_season_stats: dict | None = None
- daily_challenge_user_stats: DailyChallengeStats | None = None
- groups: list[dict] = []
- monthly_playcounts: list[MonthlyPlaycount] = []
- page: Page = Page()
- previous_usernames: list[str] = []
- rank_highest: RankHighest | None = None
- rank_history: RankHistory | None = None
- rankHistory: RankHistory | None = None # 兼容性别名
- replays_watched_counts: list[dict] = []
- team: "Team | None" = None
- user_achievements: list[UserAchievement] = []
-
-
-class APIUser(BaseModel):
- id: int
diff --git a/app/router/__init__.py b/app/router/__init__.py
index 1e87343..22f6c70 100644
--- a/app/router/__init__.py
+++ b/app/router/__init__.py
@@ -7,6 +7,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
beatmapset,
me,
relationship,
+ room,
score,
user,
)
@@ -14,4 +15,9 @@ from .api_router import router as api_router
from .auth import router as auth_router
from .fetcher import fetcher_router as fetcher_router
-__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"]
+__all__ = [
+ "api_router",
+ "auth_router",
+ "fetcher_router",
+ "signalr_router",
+]
diff --git a/app/router/auth.py b/app/router/auth.py
index 0f41b32..7a2a14d 100644
--- a/app/router/auth.py
+++ b/app/router/auth.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from datetime import timedelta
+from datetime import UTC, datetime, timedelta
import re
from app.auth import (
@@ -12,17 +12,21 @@ from app.auth import (
store_token,
)
from app.config import settings
-from app.database import User as DBUser
+from app.database import DailyChallengeStats, User
+from app.database.statistics import UserStatistics
from app.dependencies import get_db
+from app.log import logger
from app.models.oauth import (
OAuthErrorResponse,
RegistrationRequestErrors,
TokenResponse,
UserRegistrationErrors,
)
+from app.models.score import GameMode
from fastapi import APIRouter, Depends, Form
from fastapi.responses import JSONResponse
+from sqlalchemy import text
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -110,12 +114,12 @@ async def register_user(
email_errors = validate_email(user_email)
password_errors = validate_password(user_password)
- result = await db.exec(select(DBUser).where(DBUser.name == user_username))
+ result = await db.exec(select(User).where(User.username == user_username))
existing_user = result.first()
if existing_user:
username_errors.append("Username is already taken")
- result = await db.exec(select(DBUser).where(DBUser.email == user_email))
+ result = await db.exec(select(User).where(User.email == user_email))
existing_email = result.first()
if existing_email:
email_errors.append("Email is already taken")
@@ -135,119 +139,41 @@ async def register_user(
try:
# 创建新用户
- from datetime import datetime
- import time
+ # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
+ result = await db.execute( # pyright: ignore[reportDeprecated]
+ text(
+ "SELECT AUTO_INCREMENT FROM information_schema.TABLES "
+ "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
+ )
+ )
+ next_id = result.one()[0]
+ if next_id <= 2:
+ await db.execute(text("ALTER TABLE lazer_users AUTO_INCREMENT = 3"))
+ await db.commit()
- new_user = DBUser(
- name=user_username,
- safe_name=user_username.lower(), # 安全用户名(小写)
+ new_user = User(
+ username=user_username,
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
- country="CN", # 默认国家
- creation_time=int(time.time()),
- latest_activity=int(time.time()),
- preferred_mode=0, # 默认模式
- play_style=0, # 默认游戏风格
+ country_code="CN", # 默认国家
+ join_date=datetime.now(UTC),
+ last_visit=datetime.now(UTC),
)
-
db.add(new_user)
await db.commit()
await db.refresh(new_user)
-
- # 保存用户ID,因为会话可能会关闭
- user_id = new_user.id
-
- if user_id <= 2:
- await db.rollback()
- try:
- from sqlalchemy import text
-
- # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
- await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
- await db.commit()
-
- # 重新创建用户
- new_user = DBUser(
- name=user_username,
- safe_name=user_username.lower(),
- email=user_email,
- pw_bcrypt=get_password_hash(user_password),
- priv=1,
- country="CN",
- creation_time=int(time.time()),
- latest_activity=int(time.time()),
- preferred_mode=0,
- play_style=0,
- )
-
- db.add(new_user)
- await db.commit()
- await db.refresh(new_user)
- user_id = new_user.id
-
- # 最终检查ID是否有效
- if user_id <= 2:
- await db.rollback()
- errors = RegistrationRequestErrors(
- message=(
- "Failed to create account with valid ID. "
- "Please contact support."
- )
- )
- return JSONResponse(
- status_code=500, content={"form_error": errors.model_dump()}
- )
-
- except Exception as fix_error:
- await db.rollback()
- print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
- errors = RegistrationRequestErrors(
- message="Failed to create account with valid ID. Please try again."
- )
- return JSONResponse(
- status_code=500, content={"form_error": errors.model_dump()}
- )
-
- # 创建默认的 lazer_profile
- from app.database.user import LazerUserProfile
-
- lazer_profile = LazerUserProfile(
- user_id=user_id,
- is_active=True,
- is_bot=False,
- is_deleted=False,
- is_online=True,
- is_supporter=False,
- is_restricted=False,
- session_verified=False,
- has_supported=False,
- pm_friends_only=False,
- default_group="default",
- join_date=datetime.utcnow(),
- playmode="osu",
- support_level=0,
- max_blocks=50,
- max_friends=250,
- post_count=0,
- )
-
- db.add(lazer_profile)
+ assert new_user.id is not None, "New user ID should not be None"
+ for i in GameMode:
+ statistics = UserStatistics(mode=i, user_id=new_user.id)
+ db.add(statistics)
+ daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id)
+ db.add(daily_challenge_user_stats)
await db.commit()
-
- # 返回成功响应
- return JSONResponse(
- status_code=201,
- content={"message": "Account created successfully", "user_id": user_id},
- )
-
- except Exception as e:
+ except Exception:
await db.rollback()
# 打印详细错误信息用于调试
- print(f"Registration error: {e}")
- import traceback
-
- traceback.print_exc()
+ logger.exception(f"Registration error for user {user_username}")
# 返回通用错误
errors = RegistrationRequestErrors(
@@ -323,6 +249,7 @@ async def oauth_token(
refresh_token_str = generate_refresh_token()
# 存储令牌
+ assert user.id
await store_token(
db,
user.id,
diff --git a/app/router/beatmap.py b/app/router/beatmap.py
index 71d554f..7dfd0f9 100644
--- a/app/router/beatmap.py
+++ b/app/router/beatmap.py
@@ -5,12 +5,7 @@ import hashlib
import json
from app.calculator import calculate_beatmap_attribute
-from app.database import (
- Beatmap,
- BeatmapResp,
- User as DBUser,
-)
-from app.database.beatmapset import Beatmapset
+from app.database import Beatmap, BeatmapResp, User
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -27,9 +22,8 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
-from redis import Redis
+from redis.asyncio import Redis
import rosu_pp_py as rosu
-from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -39,7 +33,7 @@ async def lookup_beatmap(
id: int | None = Query(default=None, alias="id"),
md5: str | None = Query(default=None, alias="checksum"),
filename: str | None = Query(default=None, alias="filename"),
- current_user: DBUser = Depends(get_current_user),
+ current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -56,19 +50,19 @@ async def lookup_beatmap(
if beatmap is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
- return BeatmapResp.from_db(beatmap)
+ return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
- current_user: DBUser = Depends(get_current_user),
+ current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
- return BeatmapResp.from_db(beatmap)
+ return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -80,43 +74,28 @@ class BatchGetResp(BaseModel):
@router.get("/beatmaps", tags=["beatmap"], response_model=BatchGetResp)
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
async def batch_get_beatmaps(
- b_ids: list[int] = Query(alias="id", default_factory=list),
- current_user: DBUser = Depends(get_current_user),
+ b_ids: list[int] = Query(alias="ids[]", default_factory=list),
+ current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if not b_ids:
# select 50 beatmaps by last_updated
beatmaps = (
await db.exec(
- select(Beatmap)
- .options(
- joinedload(
- Beatmap.beatmapset # pyright: ignore[reportArgumentType]
- ).selectinload(
- Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
- )
- )
- .order_by(col(Beatmap.last_updated).desc())
- .limit(50)
+ select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
)
).all()
else:
beatmaps = (
- await db.exec(
- select(Beatmap)
- .options(
- joinedload(
- Beatmap.beatmapset # pyright: ignore[reportArgumentType]
- ).selectinload(
- Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
- )
- )
- .where(col(Beatmap.id).in_(b_ids))
- .limit(50)
- )
+ await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
).all()
- return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
+ return BatchGetResp(
+ beatmaps=[
+ await BeatmapResp.from_db(bm, session=db, user=current_user)
+ for bm in beatmaps
+ ]
+ )
@router.post(
@@ -126,7 +105,7 @@ async def batch_get_beatmaps(
)
async def get_beatmap_attributes(
beatmap: int,
- current_user: DBUser = Depends(get_current_user),
+ current_user: User = Depends(get_current_user),
mods: list[str] = Query(default_factory=list),
ruleset: GameMode | None = Query(default=None),
ruleset_id: int | None = Query(default=None),
@@ -153,8 +132,8 @@ async def get_beatmap_attributes(
f"beatmap:{beatmap}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
- if redis.exists(key):
- return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
+ if await redis.exists(key):
+ return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
@@ -164,7 +143,7 @@ async def get_beatmap_attributes(
)
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
raise HTTPException(status_code=400, detail=str(e))
- redis.set(key, attr.model_dump_json())
+ await redis.set(key, attr.model_dump_json())
return attr
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")
diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py
index db2dd77..f77c2ed 100644
--- a/app/router/beatmapset.py
+++ b/app/router/beatmapset.py
@@ -1,10 +1,8 @@
from __future__ import annotations
-from app.database import (
- Beatmapset,
- BeatmapsetResp,
- User as DBUser,
-)
+from typing import Literal
+
+from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.dependencies.database import get_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -12,27 +10,47 @@ from app.fetcher import Fetcher
from .api_router import router
-from fastapi import Depends, HTTPException
+from fastapi import Depends, Form, HTTPException, Query
+from fastapi.responses import RedirectResponse
from httpx import HTTPStatusError
-from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
+@router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp)
+async def lookup_beatmapset(
+ beatmap_id: int = Query(),
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+ fetcher: Fetcher = Depends(get_fetcher),
+):
+ beatmapset_id = (
+ await db.exec(select(Beatmap.beatmapset_id).where(Beatmap.id == beatmap_id))
+ ).first()
+ if not beatmapset_id:
+ try:
+ resp = await fetcher.get_beatmap(beatmap_id)
+ await Beatmap.from_resp(db, resp)
+ await db.refresh(current_user)
+ except HTTPStatusError:
+ raise HTTPException(status_code=404, detail="Beatmapset not found")
+ beatmapset = (
+ await db.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))
+ ).first()
+ if not beatmapset:
+ raise HTTPException(status_code=404, detail="Beatmapset not found")
+ resp = await BeatmapsetResp.from_db(beatmapset, session=db, user=current_user)
+ return resp
+
+
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset(
sid: int,
- current_user: DBUser = Depends(get_current_user),
+ current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
- beatmapset = (
- await db.exec(
- select(Beatmapset)
- .options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
- .where(Beatmapset.id == sid)
- )
- ).first()
+ beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
if not beatmapset:
try:
resp = await fetcher.get_beatmapset(sid)
@@ -40,5 +58,55 @@ async def get_beatmapset(
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
else:
- resp = BeatmapsetResp.from_db(beatmapset)
+ resp = await BeatmapsetResp.from_db(
+ beatmapset, session=db, include=["recent_favourites"], user=current_user
+ )
return resp
+
+
+@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])
+async def download_beatmapset(
+ beatmapset: int,
+ no_video: bool = Query(True, alias="noVideo"),
+ current_user: User = Depends(get_current_user),
+):
+ if current_user.country_code == "CN":
+ return RedirectResponse(
+ f"https://txy1.sayobot.cn/beatmaps/download/"
+ f"{'novideo' if no_video else 'full'}/{beatmapset}?server=auto"
+ )
+ else:
+ return RedirectResponse(
+ f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}"
+ )
+
+
+@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"])
+async def favourite_beatmapset(
+ beatmapset: int,
+ action: Literal["favourite", "unfavourite"] = Form(),
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ existing_favourite = (
+ await db.exec(
+ select(FavouriteBeatmapset).where(
+ FavouriteBeatmapset.user_id == current_user.id,
+ FavouriteBeatmapset.beatmapset_id == beatmapset,
+ )
+ )
+ ).first()
+
+ if action == "favourite" and existing_favourite:
+ raise HTTPException(status_code=400, detail="Already favourited")
+ elif action == "unfavourite" and not existing_favourite:
+ raise HTTPException(status_code=400, detail="Not favourited")
+
+ if action == "favourite":
+ favourite = FavouriteBeatmapset(
+ user_id=current_user.id, beatmapset_id=beatmapset
+ )
+ db.add(favourite)
+ else:
+ await db.delete(existing_favourite)
+ await db.commit()
diff --git a/app/router/me.py b/app/router/me.py
index 93dcbdc..b6d7d26 100644
--- a/app/router/me.py
+++ b/app/router/me.py
@@ -1,28 +1,27 @@
from __future__ import annotations
-from typing import Literal
-
-from app.database import (
- User as DBUser,
-)
+from app.database import User, UserResp
+from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user
-from app.models.user import (
- User as ApiUser,
-)
-from app.utils import convert_db_user_to_api_user
+from app.dependencies.database import get_db
+from app.models.score import GameMode
from .api_router import router
from fastapi import Depends
+from sqlmodel.ext.asyncio.session import AsyncSession
-@router.get("/me/{ruleset}", response_model=ApiUser)
-@router.get("/me/", response_model=ApiUser)
+@router.get("/me/{ruleset}", response_model=UserResp)
+@router.get("/me/", response_model=UserResp)
async def get_user_info_default(
- ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
- current_user: DBUser = Depends(get_current_user),
+ ruleset: GameMode | None = None,
+ current_user: User = Depends(get_current_user),
+ session: AsyncSession = Depends(get_db),
):
- """获取当前用户信息(默认使用osu模式)"""
- # 默认使用osu模式
- api_user = await convert_db_user_to_api_user(current_user, ruleset)
- return api_user
+ return await UserResp.from_db(
+ current_user,
+ session,
+ ALL_INCLUDED,
+ ruleset,
+ )
diff --git a/app/router/relationship.py b/app/router/relationship.py
index 9ed5b0f..0832d09 100644
--- a/app/router/relationship.py
+++ b/app/router/relationship.py
@@ -8,7 +8,7 @@ from app.dependencies.user import get_current_user
from .api_router import router
from fastapi import Depends, HTTPException, Query, Request
-from sqlalchemy.orm import joinedload
+from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -26,17 +26,19 @@ async def get_relationship(
else RelationshipType.BLOCK
)
relationships = await db.exec(
- select(Relationship)
- .options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
- .where(
+ select(Relationship).where(
Relationship.user_id == current_user.id,
Relationship.type == relationship_type,
)
)
- return [await RelationshipResp.from_db(db, rel) for rel in relationships]
+ return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
-@router.post("/friends", tags=["relationship"], response_model=RelationshipResp)
+class AddFriendResp(BaseModel):
+ user_relation: RelationshipResp
+
+
+@router.post("/friends", tags=["relationship"], response_model=AddFriendResp)
@router.post("/blocks", tags=["relationship"])
async def add_relationship(
request: Request,
@@ -87,14 +89,10 @@ async def add_relationship(
if origin_type == RelationshipType.FOLLOW:
relationship = (
await db.exec(
- select(Relationship)
- .where(
+ select(Relationship).where(
Relationship.user_id == current_user_id,
Relationship.target_id == target,
)
- .options(
- joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
- )
)
).first()
assert relationship, "Relationship should exist after commit"
diff --git a/app/router/room.py b/app/router/room.py
index 6e6f4a1..2677b75 100644
--- a/app/router/room.py
+++ b/app/router/room.py
@@ -1,106 +1,296 @@
from __future__ import annotations
-from app.database.room import RoomIndex
+from datetime import UTC, datetime
+from typing import Literal
+
+from app.database.beatmap import Beatmap, BeatmapResp
+from app.database.beatmapset import BeatmapsetResp
+from app.database.lazer_user import User, UserResp
+from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
+from app.database.playlist_attempts import ItemAttemptsCount, ItemAttemptsResp
+from app.database.playlists import Playlist, PlaylistResp
+from app.database.room import Room, RoomBase, RoomResp
+from app.database.score import Score
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
+from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
-from app.models.room import MultiplayerRoom, MultiplayerRoomState, Room
+from app.models.multiplayer_hub import (
+ MultiplayerRoom,
+ MultiplayerRoomUser,
+ ServerMultiplayerRoom,
+)
+from app.models.room import RoomStatus
+from app.signalr.hub import MultiplayerHubs
+
+from .api_router import router
-from api_router import router
from fastapi import Depends, HTTPException, Query
-from sqlmodel import select
+from pydantic import BaseModel, Field
+from redis.asyncio import Redis
+from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
-@router.get("/rooms", tags=["rooms"], response_model=list[Room])
+@router.get("/rooms", tags=["rooms"], response_model=list[RoomResp])
async def get_all_rooms(
- mode: str | None = Query(None), # TODO: 对房间根据状态进行筛选
- status: str | None = Query(None),
- category: str | None = Query(
- None
- ), # TODO: 对房间根据分类进行筛选(真的有人用这功能吗)
+ mode: Literal["open", "ended", "participated", "owned", None] = Query(
+ default="open"
+ ), # TODO: 对房间根据状态进行筛选
+ category: str = Query(default="realtime"), # TODO
+ status: RoomStatus | None = Query(None),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
+ redis: Redis = Depends(get_redis),
+ current_user: User = Depends(get_current_user),
):
- all_roomID = (await db.exec(select(RoomIndex))).all()
- redis = get_redis()
- if redis is not None:
- resp: list[Room] = []
- for id in all_roomID:
- dumped_room = redis.get(str(id))
- validated_room = MultiplayerRoom.model_validate_json(str(dumped_room))
- flag: bool = False
- if status is not None:
- if (
- validated_room.State == MultiplayerRoomState.OPEN
- and status == "idle"
- ):
- flag = True
- elif validated_room != MultiplayerRoomState.CLOSED:
- flag = True
- if flag:
- resp.append(
- await Room.from_mpRoom(
- MultiplayerRoom.model_validate_json(str(dumped_room)),
- db,
- fetcher,
- )
- )
- return resp
- else:
- raise HTTPException(status_code=500, detail="Redis Error")
+ rooms = MultiplayerHubs.rooms.values()
+ resp_list: list[RoomResp] = []
+ for room in rooms:
+ # if category == "realtime" and room.category != "normal":
+ # continue
+ # elif category != room.category and category != "":
+ # continue
+ resp_list.append(await RoomResp.from_hub(room))
+ return resp_list
-@router.get("/rooms/{room}", tags=["room"], response_model=Room)
+class APICreatedRoom(RoomResp):
+ error: str = ""
+
+
+class APIUploadedRoom(RoomBase):
+ def to_room(self) -> Room:
+ """
+ 将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。
+ """
+ room_dict = self.model_dump()
+ room_dict.pop("playlist", None)
+ # host_id 已在字段中
+ return Room(**room_dict)
+
+ id: int | None
+ host_id: int | None = None
+ playlist: list[Playlist]
+
+
+@router.post("/rooms", tags=["room"], response_model=APICreatedRoom)
+async def create_room(
+ room: APIUploadedRoom,
+ db: AsyncSession = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+):
+ # db_room = Room.from_resp(room)
+ await db.refresh(current_user)
+ user_id = current_user.id
+ db_room = room.to_room()
+ db_room.host_id = current_user.id if current_user.id else 1
+ db.add(db_room)
+ await db.commit()
+ await db.refresh(db_room)
+
+ playlist: list[Playlist] = []
+ # 处理 APIUploadedRoom 里的 playlist 字段
+ for item in room.playlist:
+ # 确保 room_id 正确赋值
+ item.id = await Playlist.get_next_id_for_room(db_room.id, db)
+ item.room_id = db_room.id
+ item.owner_id = user_id if user_id else 1
+ db.add(item)
+ await db.commit()
+ await db.refresh(item)
+ playlist.append(item)
+ await db.refresh(db_room)
+ db_room.playlist = playlist
+ server_room = ServerMultiplayerRoom(
+ room=MultiplayerRoom.from_db(db_room),
+ category=db_room.category,
+ start_at=datetime.now(UTC),
+ hub=MultiplayerHubs,
+ )
+ MultiplayerHubs.rooms[db_room.id] = server_room
+ created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room))
+ created_room.error = ""
+ return created_room
+
+
+@router.get("/rooms/{room}", tags=["room"], response_model=RoomResp)
async def get_room(
room: int,
db: AsyncSession = Depends(get_db),
- fetcher: Fetcher = Depends(get_fetcher),
):
- redis = get_redis()
- if redis:
- dumped_room = str(redis.get(str(room)))
- if dumped_room is not None:
- resp = await Room.from_mpRoom(
- MultiplayerRoom.model_validate_json(str(dumped_room)), db, fetcher
- )
- return resp
- else:
- raise HTTPException(status_code=404, detail="Room Not Found")
- else:
- raise HTTPException(status_code=500, detail="Redis error")
-
-
-class APICreatedRoom(Room):
- error: str | None
-
-
-@router.post("/rooms", tags=["beatmap"], response_model=APICreatedRoom)
-async def create_room(
- room: Room,
- db: AsyncSession = Depends(get_db),
- fetcher: Fetcher = Depends(get_fetcher),
-):
- redis = get_redis()
- if redis:
- room_index = RoomIndex()
- db.add(room_index)
- await db.commit()
- await db.refresh(room_index)
- server_room = await MultiplayerRoom.from_apiRoom(room, db, fetcher)
- redis.set(str(room_index.id), server_room.model_dump_json())
- room.room_id = room_index.id
- return APICreatedRoom(**room.model_dump(), error=None)
- else:
- raise HTTPException(status_code=500, detail="redis error")
+ server_room = MultiplayerHubs.rooms[room]
+ return await RoomResp.from_hub(server_room)
@router.delete("/rooms/{room}", tags=["room"])
-async def remove_room(room: int, db: AsyncSession = Depends(get_db)):
- redis = get_redis()
- if redis:
- redis.delete(str(room))
- room_index = await db.get(RoomIndex, room)
- if room_index:
- await db.delete(room_index)
+async def delete_room(room: int, db: AsyncSession = Depends(get_db)):
+ db_room = (await db.exec(select(Room).where(Room.id == room))).first()
+ if db_room is None:
+ raise HTTPException(404, "Room not found")
+ else:
+ await db.delete(db_room)
+ return None
+
+
+@router.put("/rooms/{room}/users/{user}", tags=["room"])
+async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_db)):
+ server_room = MultiplayerHubs.rooms[room]
+ server_room.room.users.append(MultiplayerRoomUser(user_id=user))
+ db_room = (await db.exec(select(Room).where(Room.id == room))).first()
+ if db_room is not None:
+ db_room.participant_count += 1
await db.commit()
+ resp = await RoomResp.from_hub(server_room)
+ await db.refresh(db_room)
+ for item in db_room.playlist:
+ resp.playlist.append(await PlaylistResp.from_db(item, ["beatmap"]))
+ return resp
+ else:
+ raise HTTPException(404, "room not found0")
+
+
+class APILeaderboard(BaseModel):
+ leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
+ user_score: ItemAttemptsResp | None = None
+
+
+@router.get("/rooms/{room}/leaderboard", tags=["room"], response_model=APILeaderboard)
+async def get_room_leaderboard(
+ room: int,
+ db: AsyncSession = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+):
+ server_room = MultiplayerHubs.rooms[room]
+ if not server_room:
+ raise HTTPException(404, "Room not found")
+
+ aggs = await db.exec(
+ select(ItemAttemptsCount)
+ .where(ItemAttemptsCount.room_id == room)
+ .order_by(col(ItemAttemptsCount.total_score).desc())
+ )
+ aggs_resp = []
+ user_agg = None
+ for i, agg in enumerate(aggs):
+ resp = await ItemAttemptsResp.from_db(agg, db)
+ resp.position = i + 1
+ aggs_resp.append(resp)
+ if agg.user_id == current_user.id:
+ user_agg = resp
+ return APILeaderboard(
+ leaderboard=aggs_resp,
+ user_score=user_agg,
+ )
+
+
+class RoomEvents(BaseModel):
+ beatmaps: list[BeatmapResp] = Field(default_factory=list)
+ beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
+ current_playlist_item_id: int = 0
+ events: list[MultiplayerEventResp] = Field(default_factory=list)
+ first_event_id: int = 0
+ last_event_id: int = 0
+ playlist_items: list[PlaylistResp] = Field(default_factory=list)
+ room: RoomResp
+ user: list[UserResp] = Field(default_factory=list)
+
+
+@router.get("/rooms/{room_id}/events", response_model=RoomEvents, tags=["room"])
+async def get_room_events(
+ room_id: int,
+ db: AsyncSession = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+ limit: int = Query(100, ge=1, le=1000),
+ after: int | None = Query(None, ge=0),
+ before: int | None = Query(None, ge=0),
+):
+ events = (
+ await db.exec(
+ select(MultiplayerEvent)
+ .where(
+ MultiplayerEvent.room_id == room_id,
+ col(MultiplayerEvent.id) > after if after is not None else True,
+ col(MultiplayerEvent.id) < before if before is not None else True,
+ )
+ .order_by(col(MultiplayerEvent.id).desc())
+ .limit(limit)
+ )
+ ).all()
+
+ user_ids = set()
+ playlist_items = {}
+ beatmap_ids = set()
+
+ event_resps = []
+ first_event_id = 0
+ last_event_id = 0
+
+ current_playlist_item_id = 0
+ for event in events:
+ event_resps.append(MultiplayerEventResp.from_db(event))
+
+ if event.user_id:
+ user_ids.add(event.user_id)
+
+ if event.playlist_item_id is not None and (
+ playitem := (
+ await db.exec(
+ select(Playlist).where(
+ Playlist.id == event.playlist_item_id,
+ Playlist.room_id == room_id,
+ )
+ )
+ ).first()
+ ):
+ current_playlist_item_id = playitem.id
+ playlist_items[event.playlist_item_id] = playitem
+ beatmap_ids.add(playitem.beatmap_id)
+ scores = await db.exec(
+ select(Score).where(
+ Score.playlist_item_id == event.playlist_item_id,
+ Score.room_id == room_id,
+ )
+ )
+ for score in scores:
+ user_ids.add(score.user_id)
+ beatmap_ids.add(score.beatmap_id)
+
+ assert event.id is not None
+ first_event_id = min(first_event_id, event.id)
+ last_event_id = max(last_event_id, event.id)
+
+ if room := MultiplayerHubs.rooms.get(room_id):
+ current_playlist_item_id = room.queue.current_item.id
+ room_resp = await RoomResp.from_hub(room)
+ else:
+ room = (await db.exec(select(Room).where(Room.id == room_id))).first()
+ if room is None:
+ raise HTTPException(404, "Room not found")
+ room_resp = await RoomResp.from_db(room)
+
+ users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
+ user_resps = [await UserResp.from_db(user, db) for user in users]
+ beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
+ beatmap_resps = [
+ await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
+ ]
+ beatmapset_resps = {}
+ for beatmap_resp in beatmap_resps:
+ beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
+
+ playlist_items_resps = [
+ await PlaylistResp.from_db(item) for item in playlist_items.values()
+ ]
+
+ return RoomEvents(
+ beatmaps=beatmap_resps,
+ beatmapsets=beatmapset_resps,
+ current_playlist_item_id=current_playlist_item_id,
+ events=event_resps,
+ first_event_id=first_event_id,
+ last_event_id=last_event_id,
+ playlist_items=playlist_items_resps,
+ room=room_resp,
+ user=user_resps,
+ )
diff --git a/app/router/score.py b/app/router/score.py
index cc38dcc..5db171d 100644
--- a/app/router/score.py
+++ b/app/router/score.py
@@ -1,18 +1,41 @@
from __future__ import annotations
+from datetime import UTC, datetime
+import time
+
+from app.calculator import clamp
from app.database import (
- User as DBUser,
+ Beatmap,
+ Playlist,
+ Room,
+ Score,
+ ScoreResp,
+ ScoreToken,
+ ScoreTokenResp,
+ User,
+)
+from app.database.playlist_attempts import ItemAttemptsCount
+from app.database.playlist_best_score import (
+ PlaylistBestScore,
+ get_position,
+ process_playlist_best_score,
+)
+from app.database.score import (
+ MultiplayerScores,
+ ScoreAround,
+ get_leaderboard,
+ process_score,
+ process_user,
)
-from app.database.beatmap import Beatmap
-from app.database.score import Score, ScoreResp, process_score, process_user
-from app.database.score_token import ScoreToken, ScoreTokenResp
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
+from app.fetcher import Fetcher
from app.models.beatmap import BeatmapRankStatus
from app.models.score import (
INT_TO_MODE,
GameMode,
+ LeaderboardType,
Rank,
SoloScoreSubmissionInfo,
)
@@ -21,11 +44,75 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel
-from redis import Redis
+from redis.asyncio import Redis
from sqlalchemy.orm import joinedload
-from sqlmodel import col, select, true
+from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
+READ_SCORE_TIMEOUT = 10
+
+
+async def submit_score(
+ info: SoloScoreSubmissionInfo,
+ beatmap: int,
+ token: int,
+ current_user: User,
+ db: AsyncSession,
+ redis: Redis,
+ fetcher: Fetcher,
+ item_id: int | None = None,
+ room_id: int | None = None,
+):
+ if not info.passed:
+ info.rank = Rank.F
+ score_token = (
+ await db.exec(
+ select(ScoreToken)
+ .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
+ .where(ScoreToken.id == token)
+ )
+ ).first()
+ if not score_token or score_token.user_id != current_user.id:
+ raise HTTPException(status_code=404, detail="Score token not found")
+ if score_token.score_id:
+ score = (
+ await db.exec(
+ select(Score).where(
+ Score.id == score_token.score_id,
+ Score.user_id == current_user.id,
+ )
+ )
+ ).first()
+ if not score:
+ raise HTTPException(status_code=404, detail="Score not found")
+ else:
+ beatmap_status = (
+ await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap))
+ ).first()
+ if beatmap_status is None:
+ raise HTTPException(status_code=404, detail="Beatmap not found")
+ ranked = beatmap_status in {
+ BeatmapRankStatus.RANKED,
+ BeatmapRankStatus.APPROVED,
+ }
+ score = await process_score(
+ current_user,
+ beatmap,
+ ranked,
+ score_token,
+ info,
+ fetcher,
+ db,
+ redis,
+ )
+ await db.refresh(current_user)
+ score_id = score.id
+ score_token.score_id = score_id
+ await process_user(db, current_user, score, ranked)
+ score = (await db.exec(select(Score).where(Score.id == score_id))).first()
+ assert score is not None
+ return await ScoreResp.from_db(db, score)
+
class BeatmapScores(BaseModel):
scores: list[ScoreResp]
@@ -37,44 +124,26 @@ class BeatmapScores(BaseModel):
)
async def get_beatmap_scores(
beatmap: int,
+ mode: GameMode,
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
- mode: GameMode | None = Query(None),
- # mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
- type: str = Query(None),
- current_user: DBUser = Depends(get_current_user),
+ mods: list[str] = Query(default_factory=set, alias="mods[]"),
+ type: LeaderboardType = Query(LeaderboardType.GLOBAL),
+ current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
+ limit: int = Query(50, ge=1, le=200),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="this server only contains lazer scores"
)
- all_scores = (
- await db.exec(
- Score.select_clause_unique(
- Score.beatmap_id == beatmap,
- col(Score.passed).is_(True),
- Score.gamemode == mode if mode is not None else true(),
- )
- )
- ).all()
-
- user_score = (
- await db.exec(
- Score.select_clause_unique(
- Score.beatmap_id == beatmap,
- Score.user_id == current_user.id,
- col(Score.passed).is_(True),
- Score.gamemode == mode if mode is not None else true(),
- )
- )
- ).first()
+ all_scores, user_score = await get_leaderboard(
+ db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods
+ )
return BeatmapScores(
- scores=[await ScoreResp.from_db(db, score, score.user) for score in all_scores],
- userScore=await ScoreResp.from_db(db, user_score, user_score.user)
- if user_score
- else None,
+ scores=[await ScoreResp.from_db(db, score) for score in all_scores],
+ userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
)
@@ -94,7 +163,7 @@ async def get_user_beatmap_score(
legacy_only: bool = Query(None),
mode: str = Query(None),
mods: str = Query(None), # TODO:添加mods筛选
- current_user: DBUser = Depends(get_current_user),
+ current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -103,7 +172,7 @@ async def get_user_beatmap_score(
)
user_score = (
await db.exec(
- Score.select_clause(True)
+ select(Score)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,
@@ -118,9 +187,10 @@ async def get_user_beatmap_score(
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
)
else:
+ resp = await ScoreResp.from_db(db, user_score)
return BeatmapUserScore(
- position=user_score.position if user_score.position is not None else 0,
- score=await ScoreResp.from_db(db, user_score, user_score.user),
+ position=resp.rank_global or 0,
+ score=resp,
)
@@ -134,7 +204,7 @@ async def get_user_all_beatmap_scores(
user: int,
legacy_only: bool = Query(None),
ruleset: str = Query(None),
- current_user: DBUser = Depends(get_current_user),
+ current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -143,7 +213,7 @@ async def get_user_all_beatmap_scores(
)
all_user_scores = (
await db.exec(
- Score.select_clause()
+ select(Score)
.where(
Score.gamemode == ruleset if ruleset is not None else True,
Score.beatmap_id == beatmap,
@@ -153,9 +223,7 @@ async def get_user_all_beatmap_scores(
)
).all()
- return [
- await ScoreResp.from_db(db, score, current_user) for score in all_user_scores
- ]
+ return [await ScoreResp.from_db(db, score) for score in all_user_scores]
@router.post(
@@ -166,9 +234,10 @@ async def create_solo_score(
version_hash: str = Form(""),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
- current_user: DBUser = Depends(get_current_user),
+ current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
+ assert current_user.id
async with db:
score_token = ScoreToken(
user_id=current_user.id,
@@ -190,64 +259,275 @@ async def submit_solo_score(
beatmap: int,
token: int,
info: SoloScoreSubmissionInfo,
- current_user: DBUser = Depends(get_current_user),
+ current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
):
- if not info.passed:
- info.rank = Rank.F
- async with db:
- score_token = (
- await db.exec(
- select(ScoreToken)
- .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
- .where(ScoreToken.id == token, ScoreToken.user_id == current_user.id)
+ return await submit_score(info, beatmap, token, current_user, db, redis, fetcher)
+
+
+@router.post(
+ "/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp
+)
+async def create_playlist_score(
+ room_id: int,
+ playlist_id: int,
+ beatmap_id: int = Form(),
+ beatmap_hash: str = Form(),
+ ruleset_id: int = Form(..., ge=0, le=3),
+ version_hash: str = Form(""),
+ current_user: User = Depends(get_current_user),
+ session: AsyncSession = Depends(get_db),
+):
+ room = await session.get(Room, room_id)
+ if not room:
+ raise HTTPException(status_code=404, detail="Room not found")
+ if room.ended_at and room.ended_at < datetime.now(UTC):
+ raise HTTPException(status_code=400, detail="Room has ended")
+ item = (
+ await session.exec(
+ select(Playlist).where(
+ Playlist.id == playlist_id, Playlist.room_id == room_id
)
- ).first()
- if not score_token or score_token.user_id != current_user.id:
- raise HTTPException(status_code=404, detail="Score token not found")
- if score_token.score_id:
- score = (
- await db.exec(
- select(Score)
- .options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
- .where(
- Score.id == score_token.score_id,
- Score.user_id == current_user.id,
+ )
+ ).first()
+ if not item:
+ raise HTTPException(status_code=404, detail="Playlist not found")
+
+ # validate
+ if not item.freestyle:
+ if item.ruleset_id != ruleset_id:
+ raise HTTPException(
+ status_code=400, detail="Ruleset mismatch in playlist item"
+ )
+ if item.beatmap_id != beatmap_id:
+ raise HTTPException(
+ status_code=400, detail="Beatmap ID mismatch in playlist item"
+ )
+ agg = await session.exec(
+ select(ItemAttemptsCount).where(
+ ItemAttemptsCount.room_id == room_id,
+ ItemAttemptsCount.user_id == current_user.id,
+ )
+ )
+ agg = agg.first()
+ if agg and room.max_attempts and agg.attempts >= room.max_attempts:
+ raise HTTPException(
+ status_code=422,
+ detail="You have reached the maximum attempts for this room",
+ )
+ if item.expired:
+ raise HTTPException(status_code=400, detail="Playlist item has expired")
+ if item.played_at:
+ raise HTTPException(
+ status_code=400, detail="Playlist item has already been played"
+ )
+ # 这里应该不用验证mod了吧。。。
+
+ score_token = ScoreToken(
+ user_id=current_user.id,
+ beatmap_id=beatmap_id,
+ ruleset_id=INT_TO_MODE[ruleset_id],
+ playlist_item_id=playlist_id,
+ )
+ session.add(score_token)
+ await session.commit()
+ await session.refresh(score_token)
+ return ScoreTokenResp.from_db(score_token)
+
+
+@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}")
+async def submit_playlist_score(
+ room_id: int,
+ playlist_id: int,
+ token: int,
+ info: SoloScoreSubmissionInfo,
+ current_user: User = Depends(get_current_user),
+ session: AsyncSession = Depends(get_db),
+ redis: Redis = Depends(get_redis),
+ fetcher: Fetcher = Depends(get_fetcher),
+):
+ item = (
+ await session.exec(
+ select(Playlist).where(
+ Playlist.id == playlist_id, Playlist.room_id == room_id
+ )
+ )
+ ).first()
+ if not item:
+ raise HTTPException(status_code=404, detail="Playlist item not found")
+
+ user_id = current_user.id
+ score_resp = await submit_score(
+ info,
+ item.beatmap_id,
+ token,
+ current_user,
+ session,
+ redis,
+ fetcher,
+ item.id,
+ room_id,
+ )
+ await process_playlist_best_score(
+ room_id,
+ playlist_id,
+ user_id,
+ score_resp.id,
+ score_resp.total_score,
+ session,
+ redis,
+ )
+ await ItemAttemptsCount.get_or_create(room_id, user_id, session)
+ return score_resp
+
+
+class IndexedScoreResp(MultiplayerScores):
+ total: int
+ user_score: ScoreResp | None = None
+
+
+@router.get(
+ "/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=IndexedScoreResp
+)
+async def index_playlist_scores(
+ room_id: int,
+ playlist_id: int,
+ limit: int = 50,
+ cursor: int = Query(2000000, alias="cursor[total_score]"),
+ current_user: User = Depends(get_current_user),
+ session: AsyncSession = Depends(get_db),
+):
+ limit = clamp(limit, 1, 50)
+
+ scores = (
+ await session.exec(
+ select(PlaylistBestScore)
+ .where(
+ PlaylistBestScore.playlist_id == playlist_id,
+ PlaylistBestScore.room_id == room_id,
+ PlaylistBestScore.total_score < cursor,
+ )
+ .order_by(col(PlaylistBestScore.total_score).desc())
+ .limit(limit + 1)
+ )
+ ).all()
+ has_more = len(scores) > limit
+ if has_more:
+ scores = scores[:-1]
+
+ user_score = None
+ score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
+ for score in score_resp:
+ score.position = await get_position(room_id, playlist_id, score.id, session)
+ if score.user_id == current_user.id:
+ user_score = score
+ resp = IndexedScoreResp(
+ scores=score_resp,
+ user_score=user_score,
+ total=len(scores),
+ params={
+ "limit": limit,
+ },
+ )
+ if has_more:
+ resp.cursor = {
+ "total_score": scores[-1].total_score,
+ }
+ return resp
+
+
+@router.get(
+ "/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
+ response_model=ScoreResp,
+)
+async def show_playlist_score(
+ room_id: int,
+ playlist_id: int,
+ score_id: int,
+ current_user: User = Depends(get_current_user),
+ session: AsyncSession = Depends(get_db),
+ redis: Redis = Depends(get_redis),
+):
+ start_time = time.time()
+ score_record = None
+ completed = False
+ while time.time() - start_time < READ_SCORE_TIMEOUT:
+ if score_record is None:
+ score_record = (
+ await session.exec(
+ select(PlaylistBestScore).where(
+ PlaylistBestScore.score_id == score_id,
+ PlaylistBestScore.playlist_id == playlist_id,
+ PlaylistBestScore.room_id == room_id,
)
)
).first()
- if not score:
- raise HTTPException(status_code=404, detail="Score not found")
- else:
- beatmap_status = (
- await db.exec(
- select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
+ if completed_players := await redis.get(
+ f"multiplayer:{room_id}:gameplay:players"
+ ):
+ completed = completed_players == "0"
+ if score_record and completed:
+ break
+ if not score_record:
+ raise HTTPException(status_code=404, detail="Score not found")
+ resp = await ScoreResp.from_db(session, score_record.score)
+ resp.position = await get_position(room_id, playlist_id, score_id, session)
+ if completed:
+ scores = (
+ await session.exec(
+ select(PlaylistBestScore).where(
+ PlaylistBestScore.playlist_id == playlist_id,
+ PlaylistBestScore.room_id == room_id,
)
- ).first()
- if beatmap_status is None:
- raise HTTPException(status_code=404, detail="Beatmap not found")
- ranked = beatmap_status in {
- BeatmapRankStatus.RANKED,
- BeatmapRankStatus.APPROVED,
- }
- score = await process_score(
- current_user,
- beatmap,
- ranked,
- score_token,
- info,
- fetcher,
- db,
- redis,
)
- await db.refresh(current_user)
- score_id = score.id
- score_token.score_id = score_id
- await process_user(db, current_user, score, ranked)
- score = (
- await db.exec(Score.select_clause().where(Score.id == score_id))
- ).first()
- assert score is not None
- return await ScoreResp.from_db(db, score, current_user)
+ ).all()
+ higher_scores = []
+ lower_scores = []
+ for score in scores:
+ if score.total_score > resp.total_score:
+ higher_scores.append(await ScoreResp.from_db(session, score.score))
+ elif score.total_score < resp.total_score:
+ lower_scores.append(await ScoreResp.from_db(session, score.score))
+ resp.scores_around = ScoreAround(
+ higher=MultiplayerScores(scores=higher_scores),
+ lower=MultiplayerScores(scores=lower_scores),
+ )
+
+ return resp
+
+
+@router.get(
+ "rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
+ response_model=ScoreResp,
+)
+async def get_user_playlist_score(
+ room_id: int,
+ playlist_id: int,
+ user_id: int,
+ current_user: User = Depends(get_current_user),
+ session: AsyncSession = Depends(get_db),
+):
+ score_record = None
+ start_time = time.time()
+ while time.time() - start_time < READ_SCORE_TIMEOUT:
+ score_record = (
+ await session.exec(
+ select(PlaylistBestScore).where(
+ PlaylistBestScore.user_id == user_id,
+ PlaylistBestScore.playlist_id == playlist_id,
+ PlaylistBestScore.room_id == room_id,
+ )
+ )
+ ).first()
+ if score_record:
+ break
+ if not score_record:
+ raise HTTPException(status_code=404, detail="Score not found")
+
+ resp = await ScoreResp.from_db(session, score_record.score)
+ resp.position = await get_position(
+ room_id, playlist_id, score_record.score_id, session
+ )
+ return resp
diff --git a/app/router/user.py b/app/router/user.py
index 6e169c3..649f1d4 100644
--- a/app/router/user.py
+++ b/app/router/user.py
@@ -1,12 +1,9 @@
from __future__ import annotations
-from typing import Literal
-
-from app.database import User as DBUser
+from app.database import User, UserResp
+from app.database.lazer_user import SEARCH_INCLUDED
from app.dependencies.database import get_db
-from app.models.score import INT_TO_MODE
-from app.models.user import User as ApiUser
-from app.utils import convert_db_user_to_api_user
+from app.models.score import GameMode
from .api_router import router
@@ -17,28 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import col
-# ---------- Shared Utility ----------
-async def get_user_by_lookup(
- db: AsyncSession, lookup: str, key: str = "id"
-) -> DBUser | None:
- """根据查找方式获取用户"""
- if key == "id":
- try:
- user_id = int(lookup)
- result = await db.exec(select(DBUser).where(DBUser.id == user_id))
- return result.first()
- except ValueError:
- return None
- elif key == "username":
- result = await db.exec(select(DBUser).where(DBUser.name == lookup))
- return result.first()
- else:
- return None
-
-
-# ---------- Batch Users ----------
class BatchUserResponse(BaseModel):
- users: list[ApiUser]
+ users: list[UserResp]
@router.get("/users", response_model=BatchUserResponse)
@@ -51,75 +28,44 @@ async def get_users(
):
if user_ids:
searched_users = (
- await session.exec(
- DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
- )
+ await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
).all()
else:
- searched_users = (
- await session.exec(DBUser.all_select_clause().limit(50))
- ).all()
+ searched_users = (await session.exec(select(User).limit(50))).all()
return BatchUserResponse(
users=[
- await convert_db_user_to_api_user(
- searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value
+ await UserResp.from_db(
+ searched_user,
+ session,
+ include=SEARCH_INCLUDED,
)
for searched_user in searched_users
]
)
-# # ---------- Individual User ----------
-# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
-# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
-# async def get_user_with_mode(
-# user_lookup: str,
-# mode: Literal["osu", "taiko", "fruits", "mania"],
-# key: Literal["id", "username"] = Query("id"),
-# current_user: DBUser = Depends(get_current_user),
-# db: AsyncSession = Depends(get_db),
-# ):
-# """获取指定游戏模式的用户信息"""
-# user = await get_user_by_lookup(db, user_lookup, key)
-# if not user:
-# raise HTTPException(status_code=404, detail="User not found")
-
-# return await convert_db_user_to_api_user(user, mode)
-
-
-# @router.get("/users/{user_lookup}", response_model=ApiUser)
-# @router.get("/users/{user_lookup}/", response_model=ApiUser)
-# async def get_user_default(
-# user_lookup: str,
-# key: Literal["id", "username"] = Query("id"),
-# current_user: DBUser = Depends(get_current_user),
-# db: AsyncSession = Depends(get_db),
-# ):
-# """获取用户信息(默认使用osu模式,但包含所有模式的统计信息)"""
-# user = await get_user_by_lookup(db, user_lookup, key)
-# if not user:
-# raise HTTPException(status_code=404, detail="User not found")
-
-# return await convert_db_user_to_api_user(user, "osu")
-
-
-@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
-@router.get("/users/{user}/", response_model=ApiUser)
-@router.get("/users/{user}", response_model=ApiUser)
+@router.get("/users/{user}/{ruleset}", response_model=UserResp)
+@router.get("/users/{user}/", response_model=UserResp)
+@router.get("/users/{user}", response_model=UserResp)
async def get_user_info(
user: str,
- ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
+ ruleset: GameMode | None = None,
session: AsyncSession = Depends(get_db),
):
searched_user = (
await session.exec(
- DBUser.all_select_clause().where(
- DBUser.id == int(user)
+ select(User).where(
+ User.id == int(user)
if user.isdigit()
- else DBUser.name == user.removeprefix("@")
+ else User.username == user.removeprefix("@")
)
)
).first()
if not searched_user:
raise HTTPException(404, detail="User not found")
- return await convert_db_user_to_api_user(searched_user, ruleset=ruleset)
+ return await UserResp.from_db(
+ searched_user,
+ session,
+ include=SEARCH_INCLUDED,
+ ruleset=ruleset,
+ )
diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py
index 276140f..4bab451 100644
--- a/app/signalr/hub/hub.py
+++ b/app/signalr/hub/hub.py
@@ -6,9 +6,9 @@ import time
from typing import Any
from app.config import settings
+from app.exception import InvokeException
from app.log import logger
from app.models.signalr import UserState
-from app.signalr.exception import InvokeException
from app.signalr.packet import (
ClosePacket,
CompletionPacket,
@@ -21,7 +21,6 @@ from app.signalr.store import ResultStore
from app.signalr.utils import get_signature
from fastapi import WebSocket
-from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
@@ -49,7 +48,7 @@ class Client:
self.connection_id = connection_id
self.connection_token = connection_token
self.connection = connection
- self.procotol = protocol
+ self.protocol = protocol
self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None
self._store = ResultStore()
@@ -62,14 +61,14 @@ class Client:
return int(self.connection_id)
async def send_packet(self, packet: Packet):
- await self.connection.send_bytes(self.procotol.encode(packet))
+ await self.connection.send_bytes(self.protocol.encode(packet))
async def receive_packets(self) -> list[Packet]:
message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
return []
- return self.procotol.decode(d)
+ return self.protocol.decode(d)
async def _ping(self):
while True:
@@ -263,10 +262,9 @@ class Hub[TState: UserState]:
for name, param in signature.parameters.items():
if name == "self" or param.annotation is Client:
continue
- if issubclass(param.annotation, BaseModel):
- call_params.append(param.annotation.model_validate(args.pop(0)))
- else:
- call_params.append(args.pop(0))
+ call_params.append(
+ client.protocol.validate_object(args.pop(0), param.annotation)
+ )
return await method_(client, *call_params)
async def call(self, client: Client, method: str, *args: Any) -> Any:
diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py
index 821d831..08ee035 100644
--- a/app/signalr/hub/metadata.py
+++ b/app/signalr/hub/metadata.py
@@ -2,15 +2,15 @@ from __future__ import annotations
import asyncio
from collections.abc import Coroutine
+from datetime import UTC, datetime
from typing import override
-from app.database.relationship import Relationship, RelationshipType
-from app.dependencies.database import engine
+from app.database import Relationship, RelationshipType, User
+from app.dependencies.database import engine, get_redis
from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity
from .hub import Client, Hub
-from pydantic import TypeAdapter
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -30,7 +30,7 @@ class MetadataHub(Hub[MetadataClientState]):
) -> set[Coroutine]:
if store is not None and not store.pushable:
return set()
- data = store.to_dict() if store else None
+ data = store.for_push if store else None
return {
self.broadcast_group_call(
self.online_presence_watchers_group(),
@@ -54,6 +54,18 @@ class MetadataHub(Hub[MetadataClientState]):
async def _clean_state(self, state: MetadataClientState) -> None:
if state.pushable:
await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None))
+ redis = get_redis()
+ if await redis.exists(f"metadata:online:{state.connection_id}"):
+ await redis.delete(f"metadata:online:{state.connection_id}")
+ async with AsyncSession(engine) as session:
+ async with session.begin():
+ user = (
+ await session.exec(
+ select(User).where(User.id == int(state.connection_id))
+ )
+ ).one()
+ user.last_visit = datetime.now(UTC)
+ await session.commit()
@override
def create_state(self, client: Client) -> MetadataClientState:
@@ -89,10 +101,14 @@ class MetadataHub(Hub[MetadataClientState]):
self.friend_presence_watchers_group(friend_id),
"FriendPresenceUpdated",
friend_id,
- friend_state.to_dict(),
+ friend_state.for_push
+ if friend_state.pushable
+ else None,
)
)
await asyncio.gather(*tasks)
+ redis = get_redis()
+ await redis.set(f"metadata:online:{user_id}", "")
async def UpdateStatus(self, client: Client, status: int) -> None:
status_ = OnlineStatus(status)
@@ -107,27 +123,24 @@ class MetadataHub(Hub[MetadataClientState]):
client,
"UserPresenceUpdated",
user_id,
- store.to_dict(),
+ store.for_push,
)
)
await asyncio.gather(*tasks)
- async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
+ async def UpdateActivity(
+ self, client: Client, activity: UserActivity | None
+ ) -> None:
user_id = int(client.connection_id)
- activity = (
- TypeAdapter(UserActivity).validate_python(activity_dict)
- if activity_dict
- else None
- )
store = self.get_or_create_state(client)
- store.user_activity = activity
+ store.activity = activity
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
- store.to_dict(),
+ store.for_push,
)
)
await asyncio.gather(*tasks)
@@ -139,7 +152,7 @@ class MetadataHub(Hub[MetadataClientState]):
client,
"UserPresenceUpdated",
user_id,
- store.to_dict(),
+ store,
)
for user_id, store in self.state.items()
if store.pushable
diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py
index 72b4a52..3688efa 100644
--- a/app/signalr/hub/multiplayer.py
+++ b/app/signalr/hub/multiplayer.py
@@ -1,6 +1,1138 @@
from __future__ import annotations
-from .hub import Hub
+import asyncio
+from datetime import UTC, datetime, timedelta
+from typing import override
+
+from app.database import Room
+from app.database.beatmap import Beatmap
+from app.database.lazer_user import User
+from app.database.multiplayer_event import MultiplayerEvent
+from app.database.playlists import Playlist
+from app.database.relationship import Relationship, RelationshipType
+from app.dependencies.database import engine, get_redis
+from app.exception import InvokeException
+from app.log import logger
+from app.models.mods import APIMod
+from app.models.multiplayer_hub import (
+ BeatmapAvailability,
+ ForceGameplayStartCountdown,
+ GameplayAbortReason,
+ MatchRequest,
+ MatchServerEvent,
+ MatchStartCountdown,
+ MatchStartedEventDetail,
+ MultiplayerClientState,
+ MultiplayerRoom,
+ MultiplayerRoomSettings,
+ MultiplayerRoomUser,
+ PlaylistItem,
+ ServerMultiplayerRoom,
+ ServerShuttingDownCountdown,
+ StartMatchCountdownRequest,
+ StopCountdownRequest,
+)
+from app.models.room import (
+ DownloadState,
+ MatchType,
+ MultiplayerRoomState,
+ MultiplayerUserState,
+ RoomCategory,
+ RoomStatus,
+)
+from app.models.score import GameMode
+
+from .hub import Client, Hub
+
+from sqlalchemy import update
+from sqlmodel import col, select
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+GAMEPLAY_LOAD_TIMEOUT = 30
-class MultiplayerHub(Hub): ...
+class MultiplayerEventLogger:
+ def __init__(self):
+ pass
+
+ async def log_event(self, event: MultiplayerEvent):
+ try:
+ async with AsyncSession(engine) as session:
+ session.add(event)
+ await session.commit()
+ except Exception as e:
+ logger.warning(f"Failed to log multiplayer room event to database: {e}")
+
+ async def room_created(self, room_id: int, user_id: int):
+ event = MultiplayerEvent(
+ room_id=room_id,
+ user_id=user_id,
+ event_type="room_created",
+ )
+ await self.log_event(event)
+
+ async def room_disbanded(self, room_id: int, user_id: int):
+ event = MultiplayerEvent(
+ room_id=room_id,
+ user_id=user_id,
+ event_type="room_disbanded",
+ )
+ await self.log_event(event)
+
+ async def player_joined(self, room_id: int, user_id: int):
+ event = MultiplayerEvent(
+ room_id=room_id,
+ user_id=user_id,
+ event_type="player_joined",
+ )
+ await self.log_event(event)
+
+ async def player_left(self, room_id: int, user_id: int):
+ event = MultiplayerEvent(
+ room_id=room_id,
+ user_id=user_id,
+ event_type="player_left",
+ )
+ await self.log_event(event)
+
+ async def player_kicked(self, room_id: int, user_id: int):
+ event = MultiplayerEvent(
+ room_id=room_id,
+ user_id=user_id,
+ event_type="player_kicked",
+ )
+ await self.log_event(event)
+
+ async def host_changed(self, room_id: int, user_id: int):
+ event = MultiplayerEvent(
+ room_id=room_id,
+ user_id=user_id,
+ event_type="host_changed",
+ )
+ await self.log_event(event)
+
+ async def game_started(
+ self, room_id: int, playlist_item_id: int, details: MatchStartedEventDetail
+ ):
+ event = MultiplayerEvent(
+ room_id=room_id,
+ playlist_item_id=playlist_item_id,
+ event_type="game_started",
+ event_detail=details, # pyright: ignore[reportArgumentType]
+ )
+ await self.log_event(event)
+
+ async def game_aborted(self, room_id: int, playlist_item_id: int):
+ event = MultiplayerEvent(
+ room_id=room_id,
+ playlist_item_id=playlist_item_id,
+ event_type="game_aborted",
+ )
+ await self.log_event(event)
+
+ async def game_completed(self, room_id: int, playlist_item_id: int):
+ event = MultiplayerEvent(
+ room_id=room_id,
+ playlist_item_id=playlist_item_id,
+ event_type="game_completed",
+ )
+ await self.log_event(event)
+
+
+class MultiplayerHub(Hub[MultiplayerClientState]):
+ @override
+ def __init__(self):
+ super().__init__()
+ self.rooms: dict[int, ServerMultiplayerRoom] = {}
+ self.event_logger = MultiplayerEventLogger()
+
+ @staticmethod
+ def group_id(room: int) -> str:
+ return f"room:{room}"
+
+ @override
+ def create_state(self, client: Client) -> MultiplayerClientState:
+ return MultiplayerClientState(
+ connection_id=client.connection_id,
+ connection_token=client.connection_token,
+ )
+
+ @override
+ async def _clean_state(self, state: MultiplayerClientState):
+ user_id = int(state.connection_id)
+ if state.room_id != 0 and state.room_id in self.rooms:
+ server_room = self.rooms[state.room_id]
+ room = server_room.room
+ user = next((u for u in room.users if u.user_id == user_id), None)
+ if user is not None:
+ await self.make_user_leave(
+ self.get_client_by_id(str(user_id)), server_room, user
+ )
+
+ async def CreateRoom(self, client: Client, room: MultiplayerRoom):
+ logger.info(f"[MultiplayerHub] {client.user_id} creating room")
+ store = self.get_or_create_state(client)
+ if store.room_id != 0:
+ raise InvokeException("You are already in a room")
+ async with AsyncSession(engine) as session:
+ async with session:
+ db_room = Room(
+ name=room.settings.name,
+ category=RoomCategory.NORMAL,
+ type=room.settings.match_type,
+ queue_mode=room.settings.queue_mode,
+ auto_skip=room.settings.auto_skip,
+ auto_start_duration=int(
+ room.settings.auto_start_duration.total_seconds()
+ ),
+ host_id=client.user_id,
+ status=RoomStatus.IDLE,
+ )
+ session.add(db_room)
+ await session.commit()
+ await session.refresh(db_room)
+ item = room.playlist[0]
+ item.owner_id = client.user_id
+ room.room_id = db_room.id
+ starts_at = db_room.starts_at or datetime.now(UTC)
+ await Playlist.add_to_db(item, db_room.id, session)
+ server_room = ServerMultiplayerRoom(
+ room=room,
+ category=RoomCategory.NORMAL,
+ start_at=starts_at,
+ hub=self,
+ )
+ self.rooms[room.room_id] = server_room
+ await server_room.set_handler()
+ await self.event_logger.room_created(room.room_id, client.user_id)
+ return await self.JoinRoomWithPassword(
+ client, room.room_id, room.settings.password
+ )
+
+ async def JoinRoom(self, client: Client, room_id: int):
+ return self.JoinRoomWithPassword(client, room_id, "")
+
+ async def JoinRoomWithPassword(self, client: Client, room_id: int, password: str):
+ logger.info(f"[MultiplayerHub] {client.user_id} joining room {room_id}")
+ store = self.get_or_create_state(client)
+ if store.room_id != 0:
+ raise InvokeException("You are already in a room")
+ user = MultiplayerRoomUser(user_id=client.user_id)
+ if room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[room_id]
+ room = server_room.room
+ for u in room.users:
+ if u.user_id == client.user_id:
+ raise InvokeException("You are already in this room")
+ if room.settings.password != password:
+ raise InvokeException("Incorrect password")
+ if room.host is None:
+ # from CreateRoom
+ room.host = user
+ store.room_id = room_id
+ await self.broadcast_group_call(self.group_id(room_id), "UserJoined", user)
+ room.users.append(user)
+ self.add_to_group(client, self.group_id(room_id))
+ await server_room.match_type_handler.handle_join(user)
+ await self.event_logger.player_joined(room_id, user.user_id)
+ return room
+
+ async def ChangeBeatmapAvailability(
+ self, client: Client, beatmap_availability: BeatmapAvailability
+ ):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ availability = user.availability
+ if (
+ availability.state == beatmap_availability.state
+ and availability.download_progress == beatmap_availability.download_progress
+ ):
+ return
+ user.availability = beatmap_availability
+ await self.broadcast_group_call(
+ self.group_id(store.room_id),
+ "UserBeatmapAvailabilityChanged",
+ user.user_id,
+ (beatmap_availability),
+ )
+
+ async def AddPlaylistItem(self, client: Client, item: PlaylistItem):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ await server_room.queue.add_item(
+ item,
+ user,
+ )
+
+ async def EditPlaylistItem(self, client: Client, item: PlaylistItem):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ await server_room.queue.edit_item(
+ item,
+ user,
+ )
+
+ async def RemovePlaylistItem(self, client: Client, item_id: int):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ await server_room.queue.remove_item(
+ item_id,
+ user,
+ )
+
+ async def setting_changed(self, room: ServerMultiplayerRoom, beatmap_changed: bool):
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "SettingsChanged",
+ (room.room.settings),
+ )
+
+ async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem):
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "PlaylistItemAdded",
+ (item),
+ )
+
+ async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int):
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "PlaylistItemRemoved",
+ item_id,
+ )
+
+ async def playlist_changed(
+ self, room: ServerMultiplayerRoom, item: PlaylistItem, beatmap_changed: bool
+ ):
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "PlaylistItemChanged",
+ (item),
+ )
+
+ async def ChangeUserStyle(
+ self, client: Client, beatmap_id: int | None, ruleset_id: int | None
+ ):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ await self.change_user_style(
+ beatmap_id,
+ ruleset_id,
+ server_room,
+ user,
+ )
+
+ async def validate_styles(self, room: ServerMultiplayerRoom):
+ if not room.queue.current_item.freestyle:
+ for user in room.room.users:
+ await self.change_user_style(
+ None,
+ None,
+ room,
+ user,
+ )
+ async with AsyncSession(engine) as session:
+ beatmap = await session.get(Beatmap, room.queue.current_item.beatmap_id)
+ if beatmap is None:
+ raise InvokeException("Beatmap not found")
+ beatmap_ids = (
+ await session.exec(
+ select(Beatmap.id, Beatmap.mode).where(
+ Beatmap.beatmapset_id == beatmap.beatmapset_id,
+ )
+ )
+ ).all()
+ for user in room.room.users:
+ beatmap_id = user.beatmap_id
+ ruleset_id = user.ruleset_id
+ user_beatmap = next(
+ (b for b in beatmap_ids if b[0] == beatmap_id),
+ None,
+ )
+ if beatmap_id is not None and user_beatmap is None:
+ beatmap_id = None
+ beatmap_ruleset = user_beatmap[1] if user_beatmap else beatmap.mode
+ if (
+ ruleset_id is not None
+ and beatmap_ruleset != GameMode.OSU
+ and ruleset_id != beatmap_ruleset
+ ):
+ ruleset_id = None
+ await self.change_user_style(
+ beatmap_id,
+ ruleset_id,
+ room,
+ user,
+ )
+
+ for user in room.room.users:
+ is_valid, valid_mods = room.queue.current_item.validate_user_mods(
+ user, user.mods
+ )
+ if not is_valid:
+ await self.change_user_mods(valid_mods, room, user)
+
+ async def change_user_style(
+ self,
+ beatmap_id: int | None,
+ ruleset_id: int | None,
+ room: ServerMultiplayerRoom,
+ user: MultiplayerRoomUser,
+ ):
+ if user.beatmap_id == beatmap_id and user.ruleset_id == ruleset_id:
+ return
+
+ if beatmap_id is not None or ruleset_id is not None:
+ if not room.queue.current_item.freestyle:
+ raise InvokeException("Current item does not allow free user styles.")
+
+ async with AsyncSession(engine) as session:
+ item_beatmap = await session.get(
+ Beatmap, room.queue.current_item.beatmap_id
+ )
+ if item_beatmap is None:
+ raise InvokeException("Item beatmap not found")
+
+ user_beatmap = (
+ item_beatmap
+ if beatmap_id is None
+ else await session.get(Beatmap, beatmap_id)
+ )
+
+ if user_beatmap is None:
+ raise InvokeException("Invalid beatmap selected.")
+
+ if user_beatmap.beatmapset_id != item_beatmap.beatmapset_id:
+ raise InvokeException(
+ "Selected beatmap is not from the same beatmap set."
+ )
+
+ if (
+ ruleset_id is not None
+ and user_beatmap.mode != GameMode.OSU
+ and ruleset_id != user_beatmap.mode
+ ):
+ raise InvokeException(
+ "Selected ruleset is not supported for the given beatmap."
+ )
+
+ user.beatmap_id = beatmap_id
+ user.ruleset_id = ruleset_id
+
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "UserStyleChanged",
+ user.user_id,
+ beatmap_id,
+ ruleset_id,
+ )
+
+ async def ChangeUserMods(self, client: Client, new_mods: list[APIMod]):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ await self.change_user_mods(new_mods, server_room, user)
+
+ async def change_user_mods(
+ self,
+ new_mods: list[APIMod],
+ room: ServerMultiplayerRoom,
+ user: MultiplayerRoomUser,
+ ):
+ is_valid, valid_mods = room.queue.current_item.validate_user_mods(
+ user, new_mods
+ )
+ if not is_valid:
+ incompatible_mods = [
+ mod["acronym"] for mod in new_mods if mod not in valid_mods
+ ]
+ raise InvokeException(
+ f"Incompatible mods were selected: {','.join(incompatible_mods)}"
+ )
+
+ if user.mods == valid_mods:
+ return
+
+ user.mods = valid_mods
+
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "UserModsChanged",
+ user.user_id,
+ valid_mods,
+ )
+
+ async def validate_user_stare(
+ self,
+ room: ServerMultiplayerRoom,
+ old: MultiplayerUserState,
+ new: MultiplayerUserState,
+ ):
+ match new:
+ case MultiplayerUserState.IDLE:
+ if old.is_playing:
+ raise InvokeException(
+ "Cannot return to idle without aborting gameplay."
+ )
+ case MultiplayerUserState.READY:
+ if old != MultiplayerUserState.IDLE:
+ raise InvokeException(f"Cannot change state from {old} to {new}")
+ if room.queue.current_item.expired:
+ raise InvokeException(
+ "Cannot ready up while all items have been played."
+ )
+ case MultiplayerUserState.WAITING_FOR_LOAD:
+ raise InvokeException("Cannot change state from {old} to {new}")
+ case MultiplayerUserState.LOADED:
+ if old != MultiplayerUserState.WAITING_FOR_LOAD:
+ raise InvokeException(f"Cannot change state from {old} to {new}")
+ case MultiplayerUserState.READY_FOR_GAMEPLAY:
+ if old != MultiplayerUserState.LOADED:
+ raise InvokeException(f"Cannot change state from {old} to {new}")
+ case MultiplayerUserState.PLAYING:
+ raise InvokeException("State is managed by the server.")
+ case MultiplayerUserState.FINISHED_PLAY:
+ if old != MultiplayerUserState.PLAYING:
+ raise InvokeException(f"Cannot change state from {old} to {new}")
+ case MultiplayerUserState.RESULTS:
+ raise InvokeException("Cannot change state from {old} to {new}")
+ case MultiplayerUserState.SPECTATING:
+ if old not in (MultiplayerUserState.IDLE, MultiplayerUserState.READY):
+ raise InvokeException(f"Cannot change state from {old} to {new}")
+
+ async def ChangeState(self, client: Client, state: MultiplayerUserState):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ if user.state == state:
+ return
+ match state:
+ case MultiplayerUserState.IDLE:
+ if user.state.is_playing:
+ return
+ case MultiplayerUserState.LOADED | MultiplayerUserState.READY_FOR_GAMEPLAY:
+ if not user.state.is_playing:
+ return
+ await self.validate_user_stare(
+ server_room,
+ user.state,
+ state,
+ )
+ await self.change_user_state(server_room, user, state)
+ if state == MultiplayerUserState.SPECTATING and (
+ room.state == MultiplayerRoomState.PLAYING
+ or room.state == MultiplayerRoomState.WAITING_FOR_LOAD
+ ):
+ await self.call_noblock(client, "LoadRequested")
+ await self.update_room_state(server_room)
+
+ async def change_user_state(
+ self,
+ room: ServerMultiplayerRoom,
+ user: MultiplayerRoomUser,
+ state: MultiplayerUserState,
+ ):
+ user.state = state
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "UserStateChanged",
+ user.user_id,
+ user.state,
+ )
+
+ async def update_room_state(self, room: ServerMultiplayerRoom):
+ match room.room.state:
+ case MultiplayerRoomState.OPEN:
+ if room.room.settings.auto_start_enabled:
+ if (
+ not room.queue.current_item.expired
+ and any(
+ u.state == MultiplayerUserState.READY
+ for u in room.room.users
+ )
+ and not any(
+ isinstance(countdown, MatchStartCountdown)
+ for countdown in room.room.active_countdowns
+ )
+ ):
+ await room.start_countdown(
+ MatchStartCountdown(
+ time_remaining=room.room.settings.auto_start_duration
+ ),
+ self.start_match,
+ )
+ case MultiplayerRoomState.WAITING_FOR_LOAD:
+ played_count = len(
+ [True for user in room.room.users if user.state.is_playing]
+ )
+ ready_count = len(
+ [
+ True
+ for user in room.room.users
+ if user.state == MultiplayerUserState.READY_FOR_GAMEPLAY
+ ]
+ )
+ if played_count == ready_count:
+ await self.start_gameplay(room)
+ case MultiplayerRoomState.PLAYING:
+ if all(
+ u.state != MultiplayerUserState.PLAYING for u in room.room.users
+ ):
+ any_user_finished_playing = False
+ for u in filter(
+ lambda u: u.state == MultiplayerUserState.FINISHED_PLAY,
+ room.room.users,
+ ):
+ any_user_finished_playing = True
+ await self.change_user_state(
+ room, u, MultiplayerUserState.RESULTS
+ )
+ await self.change_room_state(room, MultiplayerRoomState.OPEN)
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "ResultsReady",
+ )
+ if any_user_finished_playing:
+ await self.event_logger.game_completed(
+ room.room.room_id,
+ room.queue.current_item.id,
+ )
+ else:
+ await self.event_logger.game_aborted(
+ room.room.room_id,
+ room.queue.current_item.id,
+ )
+ await room.queue.finish_current_item()
+
+ async def change_room_state(
+ self, room: ServerMultiplayerRoom, state: MultiplayerRoomState
+ ):
+ room.room.state = state
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "RoomStateChanged",
+ state,
+ )
+
+ async def StartMatch(self, client: Client):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+ if room.host is None or room.host.user_id != client.user_id:
+ raise InvokeException("You are not the host of this room")
+
+ # Check host state - host must be ready or spectating
+ if room.host.state not in (
+ MultiplayerUserState.SPECTATING,
+ MultiplayerUserState.READY,
+ ):
+ raise InvokeException("Can't start match when the host is not ready.")
+
+ # Check if any users are ready
+ if all(u.state != MultiplayerUserState.READY for u in room.users):
+ raise InvokeException("Can't start match when no users are ready.")
+
+ await self.start_match(server_room)
+
+ async def start_match(self, room: ServerMultiplayerRoom):
+ if room.room.state != MultiplayerRoomState.OPEN:
+ raise InvokeException("Can't start match when already in a running state.")
+ if room.queue.current_item.expired:
+ raise InvokeException("Current playlist item is expired")
+ ready_users = [
+ u
+ for u in room.room.users
+ if u.availability.state == DownloadState.LOCALLY_AVAILABLE
+ and (
+ u.state == MultiplayerUserState.READY
+ or u.state == MultiplayerUserState.IDLE
+ )
+ ]
+ await asyncio.gather(
+ *[
+ self.change_user_state(room, u, MultiplayerUserState.WAITING_FOR_LOAD)
+ for u in ready_users
+ ]
+ )
+ await self.change_room_state(
+ room,
+ MultiplayerRoomState.WAITING_FOR_LOAD,
+ )
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "LoadRequested",
+ )
+ await room.start_countdown(
+ ForceGameplayStartCountdown(
+ time_remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT)
+ ),
+ self.start_gameplay,
+ )
+ await self.event_logger.game_started(
+ room.room.room_id,
+ room.queue.current_item.id,
+ details=room.match_type_handler.get_details(),
+ )
+
+ async def start_gameplay(self, room: ServerMultiplayerRoom):
+ if room.room.state != MultiplayerRoomState.WAITING_FOR_LOAD:
+ raise InvokeException("Room is not ready for gameplay")
+ if room.queue.current_item.expired:
+ raise InvokeException("Current playlist item is expired")
+ playing = False
+ played_user = 0
+ for user in room.room.users:
+ client = self.get_client_by_id(str(user.user_id))
+ if client is None:
+ continue
+
+ if user.state in (
+ MultiplayerUserState.READY_FOR_GAMEPLAY,
+ MultiplayerUserState.LOADED,
+ ):
+ playing = True
+ played_user += 1
+ await self.change_user_state(room, user, MultiplayerUserState.PLAYING)
+ await self.call_noblock(client, "GameplayStarted")
+ elif user.state == MultiplayerUserState.WAITING_FOR_LOAD:
+ await self.change_user_state(room, user, MultiplayerUserState.IDLE)
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "GameplayAborted",
+ GameplayAbortReason.LOAD_TOOK_TOO_LONG,
+ )
+ await self.change_room_state(
+ room,
+ (MultiplayerRoomState.PLAYING if playing else MultiplayerRoomState.OPEN),
+ )
+ if playing:
+ redis = get_redis()
+ await redis.set(
+ f"multiplayer:{room.room.room_id}:gameplay:players",
+ played_user,
+ ex=3600,
+ )
+
+ async def send_match_event(
+ self, room: ServerMultiplayerRoom, event: MatchServerEvent
+ ):
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "MatchEvent",
+ event,
+ )
+
+ async def make_user_leave(
+ self,
+ client: Client,
+ room: ServerMultiplayerRoom,
+ user: MultiplayerRoomUser,
+ kicked: bool = False,
+ ):
+ self.remove_from_group(client, self.group_id(room.room.room_id))
+ room.room.users.remove(user)
+
+ if len(room.room.users) == 0:
+ await self.end_room(room)
+ await self.update_room_state(room)
+ if (
+ len(room.room.users) != 0
+ and room.room.host
+ and room.room.host.user_id == user.user_id
+ ):
+ next_host = room.room.users[0]
+ await self.set_host(room, next_host)
+
+ if kicked:
+ await self.call_noblock(client, "UserKicked", user)
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id), "UserKicked", user
+ )
+ else:
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id), "UserLeft", user
+ )
+
+ target_store = self.state.get(user.user_id)
+ if target_store:
+ target_store.room_id = 0
+
+ async def end_room(self, room: ServerMultiplayerRoom):
+ assert room.room.host
+ async with AsyncSession(engine) as session:
+ await session.execute(
+ update(Room)
+ .where(col(Room.id) == room.room.room_id)
+ .values(
+ name=room.room.settings.name,
+ ended_at=datetime.now(UTC),
+ type=room.room.settings.match_type,
+ queue_mode=room.room.settings.queue_mode,
+ auto_skip=room.room.settings.auto_skip,
+ auto_start_duration=int(
+ room.room.settings.auto_start_duration.total_seconds()
+ ),
+ host_id=room.room.host.user_id,
+ )
+ )
+ await self.event_logger.room_disbanded(
+ room.room.room_id,
+ room.room.host.user_id,
+ )
+ del self.rooms[room.room.room_id]
+
+ async def LeaveRoom(self, client: Client):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ return
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ await self.event_logger.player_left(
+ room.room_id,
+ user.user_id,
+ )
+ await self.make_user_leave(client, server_room, user)
+
+ async def KickUser(self, client: Client, user_id: int):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+
+ if room.host is None or room.host.user_id != client.user_id:
+ raise InvokeException("You are not the host of this room")
+
+ if user_id == client.user_id:
+ raise InvokeException("Can't kick self")
+
+ user = next((u for u in room.users if u.user_id == user_id), None)
+ if user is None:
+ raise InvokeException("User not found in this room")
+
+ await self.event_logger.player_kicked(
+ room.room_id,
+ user.user_id,
+ )
+ target_client = self.get_client_by_id(str(user.user_id))
+ if target_client is None:
+ return
+ await self.make_user_leave(target_client, server_room, user, kicked=True)
+
+ async def set_host(self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser):
+ room.room.host = user
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "HostChanged",
+ user.user_id,
+ )
+
+ async def TransferHost(self, client: Client, user_id: int):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+
+ if room.host is None or room.host.user_id != client.user_id:
+ raise InvokeException("You are not the host of this room")
+
+ new_host = next((u for u in room.users if u.user_id == user_id), None)
+ if new_host is None:
+ raise InvokeException("User not found in this room")
+ await self.event_logger.host_changed(
+ room.room_id,
+ new_host.user_id,
+ )
+ await self.set_host(server_room, new_host)
+
+ async def AbortGameplay(self, client: Client):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ if not user.state.is_playing:
+ raise InvokeException("Cannot abort gameplay while not in a gameplay state")
+
+ await self.change_user_state(
+ server_room,
+ user,
+ MultiplayerUserState.IDLE,
+ )
+ await self.update_room_state(server_room)
+
+ async def AbortMatch(self, client: Client):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+
+ if room.host is None or room.host.user_id != client.user_id:
+ raise InvokeException("You are not the host of this room")
+
+ if (
+ room.state != MultiplayerRoomState.PLAYING
+ and room.state != MultiplayerRoomState.WAITING_FOR_LOAD
+ ):
+ raise InvokeException("Cannot abort a match that hasn't started.")
+
+ await asyncio.gather(
+ *[
+ self.change_user_state(server_room, u, MultiplayerUserState.IDLE)
+ for u in room.users
+ if u.state.is_playing
+ ]
+ )
+ await self.broadcast_group_call(
+ self.group_id(room.room_id),
+ "GameplayAborted",
+ GameplayAbortReason.HOST_ABORTED,
+ )
+ await self.update_room_state(server_room)
+
+ async def change_user_match_state(
+ self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser
+ ):
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "MatchUserStateChanged",
+ user.user_id,
+ user.match_state,
+ )
+
+ async def change_room_match_state(self, room: ServerMultiplayerRoom):
+ await self.broadcast_group_call(
+ self.group_id(room.room.room_id),
+ "MatchRoomStateChanged",
+ room.room.match_state,
+ )
+
+ async def ChangeSettings(self, client: Client, settings: MultiplayerRoomSettings):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+
+ if room.host is None or room.host.user_id != client.user_id:
+ raise InvokeException("You are not the host of this room")
+
+ if room.state != MultiplayerRoomState.OPEN:
+ raise InvokeException("Cannot change settings while playing")
+
+ if settings.match_type == MatchType.PLAYLISTS:
+ raise InvokeException("Invalid match type selected")
+
+ previous_settings = room.settings
+ room.settings = settings
+
+ if previous_settings.match_type != settings.match_type:
+ await server_room.set_handler()
+ if previous_settings.queue_mode != settings.queue_mode:
+ await server_room.queue.update_queue_mode()
+
+ await self.setting_changed(server_room, beatmap_changed=False)
+ await self.update_room_state(server_room)
+
+ async def SendMatchRequest(self, client: Client, request: MatchRequest):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ if isinstance(request, StartMatchCountdownRequest):
+ if room.host and room.host.user_id != user.user_id:
+ raise InvokeException("You are not the host of this room")
+ if room.state != MultiplayerRoomState.OPEN:
+ raise InvokeException("Cannot start a countdown during ongoing play")
+ await server_room.start_countdown(
+ MatchStartCountdown(time_remaining=request.duration),
+ self.start_match,
+ )
+ elif isinstance(request, StopCountdownRequest):
+ countdown = next(
+ (c for c in room.active_countdowns if c.id == request.id),
+ None,
+ )
+ if countdown is None:
+ return
+ if (
+ isinstance(countdown, MatchStartCountdown)
+ and room.settings.auto_start_enabled
+ ) or isinstance(
+ countdown, (ForceGameplayStartCountdown | ServerShuttingDownCountdown)
+ ):
+ raise InvokeException("Cannot stop the requested countdown")
+
+ await server_room.stop_countdown(countdown)
+ else:
+ await server_room.match_type_handler.handle_request(user, request)
+
+ async def InvitePlayer(self, client: Client, user_id: int):
+ store = self.get_or_create_state(client)
+ if store.room_id == 0:
+ raise InvokeException("You are not in a room")
+ if store.room_id not in self.rooms:
+ raise InvokeException("Room does not exist")
+ server_room = self.rooms[store.room_id]
+ room = server_room.room
+ user = next((u for u in room.users if u.user_id == client.user_id), None)
+ if user is None:
+ raise InvokeException("You are not in this room")
+
+ async with AsyncSession(engine) as session:
+ db_user = await session.get(User, user_id)
+ target_relationship = (
+ await session.exec(
+ select(Relationship).where(
+ Relationship.user_id == user_id,
+ Relationship.target_id == client.user_id,
+ )
+ )
+ ).first()
+ inviter_relationship = (
+ await session.exec(
+ select(Relationship).where(
+ Relationship.user_id == client.user_id,
+ Relationship.target_id == user_id,
+ )
+ )
+ ).first()
+ if db_user is None:
+ raise InvokeException("User not found")
+ if db_user.id == client.user_id:
+ raise InvokeException("You cannot invite yourself")
+ if db_user.id in [u.user_id for u in room.users]:
+ raise InvokeException("User already invited")
+ if db_user.is_restricted:
+ raise InvokeException("User is restricted")
+ if (
+ inviter_relationship
+ and inviter_relationship.type == RelationshipType.BLOCK
+ ):
+ raise InvokeException("Cannot perform action due to user being blocked")
+ if (
+ target_relationship
+ and target_relationship.type == RelationshipType.BLOCK
+ ):
+ raise InvokeException("Cannot perform action due to user being blocked")
+ if (
+ db_user.pm_friends_only
+ and target_relationship is not None
+ and target_relationship.type != RelationshipType.FOLLOW
+ ):
+ raise InvokeException(
+ "Cannot perform action "
+ "because user has disabled non-friend communications"
+ )
+
+ target_client = self.get_client_by_id(str(user_id))
+ if target_client is None:
+ raise InvokeException("User is not online")
+ await self.call_noblock(
+ target_client,
+ "Invited",
+ client.user_id,
+ room.room_id,
+ room.settings.password,
+ )
diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py
index 0d0899e..b9a3c99 100644
--- a/app/signalr/hub/spectator.py
+++ b/app/signalr/hub/spectator.py
@@ -7,15 +7,13 @@ import struct
import time
from typing import override
-from app.database import Beatmap
+from app.database import Beatmap, User
from app.database.score import Score
from app.database.score_token import ScoreToken
-from app.database.user import User
from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int
-from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
-from app.models.signalr import serialize_to_list
+from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
from app.models.spectator_hub import (
APIUser,
FrameDataBundle,
@@ -70,8 +68,8 @@ def save_replay(
md5: str,
username: str,
score: Score,
- statistics: ScoreStatisticsInt,
- maximum_statistics: ScoreStatisticsInt,
+ statistics: ScoreStatistics,
+ maximum_statistics: ScoreStatistics,
frames: list[LegacyReplayFrame],
) -> None:
data = bytearray()
@@ -108,8 +106,8 @@ def save_replay(
last_time = 0
for frame in frames:
frame_strs.append(
- f"{frame.time - last_time}|{frame.x or 0.0}"
- f"|{frame.y or 0.0}|{frame.button_state}"
+ f"{frame.time - last_time}|{frame.mouse_x or 0.0}"
+ f"|{frame.mouse_y or 0.0}|{frame.button_state}"
)
last_time = frame.time
frame_strs.append("-12345|0|0|0")
@@ -166,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]):
async def on_client_connect(self, client: Client) -> None:
tasks = [
- self.call_noblock(
- client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
- )
+ self.call_noblock(client, "UserBeganPlaying", user_id, store.state)
for user_id, store in self.state.items()
if store.state is not None
]
@@ -197,7 +193,7 @@ class SpectatorHub(Hub[StoreClientState]):
).first()
if not user:
return
- name = user.name
+ name = user.username
store.state = state
store.beatmap_status = beatmap.beatmap_status
store.checksum = beatmap.checksum
@@ -215,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserBeganPlaying",
user_id,
- serialize_to_list(state),
+ state,
)
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
@@ -223,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]):
state = self.get_or_create_state(client)
if not state.score:
return
- state.score.score_info.acc = frame_data.header.acc
+ state.score.score_info.accuracy = frame_data.header.accuracy
state.score.score_info.combo = frame_data.header.combo
state.score.score_info.max_combo = frame_data.header.max_combo
state.score.score_info.statistics = frame_data.header.statistics
@@ -234,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserSentFrames",
user_id,
- frame_data.model_dump(),
+ frame_data,
)
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
@@ -297,19 +293,18 @@ class SpectatorHub(Hub[StoreClientState]):
score_record.id,
)
# save replay
- if store.state.state == SpectatedUserState.Passed:
- score_record.has_replay = True
- await session.commit()
- await session.refresh(score_record)
- save_replay(
- ruleset_id=store.ruleset_id,
- md5=store.checksum,
- username=store.score.score_info.user.name,
- score=score_record,
- statistics=store.score.score_info.statistics,
- maximum_statistics=store.score.score_info.maximum_statistics,
- frames=store.score.replay_frames,
- )
+ score_record.has_replay = True
+ await session.commit()
+ await session.refresh(score_record)
+ save_replay(
+ ruleset_id=store.ruleset_id,
+ md5=store.checksum,
+ username=store.score.score_info.user.name,
+ score=score_record,
+ statistics=store.score.score_info.statistics,
+ maximum_statistics=store.score.score_info.maximum_statistics,
+ frames=store.score.replay_frames,
+ )
async def _end_session(self, user_id: int, state: SpectatorState) -> None:
if state.state == SpectatedUserState.Playing:
@@ -318,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserFinishedPlaying",
user_id,
- serialize_to_list(state) if state else None,
+ state,
)
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
@@ -329,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]):
client,
"UserBeganPlaying",
target_id,
- serialize_to_list(target_store.state),
+ target_store.state,
)
store = self.get_or_create_state(client)
store.watched_user.add(target_id)
@@ -339,7 +334,7 @@ class SpectatorHub(Hub[StoreClientState]):
async with AsyncSession(engine) as session:
async with session.begin():
username = (
- await session.exec(select(User.name).where(User.id == user_id))
+ await session.exec(select(User.username).where(User.id == user_id))
).first()
if not username:
return
diff --git a/app/signalr/packet.py b/app/signalr/packet.py
index e361ef8..8949f4b 100644
--- a/app/signalr/packet.py
+++ b/app/signalr/packet.py
@@ -1,14 +1,24 @@
from __future__ import annotations
from dataclasses import dataclass
-from enum import IntEnum
+import datetime
+from enum import Enum, IntEnum
+import inspect
import json
+from types import NoneType, UnionType
from typing import (
Any,
Protocol as TypingProtocol,
+ Union,
+ get_args,
+ get_origin,
)
+from app.models.signalr import SignalRMeta, SignalRUnionMessage
+from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal
+
import msgpack_lazer_api as m
+from pydantic import BaseModel
SEP = b"\x1e"
@@ -73,8 +83,67 @@ class Protocol(TypingProtocol):
@staticmethod
def encode(packet: Packet) -> bytes: ...
+ @classmethod
+ def validate_object(cls, v: Any, typ: type) -> Any: ...
+
class MsgpackProtocol:
+ @classmethod
+ def serialize_msgpack(cls, v: Any) -> Any:
+ typ = v.__class__
+ if issubclass(typ, BaseModel):
+ return cls.serialize_to_list(v)
+ elif issubclass(typ, list):
+ return [cls.serialize_msgpack(item) for item in v]
+ elif issubclass(typ, datetime.datetime):
+ return [v, 0]
+ elif issubclass(typ, datetime.timedelta):
+ return int(v.total_seconds() * 10_000_000)
+ elif isinstance(v, dict):
+ return {
+ cls.serialize_msgpack(k): cls.serialize_msgpack(value)
+ for k, value in v.items()
+ }
+ elif issubclass(typ, Enum):
+ list_ = list(typ)
+ return list_.index(v) if v in list_ else v.value
+ return v
+
+ @classmethod
+ def serialize_to_list(cls, value: BaseModel) -> list[Any]:
+ values = []
+ for field, info in value.__class__.model_fields.items():
+ metadata = next(
+ (m for m in info.metadata if isinstance(m, SignalRMeta)), None
+ )
+ if metadata and metadata.member_ignore:
+ continue
+ values.append(cls.serialize_msgpack(v=getattr(value, field)))
+ if issubclass(value.__class__, SignalRUnionMessage):
+ return [value.__class__.union_type, values]
+ else:
+ return values
+
+ @staticmethod
+ def process_object(v: Any, typ: type[BaseModel]) -> Any:
+ if isinstance(v, list):
+ d = {}
+ i = 0
+ for field, info in typ.model_fields.items():
+ metadata = next(
+ (m for m in info.metadata if isinstance(m, SignalRMeta)), None
+ )
+ if metadata and metadata.member_ignore:
+ continue
+ anno = info.annotation
+ if anno is None:
+ d[camel_to_snake(field)] = v[i]
+ else:
+ d[field] = MsgpackProtocol.validate_object(v[i], anno)
+ i += 1
+ return d
+ return v
+
@staticmethod
def _encode_varint(value: int) -> bytes:
result = []
@@ -140,6 +209,53 @@ class MsgpackProtocol:
]
raise ValueError(f"Unsupported packet type: {packet_type}")
+ @classmethod
+ def validate_object(cls, v: Any, typ: type) -> Any:
+ if issubclass(typ, BaseModel):
+ return typ.model_validate(obj=cls.process_object(v, typ))
+ elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
+ return v[0]
+ elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
+ return datetime.timedelta(seconds=int(v / 10_000_000))
+ elif get_origin(typ) is list:
+ return [cls.validate_object(item, get_args(typ)[0]) for item in v]
+ elif inspect.isclass(typ) and issubclass(typ, Enum):
+ list_ = list(typ)
+ return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
+ elif get_origin(typ) is dict:
+ return {
+ cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
+ v, get_args(typ)[1]
+ )
+ for k, v in v.items()
+ }
+ elif (origin := get_origin(typ)) is Union or origin is UnionType:
+ args = get_args(typ)
+ if len(args) == 2 and NoneType in args:
+ non_none_args = [arg for arg in args if arg is not NoneType]
+ if len(non_none_args) == 1:
+ if v is None:
+ return None
+ return cls.validate_object(v, non_none_args[0])
+
+ # suppose use `MessagePack-CSharp Union | None`
+ # except `X (Other Type) | None`
+ if NoneType in args and v is None:
+ return None
+ if not all(
+ issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
+ ):
+ raise ValueError(
+ f"Cannot validate {v} to {typ}, "
+ "only SignalRUnionMessage subclasses are supported"
+ )
+ union_type = v[0]
+ for arg in args:
+ assert issubclass(arg, SignalRUnionMessage)
+ if arg.union_type == union_type:
+ return cls.validate_object(v[1], arg)
+ return v
+
@staticmethod
def encode(packet: Packet) -> bytes:
payload = [packet.type.value, packet.header or {}]
@@ -151,20 +267,24 @@ class MsgpackProtocol:
]
)
if packet.arguments is not None:
- payload.append(packet.arguments)
+ payload.append(
+ [MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments]
+ )
if packet.stream_ids is not None:
payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket):
result_kind = 2
if packet.error:
result_kind = 1
- elif packet.result is None:
+ elif packet.result is not None:
result_kind = 3
payload.extend(
[
packet.invocation_id,
result_kind,
- packet.error or packet.result or None,
+ packet.error
+ or MsgpackProtocol.serialize_msgpack(packet.result)
+ or None,
]
)
elif isinstance(packet, ClosePacket):
@@ -181,6 +301,86 @@ class MsgpackProtocol:
class JSONProtocol:
+ @classmethod
+ def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False):
+ typ = v.__class__
+ if issubclass(typ, BaseModel):
+ return cls.serialize_model(v, in_union)
+ elif isinstance(v, dict):
+ return {
+ cls.serialize_to_json(k, True): cls.serialize_to_json(value)
+ for k, value in v.items()
+ }
+ elif isinstance(v, list):
+ return [cls.serialize_to_json(item) for item in v]
+ elif isinstance(v, datetime.datetime):
+ return v.isoformat()
+ elif isinstance(v, datetime.timedelta):
+ # d.hh:mm:ss
+ total_seconds = int(v.total_seconds())
+ hours, remainder = divmod(total_seconds, 3600)
+ minutes, seconds = divmod(remainder, 60)
+ return f"{hours:02}:{minutes:02}:{seconds:02}"
+ elif isinstance(v, Enum) and dict_key:
+ return v.value
+ elif isinstance(v, Enum):
+ list_ = list(typ)
+ return list_.index(v)
+ return v
+
+ @classmethod
+ def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]:
+ d = {}
+ is_union = issubclass(v.__class__, SignalRUnionMessage)
+ for field, info in v.__class__.model_fields.items():
+ metadata = next(
+ (m for m in info.metadata if isinstance(m, SignalRMeta)), None
+ )
+ if metadata and metadata.json_ignore:
+ continue
+ name = (
+ snake_to_camel(
+ field,
+ metadata.use_abbr if metadata else True,
+ )
+ if not is_union
+ else snake_to_pascal(
+ field,
+ metadata.use_abbr if metadata else True,
+ )
+ )
+ d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union)
+ if is_union and not in_union:
+ return {
+ "$dtype": v.__class__.__name__,
+ "$value": d,
+ }
+ return d
+
+ @staticmethod
+ def process_object(
+ v: Any, typ: type[BaseModel], from_union: bool = False
+ ) -> dict[str, Any]:
+ d = {}
+ for field, info in typ.model_fields.items():
+ metadata = next(
+ (m for m in info.metadata if isinstance(m, SignalRMeta)), None
+ )
+ if metadata and metadata.json_ignore:
+ continue
+ name = (
+ snake_to_camel(field, metadata.use_abbr if metadata else True)
+ if not from_union
+ else snake_to_pascal(field, metadata.use_abbr if metadata else True)
+ )
+ value = v.get(name)
+ anno = typ.model_fields[field].annotation
+ if anno is None:
+ d[field] = value
+ continue
+ d[field] = JSONProtocol.validate_object(value, anno)
+ return d
+
@staticmethod
def decode(input: bytes) -> list[Packet]:
packets_raw = input.removesuffix(SEP).split(SEP)
@@ -225,6 +425,63 @@ class JSONProtocol:
]
raise ValueError(f"Unsupported packet type: {packet_type}")
+ @classmethod
+ def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
+ if issubclass(typ, BaseModel):
+ return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
+ elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
+ return datetime.datetime.fromisoformat(v)
+ elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
+ # d.hh:mm:ss
+ parts = v.split(":")
+ if len(parts) == 3:
+ return datetime.timedelta(
+ hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2])
+ )
+ elif len(parts) == 2:
+ return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
+ elif len(parts) == 1:
+ return datetime.timedelta(seconds=int(parts[0]))
+ elif get_origin(typ) is list:
+ return [cls.validate_object(item, get_args(typ)[0]) for item in v]
+ elif inspect.isclass(typ) and issubclass(typ, Enum):
+ list_ = list(typ)
+ return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
+ elif get_origin(typ) is dict:
+ return {
+ cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
+ v, get_args(typ)[1]
+ )
+ for k, v in v.items()
+ }
+ elif (origin := get_origin(typ)) is Union or origin is UnionType:
+ args = get_args(typ)
+ if len(args) == 2 and NoneType in args:
+ non_none_args = [arg for arg in args if arg is not NoneType]
+ if len(non_none_args) == 1:
+ if v is None:
+ return None
+ return cls.validate_object(v, non_none_args[0])
+
+ # suppose use `MessagePack-CSharp Union | None`
+ # except `X (Other Type) | None`
+ if NoneType in args and v is None:
+ return None
+ if not all(
+ issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
+ ):
+ raise ValueError(
+ f"Cannot validate {v} to {typ}, "
+ "only SignalRUnionMessage subclasses are supported"
+ )
+ # https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
+ union_type = v["$dtype"]
+ for arg in args:
+ assert issubclass(arg, SignalRUnionMessage)
+ if arg.__name__ == union_type:
+ return cls.validate_object(v["$value"], arg, True)
+ return v
+
@staticmethod
def encode(packet: Packet) -> bytes:
payload: dict[str, Any] = {
@@ -241,7 +498,9 @@ class JSONProtocol:
if packet.invocation_id is not None:
payload["invocationId"] = packet.invocation_id
if packet.arguments is not None:
- payload["arguments"] = packet.arguments
+ payload["arguments"] = [
+ JSONProtocol.serialize_to_json(arg) for arg in packet.arguments
+ ]
if packet.stream_ids is not None:
payload["streamIds"] = packet.stream_ids
elif isinstance(packet, CompletionPacket):
@@ -253,7 +512,7 @@ class JSONProtocol:
if packet.error is not None:
payload["error"] = packet.error
if packet.result is not None:
- payload["result"] = packet.result
+ payload["result"] = JSONProtocol.serialize_to_json(packet.result)
elif isinstance(packet, PingPacket):
pass
elif isinstance(packet, ClosePacket):
diff --git a/app/utils.py b/app/utils.py
index 9008706..22f06dd 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -1,465 +1,92 @@
from __future__ import annotations
-from datetime import UTC, datetime
-
-from app.database import (
- LazerUserCounts,
- LazerUserProfile,
- LazerUserStatistics,
- User as DBUser,
-)
-from app.models.user import (
- Country,
- Cover,
- DailyChallengeStats,
- GradeCounts,
- Kudosu,
- Level,
- Page,
- RankHighest,
- RankHistory,
- Statistics,
- User,
- UserAchievement,
-)
-
def unix_timestamp_to_windows(timestamp: int) -> int:
"""Convert a Unix timestamp to a Windows timestamp."""
return (timestamp + 62135596800) * 10_000_000
-async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
- """将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
-
- # 从db_user获取基本字段值
- user_id = getattr(db_user, "id")
- user_name = getattr(db_user, "name")
- user_country = getattr(db_user, "country")
- user_country_code = user_country # 在User模型中,country字段就是country_code
-
- # 获取 Lazer 用户资料
- profile = db_user.lazer_profile
- if not profile:
- # 如果没有 lazer 资料,使用默认值
- profile = LazerUserProfile(
- user_id=user_id,
- )
-
- # 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系
- lzrcnt = db_user.lazer_counts
-
- if not lzrcnt:
- # 如果没有 lazer 计数,使用默认值
- lzrcnt = LazerUserCounts(user_id=user_id)
-
- # 获取指定模式的统计信息
- user_stats = None
- if db_user.lazer_statistics:
- for stat in db_user.lazer_statistics:
- if stat.mode == ruleset:
- user_stats = stat
- break
-
- if not user_stats:
- # 如果没有找到指定模式的统计,创建默认统计
- user_stats = LazerUserStatistics(user_id=user_id)
-
- # 获取国家信息
- country_code = db_user.country_code if db_user.country_code is not None else "XX"
-
- country = Country(code=str(country_code), name=get_country_name(str(country_code)))
-
- # 获取 Kudosu 信息
- kudosu = Kudosu(available=0, total=0)
-
- # 获取计数信息
- # counts = LazerUserCounts(user_id=user_id)
-
- # 转换统计信息
- statistics = Statistics(
- count_100=user_stats.count_100,
- count_300=user_stats.count_300,
- count_50=user_stats.count_50,
- count_miss=user_stats.count_miss,
- level=Level(
- current=user_stats.level_current, progress=user_stats.level_progress
- ),
- global_rank=user_stats.global_rank,
- global_rank_exp=user_stats.global_rank_exp,
- pp=float(user_stats.pp) if user_stats.pp else 0.0,
- pp_exp=float(user_stats.pp_exp) if user_stats.pp_exp else 0.0,
- ranked_score=user_stats.ranked_score,
- hit_accuracy=float(user_stats.hit_accuracy) if user_stats.hit_accuracy else 0.0,
- play_count=user_stats.play_count,
- play_time=user_stats.play_time,
- total_score=user_stats.total_score,
- total_hits=user_stats.total_hits,
- maximum_combo=user_stats.maximum_combo,
- replays_watched_by_others=user_stats.replays_watched_by_others,
- is_ranked=user_stats.is_ranked,
- grade_counts=GradeCounts(
- ss=user_stats.grade_ss,
- ssh=user_stats.grade_ssh,
- s=user_stats.grade_s,
- sh=user_stats.grade_sh,
- a=user_stats.grade_a,
- ),
- country_rank=user_stats.country_rank,
- rank={"country": user_stats.country_rank} if user_stats.country_rank else None,
- )
-
- # 转换所有模式的统计信息
- statistics_rulesets = {}
- if db_user.lazer_statistics:
- for stat in db_user.lazer_statistics:
- statistics_rulesets[stat.mode] = Statistics(
- count_100=stat.count_100,
- count_300=stat.count_300,
- count_50=stat.count_50,
- count_miss=stat.count_miss,
- level=Level(current=stat.level_current, progress=stat.level_progress),
- global_rank=stat.global_rank,
- global_rank_exp=stat.global_rank_exp,
- pp=float(stat.pp) if stat.pp else 0.0,
- pp_exp=float(stat.pp_exp) if stat.pp_exp else 0.0,
- ranked_score=stat.ranked_score,
- hit_accuracy=float(stat.hit_accuracy) if stat.hit_accuracy else 0.0,
- play_count=stat.play_count,
- play_time=stat.play_time,
- total_score=stat.total_score,
- total_hits=stat.total_hits,
- maximum_combo=stat.maximum_combo,
- replays_watched_by_others=stat.replays_watched_by_others,
- is_ranked=stat.is_ranked,
- grade_counts=GradeCounts(
- ss=stat.grade_ss,
- ssh=stat.grade_ssh,
- s=stat.grade_s,
- sh=stat.grade_sh,
- a=stat.grade_a,
- ),
- country_rank=stat.country_rank,
- rank={"country": stat.country_rank} if stat.country_rank else None,
- )
-
- # 转换国家信息
- country = Country(code=user_country_code, name=get_country_name(user_country_code))
-
- # 转换封面信息
- cover_url = (
- profile.cover_url
- if profile and profile.cover_url
- else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
- )
- cover = Cover(
- custom_url=profile.cover_url if profile else None, url=str(cover_url), id=None
- )
-
- # 转换 Kudosu 信息
- kudosu = Kudosu(available=0, total=0)
-
- # 转换成就信息
- user_achievements = []
- if db_user.lazer_achievements:
- for achievement in db_user.lazer_achievements:
- user_achievements.append(
- UserAchievement(
- achieved_at=achievement.achieved_at,
- achievement_id=achievement.achievement_id,
- )
- )
-
- # 转换排名历史
- rank_history = None
- rank_history_data = None
- for rh in db_user.rank_history:
- if rh.mode == ruleset:
- rank_history_data = rh.rank_data
- break
-
- if rank_history_data:
- rank_history = RankHistory(mode=ruleset, data=rank_history_data)
-
- # 转换每日挑战统计
- # daily_challenge_stats = None
- # if db_user.daily_challenge_stats:
- # dcs = db_user.daily_challenge_stats
- # daily_challenge_stats = DailyChallengeStats(
- # daily_streak_best=dcs.daily_streak_best,
- # daily_streak_current=dcs.daily_streak_current,
- # last_update=dcs.last_update,
- # last_weekly_streak=dcs.last_weekly_streak,
- # playcount=dcs.playcount,
- # top_10p_placements=dcs.top_10p_placements,
- # top_50p_placements=dcs.top_50p_placements,
- # user_id=dcs.user_id,
- # weekly_streak_best=dcs.weekly_streak_best,
- # weekly_streak_current=dcs.weekly_streak_current,
- # )
-
- # 转换最高排名
- rank_highest = None
- if user_stats.rank_highest:
- rank_highest = RankHighest(
- rank=user_stats.rank_highest,
- updated_at=user_stats.rank_highest_updated_at or datetime.utcnow(),
- )
-
- # 转换团队信息
- team = None
- if db_user.team_membership:
- team_member = db_user.team_membership # 假设用户只属于一个团队
- team = team_member.team
-
- # 创建用户对象
- # 从db_user获取基本字段值
- user_id = getattr(db_user, "id")
- user_name = getattr(db_user, "name")
- user_country = getattr(db_user, "country")
-
- # 获取用户头像URL
- avatar_url = None
-
- # 首先检查 profile 中的 avatar_url
- if profile and hasattr(profile, "avatar_url") and profile.avatar_url:
- avatar_url = str(profile.avatar_url)
-
- # 然后检查是否有关联的头像记录
- if avatar_url is None and hasattr(db_user, "avatar") and db_user.avatar is not None:
- if db_user.avatar.r2_game_url:
- # 优先使用游戏用的头像URL
- avatar_url = str(db_user.avatar.r2_game_url)
- elif db_user.avatar.r2_original_url:
- # 其次使用原始头像URL
- avatar_url = str(db_user.avatar.r2_original_url)
-
- # 如果还是没有找到,通过查询获取
- # if db_session and avatar_url is None:
- # try:
- # # 导入UserAvatar模型
-
- # # 尝试查找用户的头像记录
- # statement = select(UserAvatar).where(
- # UserAvatar.user_id == user_id, UserAvatar.is_active == True
- # )
- # avatar_record = db_session.exec(statement).first()
- # if avatar_record is not None:
- # if avatar_record.r2_game_url is not None:
- # # 优先使用游戏用的头像URL
- # avatar_url = str(avatar_record.r2_game_url)
- # elif avatar_record.r2_original_url is not None:
- # # 其次使用原始头像URL
- # avatar_url = str(avatar_record.r2_original_url)
- # except Exception as e:
- # print(f"获取用户头像时出错: {e}")
- # print(f"最终头像URL: {avatar_url}")
- # 如果仍然没有找到头像URL,则使用默认URL
- if avatar_url is None:
- avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1"
-
- # 处理 profile_order 列表排序
- profile_order = [
- "me",
- "recent_activity",
- "top_ranks",
- "medals",
- "historical",
- "beatmaps",
- "kudosu",
- ]
- if profile and profile.profile_order:
- profile_order = profile.profile_order.split(",")
-
- # 在convert_db_user_to_api_user函数中添加active_tournament_banners处理
- active_tournament_banners = []
- if db_user.active_banners:
- for banner in db_user.active_banners:
- active_tournament_banners.append(
- {
- "tournament_id": banner.tournament_id,
- "image_url": banner.image_url,
- "is_active": banner.is_active,
- }
- )
-
- # 在convert_db_user_to_api_user函数中添加badges处理
- badges = []
- if db_user.lazer_badges:
- for badge in db_user.lazer_badges:
- badges.append(
- {
- "badge_id": badge.badge_id,
- "awarded_at": badge.awarded_at,
- "description": badge.description,
- "image_url": badge.image_url,
- "url": badge.url,
- }
- )
-
- # 在convert_db_user_to_api_user函数中添加monthly_playcounts处理
- monthly_playcounts = []
- if db_user.lazer_monthly_playcounts:
- for playcount in db_user.lazer_monthly_playcounts:
- monthly_playcounts.append(
- {
- "start_date": playcount.start_date.isoformat()
- if playcount.start_date
- else None,
- "play_count": playcount.play_count,
- }
- )
-
- # 在convert_db_user_to_api_user函数中添加previous_usernames处理
- previous_usernames = []
- if db_user.lazer_previous_usernames:
- for username in db_user.lazer_previous_usernames:
- previous_usernames.append(
- {
- "username": username.username,
- "changed_at": username.changed_at.isoformat()
- if username.changed_at
- else None,
- }
- )
-
- # 在convert_db_user_to_api_user函数中添加replays_watched_counts处理
- replays_watched_counts = []
- if hasattr(db_user, "lazer_replays_watched") and db_user.lazer_replays_watched:
- for replay in db_user.lazer_replays_watched:
- replays_watched_counts.append(
- {
- "start_date": replay.start_date.isoformat()
- if replay.start_date
- else None,
- "count": replay.count,
- }
- )
-
- # 创建用户对象
- user = User(
- id=user_id,
- username=user_name,
- avatar_url=avatar_url,
- country_code=str(country_code),
- default_group=profile.default_group if profile else "default",
- is_active=profile.is_active,
- is_bot=profile.is_bot,
- is_deleted=profile.is_deleted,
- is_online=profile.is_online,
- is_supporter=profile.is_supporter,
- is_restricted=profile.is_restricted,
- last_visit=db_user.last_visit,
- pm_friends_only=profile.pm_friends_only,
- profile_colour=profile.profile_colour,
- cover_url=profile.cover_url
- if profile and profile.cover_url
- else "https://assets.ppy.sh/user-profile-covers/default.jpeg",
- discord=profile.discord if profile else None,
- has_supported=profile.has_supported if profile else False,
- interests=profile.interests if profile else None,
- join_date=profile.join_date if profile.join_date else datetime.now(UTC),
- location=profile.location if profile else None,
- max_blocks=profile.max_blocks if profile and profile.max_blocks else 100,
- max_friends=profile.max_friends if profile and profile.max_friends else 500,
- post_count=profile.post_count if profile and profile.post_count else 0,
- profile_hue=profile.profile_hue if profile and profile.profile_hue else None,
- profile_order=profile_order, # 使用排序后的 profile_order
- title=profile.title if profile else None,
- title_url=profile.title_url if profile else None,
- twitter=profile.twitter if profile else None,
- website=profile.website if profile else None,
- session_verified=True,
- support_level=profile.support_level if profile else 0,
- country=country,
- cover=cover,
- kudosu=kudosu,
- statistics=statistics,
- statistics_rulesets=statistics_rulesets,
- beatmap_playcounts_count=lzrcnt.beatmap_playcounts_count if lzrcnt else 0,
- comments_count=lzrcnt.comments_count if lzrcnt else 0,
- favourite_beatmapset_count=lzrcnt.favourite_beatmapset_count if lzrcnt else 0,
- follower_count=lzrcnt.follower_count if lzrcnt else 0,
- graveyard_beatmapset_count=lzrcnt.graveyard_beatmapset_count if lzrcnt else 0,
- guest_beatmapset_count=lzrcnt.guest_beatmapset_count if lzrcnt else 0,
- loved_beatmapset_count=lzrcnt.loved_beatmapset_count if lzrcnt else 0,
- mapping_follower_count=lzrcnt.mapping_follower_count if lzrcnt else 0,
- nominated_beatmapset_count=lzrcnt.nominated_beatmapset_count if lzrcnt else 0,
- pending_beatmapset_count=lzrcnt.pending_beatmapset_count if lzrcnt else 0,
- ranked_beatmapset_count=lzrcnt.ranked_beatmapset_count if lzrcnt else 0,
- ranked_and_approved_beatmapset_count=lzrcnt.ranked_and_approved_beatmapset_count
- if lzrcnt
- else 0,
- unranked_beatmapset_count=lzrcnt.unranked_beatmapset_count if lzrcnt else 0,
- scores_best_count=lzrcnt.scores_best_count if lzrcnt else 0,
- scores_first_count=lzrcnt.scores_first_count if lzrcnt else 0,
- scores_pinned_count=lzrcnt.scores_pinned_count,
- scores_recent_count=lzrcnt.scores_recent_count if lzrcnt else 0,
- account_history=[], # TODO: 获取用户历史账户信息
- # active_tournament_banner=len(active_tournament_banners),
- active_tournament_banners=active_tournament_banners,
- badges=badges,
- current_season_stats=None,
- daily_challenge_user_stats=DailyChallengeStats(
- user_id=user_id,
- daily_streak_best=db_user.daily_challenge_stats.daily_streak_best
- if db_user.daily_challenge_stats
- else 0,
- daily_streak_current=db_user.daily_challenge_stats.daily_streak_current
- if db_user.daily_challenge_stats
- else 0,
- last_update=db_user.daily_challenge_stats.last_update
- if db_user.daily_challenge_stats
- else None,
- last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak
- if db_user.daily_challenge_stats
- else None,
- playcount=db_user.daily_challenge_stats.playcount
- if db_user.daily_challenge_stats
- else 0,
- top_10p_placements=db_user.daily_challenge_stats.top_10p_placements
- if db_user.daily_challenge_stats
- else 0,
- top_50p_placements=db_user.daily_challenge_stats.top_50p_placements
- if db_user.daily_challenge_stats
- else 0,
- weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best
- if db_user.daily_challenge_stats
- else 0,
- weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current
- if db_user.daily_challenge_stats
- else 0,
- ),
- groups=[],
- monthly_playcounts=monthly_playcounts,
- page=Page(html=profile.page_html or "", raw=profile.page_raw or "")
- if profile.page_html or profile.page_raw
- else Page(),
- previous_usernames=previous_usernames,
- rank_highest=rank_highest,
- rank_history=rank_history,
- rankHistory=rank_history,
- replays_watched_counts=replays_watched_counts,
- team=team,
- user_achievements=user_achievements,
- )
-
- return user
+def camel_to_snake(name: str) -> str:
+ """Convert a camelCase string to snake_case."""
+ result = []
+ last_chr = ""
+ for char in name:
+ if char.isupper():
+ if not last_chr.isupper() and result:
+ result.append("_")
+ result.append(char.lower())
+ else:
+ result.append(char)
+ last_chr = char
+ return "".join(result)
-def get_country_name(country_code: str) -> str:
- """根据国家代码获取国家名称"""
- country_names = {
- "CN": "China",
- "JP": "Japan",
- "US": "United States",
- "GB": "United Kingdom",
- "DE": "Germany",
- "FR": "France",
- "KR": "South Korea",
- "CA": "Canada",
- "AU": "Australia",
- "BR": "Brazil",
- # 可以添加更多国家
+def snake_to_camel(name: str, use_abbr: bool = True) -> str:
+ """Convert a snake_case string to camelCase."""
+ if not name:
+ return name
+
+ parts = name.split("_")
+ if not parts:
+ return name
+
+ # 常见缩写词列表
+ abbreviations = {
+ "id",
+ "url",
+ "api",
+ "http",
+ "https",
+ "xml",
+ "json",
+ "css",
+ "html",
+ "sql",
+ "db",
}
- return country_names.get(country_code, "Unknown")
+
+ result = []
+ for part in parts:
+ if part.lower() in abbreviations and use_abbr:
+ result.append(part.upper())
+ else:
+ if result:
+ result.append(part.capitalize())
+ else:
+ result.append(part.lower())
+
+ return "".join(result)
+
+
+def snake_to_pascal(name: str, use_abbr: bool = True) -> str:
+ """Convert a snake_case string to PascalCase."""
+ if not name:
+ return name
+
+ parts = name.split("_")
+ if not parts:
+ return name
+
+ # 常见缩写词列表
+ abbreviations = {
+ "id",
+ "url",
+ "api",
+ "http",
+ "https",
+ "xml",
+ "json",
+ "css",
+ "html",
+ "sql",
+ "db",
+ }
+
+ result = []
+ for part in parts:
+ if part.lower() in abbreviations and use_abbr:
+ result.append(part.upper())
+ else:
+ result.append(part.capitalize())
+
+ return "".join(result)
diff --git a/main.py b/main.py
index 526d593..b12f543 100644
--- a/main.py
+++ b/main.py
@@ -4,25 +4,27 @@ from contextlib import asynccontextmanager
from datetime import datetime
from app.config import settings
-from app.database import Team # noqa: F401
-from app.dependencies.database import create_tables, engine
+from app.dependencies.database import create_tables, engine, redis_client
from app.dependencies.fetcher import get_fetcher
-from app.models.user import User
-from app.router import api_router, auth_router, fetcher_router, signalr_router
+from app.router import (
+ api_router,
+ auth_router,
+ fetcher_router,
+ signalr_router,
+)
from fastapi import FastAPI
-User.model_rebuild()
-
@asynccontextmanager
async def lifespan(app: FastAPI):
# on startup
await create_tables()
- get_fetcher() # 初始化 fetcher
+ await get_fetcher() # 初始化 fetcher
# on shutdown
yield
await engine.dispose()
+ await redis_client.aclose()
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
@@ -44,104 +46,6 @@ async def health_check():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
-# @app.get("/api/v2/friends")
-# async def get_friends():
-# return JSONResponse(
-# content=[
-# {
-# "id": 123456,
-# "username": "BestFriend",
-# "is_online": True,
-# "is_supporter": False,
-# "country": {"code": "US", "name": "United States"},
-# }
-# ]
-# )
-
-
-# @app.get("/api/v2/notifications")
-# async def get_notifications():
-# return JSONResponse(content={"notifications": [], "unread_count": 0})
-
-
-# @app.post("/api/v2/chat/ack")
-# async def chat_ack():
-# return JSONResponse(content={"status": "ok"})
-
-
-# @app.get("/api/v2/users/{user_id}/{mode}")
-# async def get_user_mode(user_id: int, mode: str):
-# return JSONResponse(
-# content={
-# "id": user_id,
-# "username": "测试测试测",
-# "statistics": {
-# "level": {"current": 97, "progress": 96},
-# "pp": 114514,
-# "global_rank": 666,
-# "country_rank": 1,
-# "hit_accuracy": 100,
-# },
-# "country": {"code": "JP", "name": "Japan"},
-# }
-# )
-
-
-# @app.get("/api/v2/me")
-# async def get_me():
-# return JSONResponse(
-# content={
-# "id": 15651670,
-# "username": "Googujiang",
-# "is_online": True,
-# "country": {"code": "JP", "name": "Japan"},
-# "statistics": {
-# "level": {"current": 97, "progress": 96},
-# "pp": 2826.26,
-# "global_rank": 298026,
-# "country_rank": 11220,
-# "hit_accuracy": 95.7168,
-# },
-# }
-# )
-
-
-# @app.post("/signalr/metadata/negotiate")
-# async def metadata_negotiate(negotiateVersion: int = 1):
-# return JSONResponse(
-# content={
-# "connectionId": "abc123",
-# "availableTransports": [
-# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
-# ],
-# }
-# )
-
-
-# @app.post("/signalr/spectator/negotiate")
-# async def spectator_negotiate(negotiateVersion: int = 1):
-# return JSONResponse(
-# content={
-# "connectionId": "spec456",
-# "availableTransports": [
-# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
-# ],
-# }
-# )
-
-
-# @app.post("/signalr/multiplayer/negotiate")
-# async def multiplayer_negotiate(negotiateVersion: int = 1):
-# return JSONResponse(
-# content={
-# "connectionId": "multi789",
-# "availableTransports": [
-# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
-# ],
-# }
-# )
-
-
if __name__ == "__main__":
from app.log import logger # noqa: F401
diff --git a/migrations/versions/78be13c71791_score_remove_best_id_in_database.py b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py
similarity index 62%
rename from migrations/versions/78be13c71791_score_remove_best_id_in_database.py
rename to migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py
index d0cab2b..84bae15 100644
--- a/migrations/versions/78be13c71791_score_remove_best_id_in_database.py
+++ b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py
@@ -1,8 +1,8 @@
-"""score: remove best_id in database
+"""beatmapset: support favourite count
-Revision ID: 78be13c71791
-Revises: dc4d25c428c7
-Create Date: 2025-07-29 07:57:33.764517
+Revision ID: 1178d0758ebf
+Revises:
+Create Date: 2025-08-01 04:05:09.882800
"""
@@ -15,8 +15,8 @@ import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
-revision: str = "78be13c71791"
-down_revision: str | Sequence[str] | None = "dc4d25c428c7"
+revision: str = "1178d0758ebf"
+down_revision: str | Sequence[str] | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
@@ -24,7 +24,7 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_column("scores", "best_id")
+ op.drop_column("beatmapsets", "favourite_count")
# ### end Alembic commands ###
@@ -32,7 +32,9 @@ def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
- "scores",
- sa.Column("best_id", mysql.INTEGER(), autoincrement=False, nullable=True),
+ "beatmapsets",
+ sa.Column(
+ "favourite_count", mysql.INTEGER(), autoincrement=False, nullable=False
+ ),
)
# ### end Alembic commands ###
diff --git a/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py b/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py
new file mode 100644
index 0000000..e383621
--- /dev/null
+++ b/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py
@@ -0,0 +1,54 @@
+"""relationship: fix unique relationship
+
+Revision ID: 58a11441d302
+Revises: 1178d0758ebf
+Create Date: 2025-08-01 04:23:02.498166
+
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import mysql
+
+# revision identifiers, used by Alembic.
+revision: str = "58a11441d302"
+down_revision: str | Sequence[str] | None = "1178d0758ebf"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ """Upgrade schema."""
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column(
+ "relationship",
+ sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
+ )
+ op.drop_constraint("PRIMARY", "relationship", type_="primary")
+ op.create_primary_key("pk_relationship", "relationship", ["id"])
+ op.alter_column(
+ "relationship", "user_id", existing_type=mysql.BIGINT(), nullable=True
+ )
+ op.alter_column(
+ "relationship", "target_id", existing_type=mysql.BIGINT(), nullable=True
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ """Downgrade schema."""
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_constraint("pk_relationship", "relationship", type_="primary")
+ op.create_primary_key("PRIMARY", "relationship", ["user_id", "target_id"])
+ op.alter_column(
+ "relationship", "target_id", existing_type=mysql.BIGINT(), nullable=False
+ )
+ op.alter_column(
+ "relationship", "user_id", existing_type=mysql.BIGINT(), nullable=False
+ )
+ op.drop_column("relationship", "id")
+ # ### end Alembic commands ###
diff --git a/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py b/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py
new file mode 100644
index 0000000..74f2e56
--- /dev/null
+++ b/migrations/versions/d0c1b2cefe91_playlist_index_playlist_id.py
@@ -0,0 +1,89 @@
+"""playlist: index playlist id
+
+Revision ID: d0c1b2cefe91
+Revises: 58a11441d302
+Create Date: 2025-08-06 06:02:10.512616
+
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision: str = "d0c1b2cefe91"
+down_revision: str | Sequence[str] | None = "58a11441d302"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ """Upgrade schema."""
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_index(
+ op.f("ix_room_playlists_id"), "room_playlists", ["id"], unique=False
+ )
+ op.create_table(
+ "playlist_best_scores",
+ sa.Column("user_id", sa.BigInteger(), nullable=True),
+ sa.Column("score_id", sa.BigInteger(), nullable=False),
+ sa.Column("room_id", sa.Integer(), nullable=False),
+ sa.Column("playlist_id", sa.Integer(), nullable=False),
+ sa.Column("total_score", sa.BigInteger(), nullable=True),
+ sa.ForeignKeyConstraint(
+ ["playlist_id"],
+ ["room_playlists.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["room_id"],
+ ["rooms.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["score_id"],
+ ["scores.id"],
+ ),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["lazer_users.id"],
+ ),
+ sa.PrimaryKeyConstraint("score_id"),
+ )
+ op.create_index(
+ op.f("ix_playlist_best_scores_playlist_id"),
+ "playlist_best_scores",
+ ["playlist_id"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_playlist_best_scores_room_id"),
+ "playlist_best_scores",
+ ["room_id"],
+ unique=False,
+ )
+ op.create_index(
+ op.f("ix_playlist_best_scores_user_id"),
+ "playlist_best_scores",
+ ["user_id"],
+ unique=False,
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ """Downgrade schema."""
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_index(
+ op.f("ix_playlist_best_scores_user_id"), table_name="playlist_best_scores"
+ )
+ op.drop_index(
+ op.f("ix_playlist_best_scores_room_id"), table_name="playlist_best_scores"
+ )
+ op.drop_index(
+ op.f("ix_playlist_best_scores_playlist_id"), table_name="playlist_best_scores"
+ )
+ op.drop_table("playlist_best_scores")
+ op.drop_index(op.f("ix_room_playlists_id"), table_name="room_playlists")
+ # ### end Alembic commands ###
diff --git a/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py b/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py
deleted file mode 100644
index d90ec3d..0000000
--- a/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py
+++ /dev/null
@@ -1,36 +0,0 @@
-"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator
-
-Revision ID: dc4d25c428c7
-Revises:
-Create Date: 2025-07-29 01:43:40.221070
-
-"""
-
-from __future__ import annotations
-
-from collections.abc import Sequence
-
-from alembic import op
-import sqlalchemy as sa
-
-# revision identifiers, used by Alembic.
-revision: str = "dc4d25c428c7"
-down_revision: str | Sequence[str] | None = None
-branch_labels: str | Sequence[str] | None = None
-depends_on: str | Sequence[str] | None = None
-
-
-def upgrade() -> None:
- """Upgrade schema."""
- # ### commands auto generated by Alembic - please adjust! ###
- op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True))
- op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True))
- # ### end Alembic commands ###
-
-
-def downgrade() -> None:
- """Downgrade schema."""
- # ### commands auto generated by Alembic - please adjust! ###
- op.drop_column("scores", "nsmall_tick_hit")
- op.drop_column("scores", "nlarge_tick_hit")
- # ### end Alembic commands ###
diff --git a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi
index 88b79c5..433c53b 100644
--- a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi
+++ b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi
@@ -1,11 +1,4 @@
from typing import Any
-class APIMod:
- def __init__(self, acronym: str, settings: dict[str, Any]) -> None: ...
- @property
- def acronym(self) -> str: ...
- @property
- def settings(self) -> str: ...
-
def encode(obj: Any) -> bytes: ...
def decode(data: bytes) -> Any: ...
diff --git a/packages/msgpack_lazer_api/src/decode.rs b/packages/msgpack_lazer_api/src/decode.rs
index 15156ca..1e36c42 100644
--- a/packages/msgpack_lazer_api/src/decode.rs
+++ b/packages/msgpack_lazer_api/src/decode.rs
@@ -1,8 +1,6 @@
-use crate::APIMod;
use chrono::{TimeZone, Utc};
use pyo3::types::PyDict;
use pyo3::{prelude::*, IntoPyObjectExt};
-use std::collections::HashMap;
use std::io::Read;
pub fn read_object(
@@ -13,6 +11,8 @@ pub fn read_object(
match rmp::decode::read_marker(cursor) {
Ok(marker) => match marker {
rmp::Marker::Null => Ok(py.None()),
+ rmp::Marker::True => Ok(true.into_py_any(py)?),
+ rmp::Marker::False => Ok(false.into_py_any(py)?),
rmp::Marker::FixPos(val) => Ok(val.into_pyobject(py)?.into_any().unbind()),
rmp::Marker::FixNeg(val) => Ok(val.into_pyobject(py)?.into_any().unbind()),
rmp::Marker::U8 => {
@@ -86,8 +86,6 @@ pub fn read_object(
cursor.read_exact(&mut data).map_err(to_py_err)?;
Ok(data.into_pyobject(py)?.into_any().unbind())
}
- rmp::Marker::True => Ok(true.into_py_any(py)?),
- rmp::Marker::False => Ok(false.into_py_any(py)?),
rmp::Marker::FixStr(len) => read_string(py, cursor, len as u32),
rmp::Marker::Str8 => {
let mut buf = [0u8; 1];
@@ -206,13 +204,12 @@ fn read_array(
let obj1 = read_object(py, cursor, false)?;
if obj1.extract::(py).map_or(false, |k| k.len() == 2) {
let obj2 = read_object(py, cursor, true)?;
- return Ok(APIMod {
- acronym: obj1.extract::(py)?,
- settings: obj2.extract::>(py)?,
- }
- .into_pyobject(py)?
- .into_any()
- .unbind());
+
+ let api_mod_dict = PyDict::new(py);
+ api_mod_dict.set_item("acronym", obj1)?;
+ api_mod_dict.set_item("settings", obj2)?;
+
+ return Ok(api_mod_dict.into_pyobject(py)?.into_any().unbind());
} else {
items.push(obj1);
i += 1;
diff --git a/packages/msgpack_lazer_api/src/encode.rs b/packages/msgpack_lazer_api/src/encode.rs
index 88a732b..3ff4864 100644
--- a/packages/msgpack_lazer_api/src/encode.rs
+++ b/packages/msgpack_lazer_api/src/encode.rs
@@ -1,8 +1,7 @@
-use crate::APIMod;
-use chrono::{DateTime, Utc};
-use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyStringMethods};
+use chrono::{DateTime, Utc};
+use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyResult, PyStringMethods};
use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString};
-use pyo3::{Bound, PyAny, PyRef, Python};
+use pyo3::{Bound, PyAny};
use std::io::Write;
fn write_list(buf: &mut Vec, obj: &Bound<'_, PyList>) {
@@ -61,19 +60,42 @@ fn write_hashmap(buf: &mut Vec, obj: &Bound<'_, PyDict>) {
}
}
-fn write_nil(buf: &mut Vec){
+fn write_nil(buf: &mut Vec) {
rmp::encode::write_nil(buf).unwrap();
}
-// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
-fn write_api_mod(buf: &mut Vec, api_mod: PyRef) {
- rmp::encode::write_array_len(buf, 2).unwrap();
- rmp::encode::write_str(buf, &api_mod.acronym).unwrap();
- rmp::encode::write_array_len(buf, api_mod.settings.len() as u32).unwrap();
- for (k, v) in api_mod.settings.iter() {
- rmp::encode::write_str(buf, k).unwrap();
- Python::with_gil(|py| write_object(buf, &v.bind(py)));
+fn is_api_mod(dict: &Bound<'_, PyDict>) -> bool {
+ if let Ok(Some(acronym)) = dict.get_item("acronym") {
+ if let Ok(acronym_str) = acronym.extract::() {
+ return acronym_str.len() == 2;
+ }
}
+ false
+}
+
+// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
+fn write_api_mod(buf: &mut Vec, api_mod: &Bound<'_, PyDict>) -> PyResult<()> {
+ let acronym = api_mod
+ .get_item("acronym")?
+ .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("APIMod missing 'acronym' field"))?;
+ let acronym_str = acronym.extract::()?;
+
+ let settings = api_mod
+ .get_item("settings")?
+ .unwrap_or_else(|| PyDict::new(acronym.py()).into_any());
+ let settings_dict = settings.downcast::()?;
+
+ rmp::encode::write_array_len(buf, 2).unwrap();
+ rmp::encode::write_str(buf, &acronym_str).unwrap();
+ rmp::encode::write_array_len(buf, settings_dict.len() as u32).unwrap();
+
+ for (k, v) in settings_dict.iter() {
+ let key_str = k.extract::()?;
+ rmp::encode::write_str(buf, &key_str).unwrap();
+ write_object(buf, &v);
+ }
+
+ Ok(())
}
fn write_datetime(buf: &mut Vec, obj: &Bound<'_, PyDateTime>) {
@@ -110,22 +132,24 @@ pub fn write_object(buf: &mut Vec, obj: &Bound<'_, PyAny>) {
write_list(buf, list);
} else if let Ok(string) = obj.downcast::() {
write_string(buf, string);
- } else if let Ok(integer) = obj.downcast::() {
- write_integer(buf, integer);
- } else if let Ok(float) = obj.downcast::() {
- write_float(buf, float);
} else if let Ok(boolean) = obj.downcast::() {
write_bool(buf, boolean);
+ } else if let Ok(float) = obj.downcast::() {
+ write_float(buf, float);
+ } else if let Ok(integer) = obj.downcast::() {
+ write_integer(buf, integer);
} else if let Ok(bytes) = obj.downcast::() {
write_bin(buf, bytes);
} else if let Ok(dict) = obj.downcast::() {
- write_hashmap(buf, dict);
+ if is_api_mod(dict) {
+ write_api_mod(buf, dict).unwrap_or_else(|_| write_hashmap(buf, dict));
+ } else {
+ write_hashmap(buf, dict);
+ }
} else if let Ok(_none) = obj.downcast::() {
write_nil(buf);
} else if let Ok(datetime) = obj.downcast::() {
write_datetime(buf, datetime);
- } else if let Ok(api_mod) = obj.extract::>() {
- write_api_mod(buf, api_mod);
} else {
panic!("Unsupported type");
}
diff --git a/packages/msgpack_lazer_api/src/lib.rs b/packages/msgpack_lazer_api/src/lib.rs
index fda540c..220e645 100644
--- a/packages/msgpack_lazer_api/src/lib.rs
+++ b/packages/msgpack_lazer_api/src/lib.rs
@@ -2,30 +2,6 @@ mod decode;
mod encode;
use pyo3::prelude::*;
-use std::collections::HashMap;
-
-#[pyclass]
-struct APIMod {
- #[pyo3(get, set)]
- acronym: String,
- #[pyo3(get, set)]
- settings: HashMap,
-}
-
-#[pymethods]
-impl APIMod {
- #[new]
- fn new(acronym: String, settings: HashMap) -> Self {
- APIMod { acronym, settings }
- }
-
- fn __repr__(&self) -> String {
- format!(
- "APIMod(acronym='{}', settings={:?})",
- self.acronym, self.settings
- )
- }
-}
#[pyfunction]
#[pyo3(name = "encode")]
@@ -46,6 +22,5 @@ fn decode_py(py: Python, data: &[u8]) -> PyResult {
fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(encode_py, m)?)?;
m.add_function(wrap_pyfunction!(decode_py, m)?)?;
- m.add_class::()?;
Ok(())
}
diff --git a/static/README.md b/static/README.md
index 16ece63..77b54fe 100644
--- a/static/README.md
+++ b/static/README.md
@@ -2,4 +2,4 @@
- `mods.json`: 包含了游戏中的所有可用mod的详细信息。
- Origin: https://github.com/ppy/osu-web/blob/master/database/mods.json
- - Version: 2025/6/10 `b68c920b1db3d443b9302fdc3f86010c875fe380`
+ - Version: 2025/7/30 `ff49b66b27a2850aea4b6b3ba563cfe936cb6082`
diff --git a/static/mods.json b/static/mods.json
index defb57f..0a8449b 100644
--- a/static/mods.json
+++ b/static/mods.json
@@ -2438,7 +2438,8 @@
"Settings": [],
"IncompatibleMods": [
"CN",
- "RX"
+ "RX",
+ "MF"
],
"RequiresConfiguration": false,
"UserPlayable": false,
@@ -2460,7 +2461,8 @@
"AC",
"AT",
"CN",
- "RX"
+ "RX",
+ "MF"
],
"RequiresConfiguration": false,
"UserPlayable": false,
@@ -2477,7 +2479,8 @@
"Settings": [],
"IncompatibleMods": [
"AT",
- "CN"
+ "CN",
+ "MF"
],
"RequiresConfiguration": false,
"UserPlayable": true,
@@ -2638,6 +2641,24 @@
"ValidForMultiplayerAsFreeMod": true,
"AlwaysValidForSubmission": false
},
+ {
+ "Acronym": "MF",
+ "Name": "Moving Fast",
+ "Description": "Dashing by default, slow down!",
+ "Type": "Fun",
+ "Settings": [],
+ "IncompatibleMods": [
+ "AT",
+ "CN",
+ "RX"
+ ],
+ "RequiresConfiguration": false,
+ "UserPlayable": true,
+ "ValidForMultiplayer": true,
+ "ValidForFreestyleAsRequiredMod": false,
+ "ValidForMultiplayerAsFreeMod": true,
+ "AlwaysValidForSubmission": false
+ },
{
"Acronym": "SV2",
"Name": "Score V2",