Merge branch 'feat/multiplayer-api' of https://github.com/GooGuTeam/osu_lazer_api into feat/multiplayer-api

This commit is contained in:
jimmy-sketch
2025-08-08 10:35:55 +00:00
69 changed files with 5464 additions and 2851 deletions

1
.gitignore vendored
View File

@@ -37,6 +37,7 @@ pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
test-cert/
htmlcov/
.tox/
.nox/

View File

@@ -8,7 +8,7 @@ import string
from app.config import settings
from app.database import (
OAuthToken,
User as DBUser,
User,
)
from app.log import logger
@@ -74,7 +74,7 @@ def get_password_hash(password: str) -> str:
async def authenticate_user_legacy(
db: AsyncSession, name: str, password: str
) -> DBUser | None:
) -> User | None:
"""
验证用户身份 - 使用类似 from_login 的逻辑
"""
@@ -82,7 +82,7 @@ async def authenticate_user_legacy(
pw_md5 = hashlib.md5(password.encode()).hexdigest()
# 2. 根据用户名查找用户
statement = select(DBUser).where(DBUser.name == name)
statement = select(User).where(User.username == name)
user = (await db.exec(statement)).first()
if not user:
return None
@@ -113,7 +113,7 @@ async def authenticate_user_legacy(
async def authenticate_user(
db: AsyncSession, username: str, password: str
) -> DBUser | None:
) -> User | None:
"""验证用户身份"""
return await authenticate_user_legacy(db, username, password)

View File

@@ -1,3 +1,4 @@
from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthToken
from .beatmap import (
Beatmap as Beatmap,
@@ -8,65 +9,64 @@ from .beatmapset import (
BeatmapsetResp as BeatmapsetResp,
)
from .best_score import BestScore
from .legacy import LegacyOAuthToken, LegacyUserStatistics
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
from .favourite_beatmapset import FavouriteBeatmapset
from .lazer_user import (
User,
UserResp,
)
from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from .playlist_attempts import ItemAttemptsCount, ItemAttemptsResp
from .playlist_best_score import PlaylistBestScore
from .playlists import Playlist, PlaylistResp
from .pp_best_score import PPBestScore
from .relationship import Relationship, RelationshipResp, RelationshipType
from .room import Room, RoomResp
from .score import (
MultiplayerScores,
Score,
ScoreAround,
ScoreBase,
ScoreResp,
ScoreStatistics,
)
from .score_token import ScoreToken, ScoreTokenResp
from .statistics import (
UserStatistics,
UserStatisticsResp,
)
from .team import Team, TeamMember
from .user import (
DailyChallengeStats,
LazerUserAchievement,
LazerUserBadge,
LazerUserBanners,
LazerUserCountry,
LazerUserCounts,
LazerUserKudosu,
LazerUserMonthlyPlaycounts,
LazerUserPreviousUsername,
LazerUserProfile,
LazerUserProfileSections,
LazerUserReplaysWatched,
LazerUserStatistics,
RankHistory,
User,
UserAchievement,
UserAvatar,
from .user_account_history import (
UserAccountHistory,
UserAccountHistoryResp,
UserAccountHistoryType,
)
BeatmapsetResp.model_rebuild()
BeatmapResp.model_rebuild()
__all__ = [
"Beatmap",
"BeatmapResp",
"Beatmapset",
"BeatmapsetResp",
"BestScore",
"DailyChallengeStats",
"LazerUserAchievement",
"LazerUserBadge",
"LazerUserBanners",
"LazerUserCountry",
"LazerUserCounts",
"LazerUserKudosu",
"LazerUserMonthlyPlaycounts",
"LazerUserPreviousUsername",
"LazerUserProfile",
"LazerUserProfileSections",
"LazerUserReplaysWatched",
"LazerUserStatistics",
"LegacyOAuthToken",
"LegacyUserStatistics",
"DailyChallengeStatsResp",
"FavouriteBeatmapset",
"ItemAttemptsCount",
"ItemAttemptsResp",
"MultiplayerEvent",
"MultiplayerEventResp",
"MultiplayerScores",
"OAuthToken",
"RankHistory",
"PPBestScore",
"Playlist",
"PlaylistBestScore",
"PlaylistResp",
"Relationship",
"RelationshipResp",
"RelationshipType",
"Room",
"RoomResp",
"Score",
"ScoreAround",
"ScoreBase",
"ScoreResp",
"ScoreStatistics",
@@ -75,6 +75,17 @@ __all__ = [
"Team",
"TeamMember",
"User",
"UserAccountHistory",
"UserAccountHistoryResp",
"UserAccountHistoryType",
"UserAchievement",
"UserAvatar",
"UserAchievement",
"UserAchievementResp",
"UserResp",
"UserStatistics",
"UserStatisticsResp",
]
for i in __all__:
if i.endswith("Resp"):
globals()[i].model_rebuild() # type: ignore[call-arg]

View File

@@ -0,0 +1,40 @@
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
)
if TYPE_CHECKING:
from .lazer_user import User
class UserAchievementBase(SQLModel, UTCBaseModel):
achievement_id: int = Field(primary_key=True)
achieved_at: datetime = Field(
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
)
class UserAchievement(UserAchievementBase, table=True):
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True
)
user: "User" = Relationship(back_populates="achievement")
class UserAchievementResp(UserAchievementBase):
@classmethod
def from_db(cls, db_model: UserAchievement) -> "UserAchievementResp":
return cls.model_validate(db_model)

View File

@@ -1,19 +1,21 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
from .lazer_user import User
class OAuthToken(SQLModel, table=True):
class OAuthToken(UTCBaseModel, SQLModel, table=True):
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), index=True)
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
access_token: str = Field(max_length=500, unique=True)
refresh_token: str = Field(max_length=500, unique=True)

View File

@@ -7,13 +7,14 @@ from app.models.score import MODE_TO_INT, GameMode
from .beatmapset import Beatmapset, BeatmapsetResp
from sqlalchemy import DECIMAL, Column, DateTime
from sqlalchemy.orm import joinedload
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .lazer_user import User
class BeatmapOwner(SQLModel):
id: int
@@ -66,7 +67,9 @@ class Beatmap(BeatmapBase, table=True):
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus
# optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps")
beatmapset: Beatmapset = Relationship(
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
)
@property
def can_ranked(self) -> bool:
@@ -87,13 +90,7 @@ class Beatmap(BeatmapBase, table=True):
session.add(beatmap)
await session.commit()
beatmap = (
await session.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(Beatmap.id == resp.id)
)
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
).first()
assert beatmap is not None, "Beatmap should not be None after commit"
return beatmap
@@ -131,13 +128,9 @@ class Beatmap(BeatmapBase, table=True):
) -> "Beatmap":
beatmap = (
await session.exec(
select(Beatmap)
.where(
select(Beatmap).where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
)
).first()
if not beatmap:
@@ -164,11 +157,13 @@ class BeatmapResp(BeatmapBase):
url: str = ""
@classmethod
def from_db(
async def from_db(
cls,
beatmap: Beatmap,
query_mode: GameMode | None = None,
from_set: bool = False,
session: AsyncSession | None = None,
user: "User | None" = None,
) -> "BeatmapResp":
beatmap_ = beatmap.model_dump()
if query_mode is not None and beatmap.mode != query_mode:
@@ -178,5 +173,7 @@ class BeatmapResp(BeatmapBase):
beatmap_["ranked"] = beatmap.beatmap_status.value
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
if not from_set:
beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset)
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=session, user=user
)
return cls.model_validate(beatmap_)

View File

@@ -4,13 +4,17 @@ from typing import TYPE_CHECKING, TypedDict, cast
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.score import GameMode
from .lazer_user import BASE_INCLUDES, User, UserResp
from pydantic import BaseModel, model_serializer
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
from sqlmodel import Field, Relationship, SQLModel
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .beatmap import Beatmap, BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
class BeatmapCovers(SQLModel):
@@ -88,7 +92,6 @@ class BeatmapsetBase(SQLModel):
artist_unicode: str = Field(index=True)
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
creator: str
favourite_count: int
nsfw: bool = Field(default=False)
play_count: int
preview_url: str
@@ -112,11 +115,9 @@ class BeatmapsetBase(SQLModel):
pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
ratings: list[int] = Field(default=None, sa_column=Column(JSON))
# TODO: recent_favourites: Optional[list[User]] = None
# TODO: related_users: Optional[list[User]] = None
# TODO: user: Optional[User] = Field(default=None)
track_id: int | None = Field(default=None) # feature artist?
# TODO: has_favourited
# BeatmapsetExtended
bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2)))
@@ -129,7 +130,7 @@ class BeatmapsetBase(SQLModel):
tags: str = Field(default="", sa_column=Column(Text))
class Beatmapset(BeatmapsetBase, table=True):
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
@@ -150,6 +151,7 @@ class Beatmapset(BeatmapsetBase, table=True):
hype_required: int = Field(default=0)
availability_info: str | None = Field(default=None)
download_disabled: bool = Field(default=False)
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
async def from_resp(
@@ -197,40 +199,88 @@ class BeatmapsetResp(BeatmapsetBase):
genre: BeatmapTranslationText | None = None
language: BeatmapTranslationText | None = None
nominations: BeatmapNominations | None = None
has_favourited: bool = False
favourite_count: int = 0
recent_favourites: list[UserResp] = Field(default_factory=list)
@classmethod
def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
async def from_db(
cls,
beatmapset: Beatmapset,
include: list[str] = [],
session: AsyncSession | None = None,
user: User | None = None,
) -> "BeatmapsetResp":
from .beatmap import BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
beatmaps = [
BeatmapResp.from_db(beatmap, from_set=True)
for beatmap in beatmapset.beatmaps
]
update = {
"beatmaps": [
await BeatmapResp.from_db(beatmap, from_set=True)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
],
"hype": BeatmapHype(
current=beatmapset.hype_current, required=beatmapset.hype_required
),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
),
"genre": BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
),
"language": BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
),
"nominations": BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
),
"status": beatmapset.beatmap_status.name.lower(),
"ranked": beatmapset.beatmap_status.value,
"is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
**beatmapset.model_dump(),
}
if session and user:
existing_favourite = (
await session.exec(
select(FavouriteBeatmapset).where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id
)
)
).first()
update["has_favourited"] = existing_favourite is not None
if session and "recent_favourites" in include:
recent_favourites = (
await session.exec(
select(FavouriteBeatmapset)
.where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id,
)
.order_by(col(FavouriteBeatmapset.date).desc())
.limit(50)
)
).all()
update["recent_favourites"] = [
await UserResp.from_db(
await favourite.awaitable_attrs.user,
session=session,
include=BASE_INCLUDES,
)
for favourite in recent_favourites
]
if session:
update["favourite_count"] = (
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).one()
return cls.model_validate(
{
"beatmaps": beatmaps,
"hype": BeatmapHype(
current=beatmapset.hype_current, required=beatmapset.hype_required
),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
),
"genre": BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
),
"language": BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
),
"nominations": BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
),
"status": beatmapset.beatmap_status.name.lower(),
"ranked": beatmapset.beatmap_status.value,
"is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
**beatmapset.model_dump(),
}
update,
)

View File

@@ -1,14 +1,14 @@
from typing import TYPE_CHECKING
from app.models.score import GameMode
from app.models.score import GameMode, Rank
from .user import User
from .lazer_user import User
from sqlmodel import (
JSON,
BigInteger,
Column,
Field,
Float,
ForeignKey,
Relationship,
SQLModel,
@@ -20,22 +20,27 @@ if TYPE_CHECKING:
class BestScore(SQLModel, table=True):
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
__tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), index=True)
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
pp: float = Field(
sa_column=Column(Float, default=0),
)
acc: float = Field(
sa_column=Column(Float, default=0),
total_score: int = Field(default=0, sa_column=Column(BigInteger))
mods: list[str] = Field(
default_factory=list,
sa_column=Column(JSON),
)
rank: Rank
user: User = Relationship()
score: "Score" = Relationship()
score: "Score" = Relationship(
sa_relationship_kwargs={
"foreign_keys": "[BestScore.score_id]",
"lazy": "joined",
}
)
beatmap: "Beatmap" = Relationship()

View File

@@ -0,0 +1,58 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
)
if TYPE_CHECKING:
from .lazer_user import User
class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
daily_streak_best: int = Field(default=0)
daily_streak_current: int = Field(default=0)
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
last_weekly_streak: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
playcount: int = Field(default=0)
top_10p_placements: int = Field(default=0)
top_50p_placements: int = Field(default=0)
weekly_streak_best: int = Field(default=0)
weekly_streak_current: int = Field(default=0)
class DailyChallengeStats(DailyChallengeStatsBase, table=True):
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
user_id: int | None = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
unique=True,
index=True,
primary_key=True,
),
)
user: "User" = Relationship(back_populates="daily_challenge_stats")
class DailyChallengeStatsResp(DailyChallengeStatsBase):
user_id: int
@classmethod
def from_db(
cls,
obj: DailyChallengeStats,
) -> "DailyChallengeStatsResp":
return cls.model_validate(obj)

View File

@@ -0,0 +1,53 @@
import datetime
from app.database.beatmapset import Beatmapset
from app.database.lazer_user import User
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
)
class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True):
__tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
exclude=True,
)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
index=True,
),
)
beatmapset_id: int = Field(
default=None,
sa_column=Column(
ForeignKey("beatmapsets.id"),
index=True,
),
)
date: datetime.datetime = Field(
default=datetime.datetime.now(datetime.UTC),
sa_column=Column(
DateTime,
),
)
user: User = Relationship(back_populates="favourite_beatmapsets")
beatmapset: Beatmapset = Relationship(
sa_relationship_kwargs={
"lazy": "selectin",
},
back_populates="favourites",
)

333
app/database/lazer_user.py Normal file
View File

@@ -0,0 +1,333 @@
from datetime import UTC, datetime
from typing import TYPE_CHECKING, NotRequired, TypedDict
from app.models.model import UTCBaseModel
from app.models.score import GameMode
from app.models.user import Country, Page, RankHistory
from .achievement import UserAchievement, UserAchievementResp
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
from .monthly_playcounts import MonthlyPlaycounts, MonthlyPlaycountsResp
from .statistics import UserStatistics, UserStatisticsResp
from .team import Team, TeamMember
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
JSON,
BigInteger,
Column,
DateTime,
Field,
Relationship,
SQLModel,
func,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .favourite_beatmapset import FavouriteBeatmapset
from .relationship import RelationshipResp
class Kudosu(TypedDict):
available: int
total: int
class RankHighest(TypedDict):
rank: int
updated_at: datetime
class UserProfileCover(TypedDict):
url: str
custom_url: NotRequired[str]
id: NotRequired[str]
Badge = TypedDict(
"Badge",
{
"awarded_at": datetime,
"description": str,
"image@2x_url": str,
"image_url": str,
"url": str,
},
)
class UserBase(UTCBaseModel, SQLModel):
avatar_url: str = ""
country_code: str = Field(default="CN", max_length=2, index=True)
# ? default_group: str|None
is_active: bool = True
is_bot: bool = False
is_supporter: bool = False
last_visit: datetime | None = Field(
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
)
pm_friends_only: bool = False
profile_colour: str | None = None
username: str = Field(max_length=32, unique=True, index=True)
page: Page = Field(sa_column=Column(JSON), default=Page(html="", raw=""))
previous_usernames: list[str] = Field(default_factory=list, sa_column=Column(JSON))
# TODO: replays_watched_counts
support_level: int = 0
badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON))
# optional
is_restricted: bool = False
# blocks
cover: UserProfileCover = Field(
default=UserProfileCover(
url="https://assets.ppy.sh/user-profile-covers/default.jpeg"
),
sa_column=Column(JSON),
)
beatmap_playcounts_count: int = 0
# kudosu
# UserExtended
playmode: GameMode = GameMode.OSU
discord: str | None = None
has_supported: bool = False
interests: str | None = None
join_date: datetime = Field(default=datetime.now(UTC))
location: str | None = None
max_blocks: int = 50
max_friends: int = 500
occupation: str | None = None
playstyle: list[str] = Field(default_factory=list, sa_column=Column(JSON))
# TODO: post_count
profile_hue: int | None = None
profile_order: list[str] = Field(
default_factory=lambda: [
"me",
"recent_activity",
"top_ranks",
"medals",
"historical",
"beatmaps",
"kudosu",
],
sa_column=Column(JSON),
)
title: str | None = None
title_url: str | None = None
twitter: str | None = None
website: str | None = None
# undocumented
comments_count: int = 0
post_count: int = 0
is_admin: bool = False
is_gmt: bool = False
is_qat: bool = False
is_bng: bool = False
class User(AsyncAttrs, UserBase, table=True):
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
)
account_history: list[UserAccountHistory] = Relationship()
statistics: list[UserStatistics] = Relationship()
achievement: list[UserAchievement] = Relationship(back_populates="user")
team_membership: TeamMember | None = Relationship(back_populates="user")
daily_challenge_stats: DailyChallengeStats | None = Relationship(
back_populates="user"
)
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
back_populates="user"
)
email: str = Field(max_length=254, unique=True, index=True, exclude=True)
priv: int = Field(default=1, exclude=True)
pw_bcrypt: str = Field(max_length=60, exclude=True)
silence_end_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
donor_end_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
class UserResp(UserBase):
id: int | None = None
is_online: bool = False
groups: list = [] # TODO
country: Country = Field(default_factory=lambda: Country(code="CN", name="China"))
favourite_beatmapset_count: int = 0 # TODO
graveyard_beatmapset_count: int = 0 # TODO
guest_beatmapset_count: int = 0 # TODO
loved_beatmapset_count: int = 0 # TODO
mapping_follower_count: int = 0 # TODO
nominated_beatmapset_count: int = 0 # TODO
pending_beatmapset_count: int = 0 # TODO
ranked_beatmapset_count: int = 0 # TODO
follow_user_mapping: list[int] = Field(default_factory=list)
follower_count: int = 0
friends: list["RelationshipResp"] | None = None
scores_best_count: int = 0
scores_first_count: int = 0
scores_recent_count: int = 0
scores_pinned_count: int = 0
account_history: list[UserAccountHistoryResp] = []
active_tournament_banners: list[dict] = [] # TODO
kudosu: Kudosu = Field(default_factory=lambda: Kudosu(available=0, total=0)) # TODO
monthly_playcounts: list[MonthlyPlaycountsResp] = Field(default_factory=list)
unread_pm_count: int = 0 # TODO
rank_history: RankHistory | None = None # TODO
rank_highest: RankHighest | None = None # TODO
statistics: UserStatisticsResp | None = None
statistics_rulesets: dict[str, UserStatisticsResp] | None = None
user_achievements: list[UserAchievementResp] = Field(default_factory=list)
cover_url: str = "" # deprecated
team: Team | None = None
session_verified: bool = True
daily_challenge_user_stats: DailyChallengeStatsResp | None = None
# TODO: monthly_playcounts, unread_pm_count rank_history, user_preferences
@classmethod
async def from_db(
cls,
obj: User,
session: AsyncSession,
include: list[str] = [],
ruleset: GameMode | None = None,
) -> "UserResp":
from app.dependencies.database import get_redis
from .best_score import BestScore
from .relationship import Relationship, RelationshipResp, RelationshipType
u = cls.model_validate(obj.model_dump())
u.id = obj.id
u.follower_count = (
await session.exec(
select(func.count())
.select_from(Relationship)
.where(
Relationship.target_id == obj.id,
Relationship.type == RelationshipType.FOLLOW,
)
)
).one()
u.scores_best_count = (
await session.exec(
select(func.count())
.select_from(BestScore)
.where(
BestScore.user_id == obj.id,
)
.limit(200)
)
).one()
redis = get_redis()
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
u.cover_url = (
obj.cover.get(
"url", "https://assets.ppy.sh/user-profile-covers/default.jpeg"
)
if obj.cover
else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
)
if "friends" in include:
u.friends = [
await RelationshipResp.from_db(session, r)
for r in (
await session.exec(
select(Relationship).where(
Relationship.user_id == obj.id,
Relationship.type == RelationshipType.FOLLOW,
)
)
).all()
]
if "team" in include:
if await obj.awaitable_attrs.team_membership:
assert obj.team_membership
u.team = obj.team_membership.team
if "account_history" in include:
u.account_history = [
UserAccountHistoryResp.from_db(ah)
for ah in await obj.awaitable_attrs.account_history
]
if "daily_challenge_user_stats":
if await obj.awaitable_attrs.daily_challenge_stats:
assert obj.daily_challenge_stats
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
obj.daily_challenge_stats
)
if "statistics" in include:
current_stattistics = None
for i in await obj.awaitable_attrs.statistics:
if i.mode == (ruleset or obj.playmode):
current_stattistics = i
break
u.statistics = (
UserStatisticsResp.from_db(current_stattistics)
if current_stattistics
else None
)
if "statistics_rulesets" in include:
u.statistics_rulesets = {
i.mode.value: UserStatisticsResp.from_db(i)
for i in await obj.awaitable_attrs.statistics
}
if "monthly_playcounts" in include:
u.monthly_playcounts = [
MonthlyPlaycountsResp.from_db(pc)
for pc in await obj.awaitable_attrs.monthly_playcounts
]
if "achievements" in include:
u.user_achievements = [
UserAchievementResp.from_db(ua)
for ua in await obj.awaitable_attrs.achievement
]
return u
ALL_INCLUDED = [
"friends",
"team",
"account_history",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
"achievements",
"monthly_playcounts",
]
SEARCH_INCLUDED = [
"team",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
"achievements",
"monthly_playcounts",
]
BASE_INCLUDES = [
"team",
"daily_challenge_user_stats",
"statistics",
]

View File

@@ -1,94 +0,0 @@
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import JSON, Column, DateTime
from sqlalchemy.orm import Mapped
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
# ============================================
# 旧的兼容性表模型(保留以便向后兼容)
# ============================================
class LegacyUserStatistics(SQLModel, table=True):
__tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
mode: str = Field(max_length=10) # osu, taiko, fruits, mania
# 基本统计
count_100: int = Field(default=0)
count_300: int = Field(default=0)
count_50: int = Field(default=0)
count_miss: int = Field(default=0)
# 等级信息
level_current: int = Field(default=1)
level_progress: int = Field(default=0)
# 排名信息
global_rank: int | None = Field(default=None)
global_rank_exp: int | None = Field(default=None)
country_rank: int | None = Field(default=None)
# PP 和分数
pp: float = Field(default=0.0)
pp_exp: float = Field(default=0.0)
ranked_score: int = Field(default=0)
hit_accuracy: float = Field(default=0.0)
total_score: int = Field(default=0)
total_hits: int = Field(default=0)
maximum_combo: int = Field(default=0)
# 游戏统计
play_count: int = Field(default=0)
play_time: int = Field(default=0)
replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=False)
# 成绩等级计数
grade_ss: int = Field(default=0)
grade_ssh: int = Field(default=0)
grade_s: int = Field(default=0)
grade_sh: int = Field(default=0)
grade_a: int = Field(default=0)
# 最高排名记录
rank_highest: int | None = Field(default=None)
rank_highest_updated_at: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
# 关联关系
user: Mapped["User"] = Relationship(back_populates="statistics")
class LegacyOAuthToken(SQLModel, table=True):
__tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
access_token: str = Field(max_length=255, index=True)
refresh_token: str = Field(max_length=255, index=True)
expires_at: datetime = Field(sa_column=Column(DateTime))
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
previous_usernames: list = Field(default_factory=list, sa_column=Column(JSON))
replays_watched_counts: list = Field(default_factory=list, sa_column=Column(JSON))
# 用户关系
user: "User" = Relationship()

View File

@@ -0,0 +1,43 @@
from datetime import date
from typing import TYPE_CHECKING
from sqlmodel import (
BigInteger,
Column,
Field,
ForeignKey,
Relationship,
SQLModel,
)
if TYPE_CHECKING:
from .lazer_user import User
class MonthlyPlaycounts(SQLModel, table=True):
__tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
year: int = Field(index=True)
month: int = Field(index=True)
playcount: int = Field(default=0)
user: "User" = Relationship(back_populates="monthly_playcounts")
class MonthlyPlaycountsResp(SQLModel):
start_date: date
count: int
@classmethod
def from_db(cls, db_model: MonthlyPlaycounts) -> "MonthlyPlaycountsResp":
return cls(
start_date=date(db_model.year, db_model.month, 1),
count=db_model.playcount,
)

View File

@@ -0,0 +1,56 @@
from datetime import UTC, datetime
from typing import Any
from app.models.model import UTCBaseModel
from sqlmodel import (
JSON,
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
SQLModel,
)
class MultiplayerEventBase(SQLModel, UTCBaseModel):
playlist_item_id: int | None = None
user_id: int | None = Field(
default=None,
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),
)
created_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True),
),
default=datetime.now(UTC),
)
event_type: str = Field(index=True)
class MultiplayerEvent(MultiplayerEventBase, table=True):
__tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
)
room_id: int = Field(foreign_key="rooms.id", index=True)
updated_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True),
),
default=datetime.now(UTC),
)
event_detail: dict[str, Any] | None = Field(
sa_column=Column(JSON),
default_factory=dict,
)
class MultiplayerEventResp(MultiplayerEventBase):
id: int
@classmethod
def from_db(cls, event: MultiplayerEvent) -> "MultiplayerEventResp":
return cls.model_validate(event)

View File

@@ -0,0 +1,114 @@
from .lazer_user import User, UserResp
from .playlist_best_score import PlaylistBestScore
from sqlmodel import (
BigInteger,
Column,
Field,
ForeignKey,
Relationship,
SQLModel,
col,
func,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
class ItemAttemptsCountBase(SQLModel):
room_id: int = Field(foreign_key="rooms.id", index=True)
attempts: int = Field(default=0)
completed: int = Field(default=0)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
accuracy: float = 0.0
pp: float = 0
total_score: int = 0
class ItemAttemptsCount(ItemAttemptsCountBase, table=True):
__tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user: User = Relationship()
async def get_position(self, session: AsyncSession) -> int:
rownum = (
func.row_number()
.over(
partition_by=col(ItemAttemptsCountBase.room_id),
order_by=col(ItemAttemptsCountBase.total_score).desc(),
)
.label("rn")
)
subq = select(ItemAttemptsCountBase, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
result = await session.exec(stmt)
return result.one()
async def update(self, session: AsyncSession):
playlist_scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == self.room_id,
PlaylistBestScore.user_id == self.user_id,
)
)
).all()
self.attempts = sum(score.attempts for score in playlist_scores)
self.total_score = sum(score.total_score for score in playlist_scores)
self.pp = sum(score.score.pp for score in playlist_scores)
self.completed = len(playlist_scores)
self.accuracy = (
sum(score.score.accuracy * score.attempts for score in playlist_scores)
/ self.completed
if self.completed > 0
else 0.0
)
await session.commit()
await session.refresh(self)
@classmethod
async def get_or_create(
cls,
room_id: int,
user_id: int,
session: AsyncSession,
) -> "ItemAttemptsCount":
item_attempts = await session.exec(
select(cls).where(
cls.room_id == room_id,
cls.user_id == user_id,
)
)
item_attempts = item_attempts.first()
if item_attempts is None:
item_attempts = cls(room_id=room_id, user_id=user_id)
session.add(item_attempts)
await session.commit()
await session.refresh(item_attempts)
await item_attempts.update(session)
return item_attempts
class ItemAttemptsResp(ItemAttemptsCountBase):
user: UserResp | None = None
position: int | None = None
@classmethod
async def from_db(
cls,
item_attempts: ItemAttemptsCount,
session: AsyncSession,
include: list[str] = [],
) -> "ItemAttemptsResp":
resp = cls.model_validate(item_attempts)
resp.user = await UserResp.from_db(
item_attempts.user,
session=session,
include=["statistics", "team", "daily_challenge_user_stats"],
)
if "position" in include:
resp.position = await item_attempts.get_position(session)
return resp

View File

@@ -0,0 +1,109 @@
from typing import TYPE_CHECKING
from .lazer_user import User
from redis.asyncio import Redis
from sqlmodel import (
BigInteger,
Column,
Field,
ForeignKey,
Relationship,
SQLModel,
col,
func,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .score import Score
class PlaylistBestScore(SQLModel, table=True):
__tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
room_id: int = Field(foreign_key="rooms.id", index=True)
playlist_id: int = Field(foreign_key="room_playlists.id", index=True)
total_score: int = Field(default=0, sa_column=Column(BigInteger))
attempts: int = Field(default=0) # playlist
user: User = Relationship()
score: "Score" = Relationship(
sa_relationship_kwargs={
"foreign_keys": "[PlaylistBestScore.score_id]",
"lazy": "joined",
}
)
async def process_playlist_best_score(
room_id: int,
playlist_id: int,
user_id: int,
score_id: int,
total_score: int,
session: AsyncSession,
redis: Redis,
):
previous = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == room_id,
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.user_id == user_id,
)
)
).first()
if previous is None:
score = PlaylistBestScore(
user_id=user_id,
score_id=score_id,
room_id=room_id,
playlist_id=playlist_id,
total_score=total_score,
)
session.add(score)
else:
previous.score_id = score_id
previous.total_score = total_score
previous.attempts += 1
await session.commit()
await redis.decr(f"multiplayer:{room_id}:gameplay:players")
async def get_position(
room_id: int,
playlist_id: int,
score_id: int,
session: AsyncSession,
) -> int:
rownum = (
func.row_number()
.over(
partition_by=(
col(PlaylistBestScore.playlist_id),
col(PlaylistBestScore.room_id),
),
order_by=col(PlaylistBestScore.total_score).desc(),
)
.label("row_number")
)
subq = (
select(PlaylistBestScore, rownum)
.where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
)
.subquery()
)
stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
result = await session.exec(stmt)
s = result.one_or_none()
return s if s else 0

143
app/database/playlists.py Normal file
View File

@@ -0,0 +1,143 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from app.models.mods import APIMod
from app.models.multiplayer_hub import PlaylistItem
from .beatmap import Beatmap, BeatmapResp
from sqlmodel import (
JSON,
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
func,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .room import Room
class PlaylistBase(SQLModel, UTCBaseModel):
id: int = Field(index=True)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
ruleset_id: int = Field(ge=0, le=3)
expired: bool = Field(default=False)
playlist_order: int = Field(default=0)
played_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True)),
default=None,
)
allowed_mods: list[APIMod] = Field(
default_factory=list,
sa_column=Column(JSON),
)
required_mods: list[APIMod] = Field(
default_factory=list,
sa_column=Column(JSON),
)
beatmap_id: int = Field(
foreign_key="beatmaps.id",
)
freestyle: bool = Field(default=False)
class Playlist(PlaylistBase, table=True):
__tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType]
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
room_id: int = Field(foreign_key="rooms.id", exclude=True)
beatmap: Beatmap = Relationship(
sa_relationship_kwargs={
"lazy": "joined",
}
)
room: "Room" = Relationship()
@classmethod
async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(
cls.room_id == room_id
)
result = await session.exec(stmt)
return result.one()
@classmethod
async def from_hub(
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
) -> "Playlist":
next_id = await cls.get_next_id_for_room(room_id, session=session)
return cls(
id=next_id,
owner_id=playlist.owner_id,
ruleset_id=playlist.ruleset_id,
beatmap_id=playlist.beatmap_id,
required_mods=playlist.required_mods,
allowed_mods=playlist.allowed_mods,
expired=playlist.expired,
playlist_order=playlist.playlist_order,
played_at=playlist.played_at,
freestyle=playlist.freestyle,
room_id=room_id,
)
@classmethod
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
db_playlist = await session.exec(
select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
)
db_playlist = db_playlist.first()
if db_playlist is None:
raise ValueError("Playlist item not found")
db_playlist.owner_id = playlist.owner_id
db_playlist.ruleset_id = playlist.ruleset_id
db_playlist.beatmap_id = playlist.beatmap_id
db_playlist.required_mods = playlist.required_mods
db_playlist.allowed_mods = playlist.allowed_mods
db_playlist.expired = playlist.expired
db_playlist.playlist_order = playlist.playlist_order
db_playlist.played_at = playlist.played_at
db_playlist.freestyle = playlist.freestyle
await session.commit()
@classmethod
async def add_to_db(
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
):
db_playlist = await cls.from_hub(playlist, room_id, session)
session.add(db_playlist)
await session.commit()
await session.refresh(db_playlist)
playlist.id = db_playlist.id
@classmethod
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
db_playlist = await session.exec(
select(cls).where(cls.id == item_id, cls.room_id == room_id)
)
db_playlist = db_playlist.first()
if db_playlist is None:
raise ValueError("Playlist item not found")
await session.delete(db_playlist)
await session.commit()
class PlaylistResp(PlaylistBase):
beatmap: BeatmapResp | None = None
@classmethod
async def from_db(
cls, playlist: Playlist, include: list[str] = []
) -> "PlaylistResp":
data = playlist.model_dump()
if "beatmap" in include:
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap, from_set=True)
resp = cls.model_validate(data)
return resp

View File

@@ -0,0 +1,41 @@
from typing import TYPE_CHECKING
from app.models.score import GameMode
from .lazer_user import User
from sqlmodel import (
BigInteger,
Column,
Field,
Float,
ForeignKey,
Relationship,
SQLModel,
)
if TYPE_CHECKING:
from .beatmap import Beatmap
from .score import Score
class PPBestScore(SQLModel, table=True):
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
pp: float = Field(
sa_column=Column(Float, default=0),
)
acc: float = Field(
sa_column=Column(Float, default=0),
)
user: User = Relationship()
score: "Score" = Relationship()
beatmap: "Beatmap" = Relationship()

View File

@@ -1,8 +1,6 @@
from enum import Enum
from app.models.user import User as APIUser
from .user import User as DBUser
from .lazer_user import User, UserResp
from pydantic import BaseModel
from sqlmodel import (
@@ -24,12 +22,16 @@ class RelationshipType(str, Enum):
class Relationship(SQLModel, table=True):
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
exclude=True,
)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
ForeignKey("lazer_users.id"),
index=True,
),
)
@@ -37,20 +39,22 @@ class Relationship(SQLModel, table=True):
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
ForeignKey("lazer_users.id"),
index=True,
),
)
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
target: DBUser = SQLRelationship(
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
target: User = SQLRelationship(
sa_relationship_kwargs={
"foreign_keys": "[Relationship.target_id]",
"lazy": "selectin",
}
)
class RelationshipResp(BaseModel):
target_id: int
target: APIUser
target: UserResp
mutual: bool = False
type: RelationshipType
@@ -58,8 +62,6 @@ class RelationshipResp(BaseModel):
async def from_db(
cls, session: AsyncSession, relationship: Relationship
) -> "RelationshipResp":
from app.utils import convert_db_user_to_api_user
target_relationship = (
await session.exec(
select(Relationship).where(
@@ -75,7 +77,16 @@ class RelationshipResp(BaseModel):
)
return cls(
target_id=relationship.target_id,
target=await convert_db_user_to_api_user(relationship.target),
target=await UserResp.from_db(
relationship.target,
session,
include=[
"team",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
],
),
mutual=mutual,
type=relationship.type,
)

View File

@@ -1,6 +1,125 @@
from sqlmodel import Field, SQLModel
from datetime import UTC, datetime
from app.models.multiplayer_hub import ServerMultiplayerRoom
from app.models.room import (
MatchType,
QueueMode,
RoomCategory,
RoomDifficultyRange,
RoomPlaylistItemStats,
RoomStatus,
)
from .lazer_user import User, UserResp
from .playlists import Playlist, PlaylistResp
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
)
class RoomIndex(SQLModel, table=True):
__tablename__ = "mp_room_index" # pyright: ignore[reportAssignmentType]
id: int = Field(default=None, primary_key=True, index=True) # pyright: ignore[reportCallIssue]
class RoomBase(SQLModel):
name: str = Field(index=True)
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
duration: int | None = Field(default=None) # minutes
starts_at: datetime | None = Field(
sa_column=Column(
DateTime(timezone=True),
),
default=datetime.now(UTC),
)
ended_at: datetime | None = Field(
sa_column=Column(
DateTime(timezone=True),
),
default=None,
)
participant_count: int = Field(default=0)
max_attempts: int | None = Field(default=None) # playlists
type: MatchType
queue_mode: QueueMode
auto_skip: bool
auto_start_duration: int
status: RoomStatus
# TODO: channel_id
# recent_participants: list[User]
class Room(RoomBase, table=True):
__tablename__ = "rooms" # pyright: ignore[reportAssignmentType]
id: int = Field(default=None, primary_key=True, index=True)
host_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
host: User = Relationship()
playlist: list[Playlist] = Relationship(
sa_relationship_kwargs={
"lazy": "joined",
"cascade": "all, delete-orphan",
"overlaps": "room",
}
)
class RoomResp(RoomBase):
id: int
password: str | None = None
host: UserResp | None = None
playlist: list[PlaylistResp] = []
playlist_item_stats: RoomPlaylistItemStats | None = None
difficulty_range: RoomDifficultyRange | None = None
current_playlist_item: PlaylistResp | None = None
@classmethod
async def from_db(cls, room: Room) -> "RoomResp":
resp = cls.model_validate(room.model_dump())
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
difficulty_range = RoomDifficultyRange(
min=0,
max=0,
)
rulesets = set()
for playlist in room.playlist:
stats.count_total += 1
if not playlist.expired:
stats.count_active += 1
rulesets.add(playlist.ruleset_id)
difficulty_range.min = min(
difficulty_range.min, playlist.beatmap.difficulty_rating
)
difficulty_range.max = max(
difficulty_range.max, playlist.beatmap.difficulty_rating
)
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
stats.ruleset_ids = list(rulesets)
resp.playlist_item_stats = stats
resp.difficulty_range = difficulty_range
resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
return resp
@classmethod
async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp":
room = server_room.room
resp = cls(
id=room.room_id,
name=room.settings.name,
type=room.settings.match_type,
queue_mode=room.settings.queue_mode,
auto_skip=room.settings.auto_skip,
auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
status=server_room.status,
category=server_room.category,
# duration = room.settings.duration,
starts_at=server_room.start_at,
participant_count=len(room.users),
)
return resp

View File

@@ -1,8 +1,9 @@
import asyncio
from collections.abc import Sequence
from datetime import UTC, datetime
from datetime import UTC, date, datetime
import json
import math
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from app.calculator import (
calculate_pp,
@@ -12,9 +13,8 @@ from app.calculator import (
calculate_weighted_pp,
clamp,
)
from app.database.score_token import ScoreToken
from app.database.user import LazerUserStatistics, User
from app.models.beatmap import BeatmapRankStatus
from app.database.team import TeamMember
from app.models.model import RespWithCursor, UTCBaseModel
from app.models.mods import APIMod, mods_can_get_pp
from app.models.score import (
INT_TO_MODE,
@@ -26,15 +26,24 @@ from app.models.score import (
ScoreStatistics,
SoloScoreSubmissionInfo,
)
from app.models.user import User as APIUser
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import Beatmapset, BeatmapsetResp
from .beatmapset import BeatmapsetResp
from .best_score import BestScore
from .lazer_user import User, UserResp
from .monthly_playcounts import MonthlyPlaycounts
from .pp_best_score import PPBestScore
from .relationship import (
Relationship as DBRelationship,
RelationshipType,
)
from .score_token import ScoreToken
from redis import Redis
from redis.asyncio import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy.orm import aliased, joinedload
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import (
JSON,
BigInteger,
@@ -43,9 +52,10 @@ from sqlmodel import (
Relationship,
SQLModel,
col,
false,
func,
select,
text,
true,
)
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql._expression_select_cls import SelectOfScalar
@@ -54,7 +64,7 @@ if TYPE_CHECKING:
from app.fetcher import Fetcher
class ScoreBase(SQLModel):
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段
accuracy: float
map_md5: str = Field(max_length=32, index=True)
@@ -78,10 +88,11 @@ class ScoreBase(SQLModel):
default=0, sa_column=Column(BigInteger), exclude=True
)
type: str
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
# optional
# TODO: current_user_attributes
position: int | None = Field(default=None) # multiplayer
# position: int | None = Field(default=None) # multiplayer
class Score(ScoreBase, table=True):
@@ -89,12 +100,11 @@ class Score(ScoreBase, table=True):
id: int | None = Field(
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
)
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
ForeignKey("lazer_users.id"),
index=True,
),
)
@@ -112,28 +122,13 @@ class Score(ScoreBase, table=True):
gamemode: GameMode = Field(index=True)
# optional
beatmap: "Beatmap" = Relationship()
user: "User" = Relationship()
beatmap: Beatmap = Relationship()
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@property
def is_perfect_combo(self) -> bool:
return self.max_combo == self.beatmap.max_combo
@staticmethod
def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]:
clause = select(Score).options(
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
)
if with_user:
return clause.options(
joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType]
)
return clause
@staticmethod
def select_clause_unique(
*where_clauses: ColumnExpressionArgument[bool] | bool,
@@ -147,18 +142,7 @@ class Score(ScoreBase, table=True):
)
subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True)
return (
select(best)
.where(subq.c.rn == 1)
.options(
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType]
)
)
return select(best).where(subq.c.rn == 1)
class ScoreResp(ScoreBase):
@@ -173,22 +157,23 @@ class ScoreResp(ScoreBase):
ruleset_id: int | None = None
beatmap: BeatmapResp | None = None
beatmapset: BeatmapsetResp | None = None
user: APIUser | None = None
user: UserResp | None = None
statistics: ScoreStatistics | None = None
maximum_statistics: ScoreStatistics | None = None
rank_global: int | None = None
rank_country: int | None = None
position: int | None = None
scores_around: "ScoreAround | None" = None
@classmethod
async def from_db(
cls, session: AsyncSession, score: Score, user: User | None = None
) -> "ScoreResp":
from app.utils import convert_db_user_to_api_user
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
s = cls.model_validate(score.model_dump())
assert score.id
s.beatmap = BeatmapResp.from_db(score.beatmap)
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset)
await score.awaitable_attrs.beatmap
s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(
score.beatmap.beatmapset, session=session, user=score.user
)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = MODE_TO_INT[score.gamemode]
@@ -220,163 +205,184 @@ class ScoreResp(ScoreBase):
s.maximum_statistics = {
HitResult.GREAT: score.beatmap.max_combo,
}
if user:
s.user = await convert_db_user_to_api_user(user)
s.user = await UserResp.from_db(
score.user,
session,
include=["statistics", "team", "daily_challenge_user_stats"],
ruleset=score.gamemode,
)
s.rank_global = (
await get_score_position_by_id(
session,
score.map_md5,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=user or score.user,
user=score.user,
)
or None
)
s.rank_country = (
await get_score_position_by_id(
session,
score.map_md5,
score.beatmap_id,
score.id,
score.gamemode,
user or score.user,
score.user,
type=LeaderboardType.COUNTRY,
)
or None
)
return s
class MultiplayerScores(RespWithCursor):
scores: list[ScoreResp] = Field(default_factory=list)
params: dict[str, Any] = Field(default_factory=dict)
class ScoreAround(SQLModel):
higher: MultiplayerScores | None = None
lower: MultiplayerScores | None = None
async def get_best_id(session: AsyncSession, score_id: int) -> None:
rownum = (
func.row_number()
.over(partition_by=col(BestScore.user_id), order_by=col(BestScore.pp).desc())
.over(
partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()
)
.label("rn")
)
subq = select(BestScore, rownum).subquery()
subq = select(PPBestScore, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
result = await session.exec(stmt)
return result.one_or_none()
async def _score_where(
type: LeaderboardType,
beatmap: int,
mode: GameMode,
mods: list[str] | None = None,
user: User | None = None,
) -> list[ColumnElement[bool]] | None:
wheres = [
col(BestScore.beatmap_id) == beatmap,
col(BestScore.gamemode) == mode,
]
if type == LeaderboardType.FRIENDS:
if user and user.is_supporter:
subq = (
select(DBRelationship.target_id)
.where(
DBRelationship.type == RelationshipType.FOLLOW,
DBRelationship.user_id == user.id,
)
.subquery()
)
wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id)))
else:
return None
elif type == LeaderboardType.COUNTRY:
if user and user.is_supporter:
wheres.append(
col(BestScore.user).has(col(User.country_code) == user.country_code)
)
else:
return None
elif type == LeaderboardType.TEAM:
if user:
team_membership = await user.awaitable_attrs.team_membership
if team_membership:
team_id = team_membership.team_id
wheres.append(
col(BestScore.user).has(
col(User.team_membership).has(TeamMember.team_id == team_id)
)
)
if mods:
if user and user.is_supporter:
wheres.append(
text(
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
) # pyright: ignore[reportArgumentType]
)
else:
return None
return wheres
async def get_leaderboard(
session: AsyncSession,
beatmap_md5: str,
beatmap: int,
mode: GameMode,
type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None,
mods: list[str] | None = None,
user: User | None = None,
limit: int = 50,
) -> list[Score]:
scores = []
if type == LeaderboardType.GLOBAL:
query = (
select(Score)
.where(
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
Score.mods == mods if user and user.is_supporter else false(),
)
.limit(limit)
.order_by(
col(Score.total_score).desc(),
)
)
result = await session.exec(query)
scores = list[Score](result.all())
elif type == LeaderboardType.FRIENDS and user and user.is_supporter:
# TODO
...
elif type == LeaderboardType.TEAM and user and user.team_membership:
team_id = user.team_membership.team_id
query = (
select(Score)
.join(Beatmap)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where(
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Score.user.team_membership).is_not(None),
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
Score.mods == mods if user and user.is_supporter else false(),
)
.limit(limit)
.order_by(
col(Score.total_score).desc(),
)
)
result = await session.exec(query)
scores = list[Score](result.all())
) -> tuple[list[Score], Score | None]:
wheres = await _score_where(type, beatmap, mode, mods, user)
if wheres is None:
return [], None
query = (
select(BestScore)
.where(*wheres)
.limit(limit)
.order_by(col(BestScore.total_score).desc())
)
if mods:
query = query.params(w=json.dumps(mods))
scores = [s.score for s in await session.exec(query)]
user_score = None
if user:
user_score = (
await session.exec(
select(Score).where(
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
Score.user_id == user.id,
col(Score.passed).is_(True),
)
self_query = (
select(BestScore)
.where(BestScore.user_id == user.id)
.where(
col(BestScore.beatmap_id) == beatmap,
col(BestScore.gamemode) == mode,
)
).first()
.order_by(col(BestScore.total_score).desc())
.limit(1)
)
if mods:
self_query = self_query.where(
text(
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
)
).params(w=json.dumps(mods))
user_bs = (await session.exec(self_query)).first()
if user_bs:
user_score = user_bs.score
if user_score and user_score not in scores:
scores.append(user_score)
return scores
return scores, user_score
async def get_score_position_by_user(
session: AsyncSession,
beatmap_md5: str,
beatmap: int,
user: User,
mode: GameMode,
type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None,
mods: list[str] | None = None,
) -> int:
where_clause = [
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
]
if mods and user.is_supporter:
where_clause.append(Score.mods == mods)
else:
where_clause.append(false())
if type == LeaderboardType.FRIENDS and user.is_supporter:
# TODO
...
elif type == LeaderboardType.TEAM and user.team_membership:
team_id = user.team_membership.team_id
where_clause.append(
col(Score.user.team_membership).is_not(None),
)
where_clause.append(
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
)
wheres = await _score_where(type, beatmap, mode, mods, user=user)
if wheres is None:
return 0
rownum = (
func.row_number()
.over(
partition_by=Score.map_md5,
order_by=col(Score.total_score).desc(),
partition_by=col(BestScore.beatmap_id),
order_by=col(BestScore.total_score).desc(),
)
.label("row_number")
)
subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery()
stmt = select(subq.c.row_number).where(subq.c.user == user)
subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
stmt = select(subq.c.row_number).where(subq.c.user_id == user.id)
result = await session.exec(stmt)
s = result.one_or_none()
return s if s else 0
@@ -384,57 +390,26 @@ async def get_score_position_by_user(
async def get_score_position_by_id(
session: AsyncSession,
beatmap_md5: str,
beatmap: int,
score_id: int,
mode: GameMode,
user: User | None = None,
type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None,
mods: list[str] | None = None,
) -> int:
where_clause = [
Score.map_md5 == beatmap_md5,
Score.id == score_id,
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
]
if mods and user and user.is_supporter:
where_clause.append(Score.mods == mods)
elif mods:
where_clause.append(false())
wheres = await _score_where(type, beatmap, mode, mods, user=user)
if wheres is None:
return 0
rownum = (
func.row_number()
.over(
partition_by=[col(Score.user_id), col(Score.map_md5)],
order_by=col(Score.total_score).desc(),
partition_by=col(BestScore.beatmap_id),
order_by=col(BestScore.total_score).desc(),
)
.label("rownum")
.label("row_number")
)
subq = (
select(Score.user_id, Score.id, Score.total_score, rownum)
.join(Beatmap)
.where(*where_clause)
.subquery()
)
best_scores = aliased(subq)
overall_rank = (
func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank")
)
final_q = (
select(best_scores.c.id, overall_rank)
.select_from(best_scores)
.where(best_scores.c.rownum == 1)
.subquery()
)
stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id)
subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
result = await session.exec(stmt)
s = result.one_or_none()
return s if s else 0
@@ -445,16 +420,38 @@ async def get_user_best_score_in_beatmap(
beatmap: int,
user: int,
mode: GameMode | None = None,
) -> Score | None:
) -> BestScore | None:
return (
await session.exec(
Score.select_clause(False)
select(BestScore)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,
Score.user_id == user,
BestScore.gamemode == mode if mode is not None else true(),
BestScore.beatmap_id == beatmap,
BestScore.user_id == user,
)
.order_by(col(Score.total_score).desc())
.order_by(col(BestScore.total_score).desc())
)
).first()
# FIXME
async def get_user_best_score_with_mod_in_beatmap(
session: AsyncSession,
beatmap: int,
user: int,
mod: list[str],
mode: GameMode | None = None,
) -> BestScore | None:
return (
await session.exec(
select(BestScore)
.where(
BestScore.gamemode == mode if mode is not None else True,
BestScore.beatmap_id == beatmap,
BestScore.user_id == user,
# BestScore.mods == mod,
)
.order_by(col(BestScore.total_score).desc())
)
).first()
@@ -464,13 +461,13 @@ async def get_user_best_pp_in_beatmap(
beatmap: int,
user: int,
mode: GameMode,
) -> BestScore | None:
) -> PPBestScore | None:
return (
await session.exec(
select(BestScore).where(
BestScore.beatmap_id == beatmap,
BestScore.user_id == user,
BestScore.gamemode == mode,
select(PPBestScore).where(
PPBestScore.beatmap_id == beatmap,
PPBestScore.user_id == user,
PPBestScore.gamemode == mode,
)
)
).first()
@@ -480,12 +477,12 @@ async def get_user_best_pp(
session: AsyncSession,
user: int,
limit: int = 200,
) -> Sequence[BestScore]:
) -> Sequence[PPBestScore]:
return (
await session.exec(
select(BestScore)
.where(BestScore.user_id == user)
.order_by(col(BestScore.pp).desc())
select(PPBestScore)
.where(PPBestScore.user_id == user)
.order_by(col(PPBestScore.pp).desc())
.limit(limit)
)
).all()
@@ -494,27 +491,45 @@ async def get_user_best_pp(
async def process_user(
session: AsyncSession, user: User, score: Score, ranked: bool = False
):
assert user.id
assert score.id
mod_for_save = list({mod["acronym"] for mod in score.mods})
previous_score_best = await get_user_best_score_in_beatmap(
session, score.beatmap_id, user.id, score.gamemode
)
statistics = None
previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap(
session, score.beatmap_id, user.id, mod_for_save, score.gamemode
)
add_to_db = False
for i in user.lazer_statistics:
mouthly_playcount = (
await session.exec(
select(MonthlyPlaycounts).where(
MonthlyPlaycounts.user_id == user.id,
MonthlyPlaycounts.year == date.today().year,
MonthlyPlaycounts.month == date.today().month,
)
)
).first()
if mouthly_playcount is None:
mouthly_playcount = MonthlyPlaycounts(
user_id=user.id, year=date.today().year, month=date.today().month
)
add_to_db = True
statistics = None
for i in await user.awaitable_attrs.statistics:
if i.mode == score.gamemode.value:
statistics = i
break
if statistics is None:
statistics = LazerUserStatistics(
mode=score.gamemode.value,
user_id=user.id,
raise ValueError(
f"User {user.id} does not have statistics for mode {score.gamemode.value}"
)
add_to_db = True
# pc, pt, tth, tts
statistics.total_score += score.total_score
difference = (
score.total_score - previous_score_best.total_score
if previous_score_best and previous_score_best.id != score.id
if previous_score_best
else score.total_score
)
if difference > 0 and score.passed and ranked:
@@ -541,11 +556,48 @@ async def process_user(
statistics.grade_sh -= 1
case Rank.A:
statistics.grade_a -= 1
else:
previous_score_best = BestScore(
user_id=user.id,
beatmap_id=score.beatmap_id,
gamemode=score.gamemode,
score_id=score.id,
total_score=score.total_score,
rank=score.rank,
mods=mod_for_save,
)
session.add(previous_score_best)
statistics.ranked_score += difference
statistics.level_current = calculate_score_to_level(statistics.ranked_score)
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
if score.passed and ranked:
if previous_score_best_mod is not None:
previous_score_best_mod.mods = mod_for_save
previous_score_best_mod.score_id = score.id
previous_score_best_mod.rank = score.rank
previous_score_best_mod.total_score = score.total_score
elif (
previous_score_best is not None and previous_score_best.score_id != score.id
):
session.add(
BestScore(
user_id=user.id,
beatmap_id=score.beatmap_id,
gamemode=score.gamemode,
score_id=score.id,
total_score=score.total_score,
rank=score.rank,
mods=mod_for_save,
)
)
statistics.play_count += 1
mouthly_playcount.playcount += 1
statistics.play_time += int((score.ended_at - score.started_at).total_seconds())
statistics.count_100 += score.n100 + score.nkatu
statistics.count_300 += score.n300 + score.ngeki
statistics.count_50 += score.n50
statistics.count_miss += score.nmiss
statistics.total_hits += (
score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
)
@@ -563,11 +615,8 @@ async def process_user(
acc_sum = clamp(acc_sum, 0.0, 100.0)
statistics.pp = pp_sum
statistics.hit_accuracy = acc_sum
statistics.updated_at = datetime.now(UTC)
if add_to_db:
session.add(statistics)
session.add(mouthly_playcount)
await session.commit()
await session.refresh(user)
@@ -581,7 +630,10 @@ async def process_score(
fetcher: "Fetcher",
session: AsyncSession,
redis: Redis,
item_id: int | None = None,
room_id: int | None = None,
) -> Score:
assert user.id
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
score = Score(
accuracy=info.accuracy,
@@ -611,6 +663,8 @@ async def process_score(
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
playlist_item_id=item_id,
room_id=room_id,
)
if can_get_pp:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
@@ -628,7 +682,7 @@ async def process_score(
)
if previous_pp_best is None or score.pp > previous_pp_best.pp:
assert score.id
best_score = BestScore(
best_score = PPBestScore(
user_id=user_id,
score_id=score.id,
beatmap_id=beatmap_id,
@@ -637,7 +691,7 @@ async def process_score(
acc=score.accuracy,
)
session.add(best_score)
session.delete(previous_pp_best) if previous_pp_best else None
await session.delete(previous_pp_best) if previous_pp_best else None
await session.commit()
await session.refresh(score)
await session.refresh(score_token)

View File

@@ -1,15 +1,16 @@
from datetime import datetime
from app.models.model import UTCBaseModel
from app.models.score import GameMode
from .beatmap import Beatmap
from .user import User
from .lazer_user import User
from sqlalchemy import Column, DateTime, Index
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
class ScoreTokenBase(SQLModel):
class ScoreTokenBase(SQLModel, UTCBaseModel):
score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
ruleset_id: GameMode
playlist_item_id: int | None = Field(default=None) # playlist
@@ -34,10 +35,10 @@ class ScoreToken(ScoreTokenBase, table=True):
autoincrement=True,
),
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
beatmap_id: int = Field(foreign_key="beatmaps.id")
user: "User" = Relationship()
beatmap: "Beatmap" = Relationship()
user: User = Relationship()
beatmap: Beatmap = Relationship()
class ScoreTokenResp(ScoreTokenBase):

View File

@@ -0,0 +1,95 @@
from typing import TYPE_CHECKING
from app.models.score import GameMode
from sqlmodel import (
BigInteger,
Column,
Field,
ForeignKey,
Relationship,
SQLModel,
)
if TYPE_CHECKING:
from .lazer_user import User
class UserStatisticsBase(SQLModel):
mode: GameMode
count_100: int = Field(default=0, sa_column=Column(BigInteger))
count_300: int = Field(default=0, sa_column=Column(BigInteger))
count_50: int = Field(default=0, sa_column=Column(BigInteger))
count_miss: int = Field(default=0, sa_column=Column(BigInteger))
global_rank: int | None = Field(default=None)
country_rank: int | None = Field(default=None)
pp: float = Field(default=0.0)
ranked_score: int = Field(default=0)
hit_accuracy: float = Field(default=0.00)
total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_hits: int = Field(default=0, sa_column=Column(BigInteger))
maximum_combo: int = Field(default=0)
play_count: int = Field(default=0)
play_time: int = Field(default=0, sa_column=Column(BigInteger))
replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=True)
class UserStatistics(UserStatisticsBase, table=True):
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
index=True,
),
)
grade_ss: int = Field(default=0)
grade_ssh: int = Field(default=0)
grade_s: int = Field(default=0)
grade_sh: int = Field(default=0)
grade_a: int = Field(default=0)
level_current: int = Field(default=1)
level_progress: int = Field(default=0)
user: "User" = Relationship(back_populates="statistics") # type: ignore[valid-type]
class UserStatisticsResp(UserStatisticsBase):
grade_counts: dict[str, int] = Field(
default_factory=lambda: {
"ss": 0,
"ssh": 0,
"s": 0,
"sh": 0,
"a": 0,
}
)
level: dict[str, int] = Field(
default_factory=lambda: {
"current": 1,
"progress": 0,
}
)
@classmethod
def from_db(cls, obj: UserStatistics) -> "UserStatisticsResp":
s = cls.model_validate(obj)
s.grade_counts = {
"ss": obj.grade_ss,
"ssh": obj.grade_ssh,
"s": obj.grade_s,
"sh": obj.grade_sh,
"a": obj.grade_a,
}
s.level = {
"current": obj.level_current,
"progress": obj.level_progress,
}
return s

View File

@@ -1,14 +1,16 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
from .lazer_user import User
class Team(SQLModel, table=True):
class Team(SQLModel, UTCBaseModel, table=True):
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
@@ -22,15 +24,19 @@ class Team(SQLModel, table=True):
members: list["TeamMember"] = Relationship(back_populates="team")
class TeamMember(SQLModel, table=True):
class TeamMember(SQLModel, UTCBaseModel, table=True):
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
team_id: int = Field(foreign_key="teams.id")
joined_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="team_membership")
team: "Team" = Relationship(back_populates="members")
user: "User" = Relationship(
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
)
team: "Team" = Relationship(
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
)

View File

@@ -1,527 +0,0 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from .legacy import LegacyUserStatistics
from .team import TeamMember
from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text
from sqlalchemy.dialects.mysql import VARCHAR
from sqlalchemy.orm import joinedload, selectinload
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select
class User(SQLModel, table=True):
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
# 主键
id: int = Field(
default=None, sa_column=Column(BigInteger, primary_key=True, index=True)
)
# 基本信息(匹配 migrations_old 中的结构)
name: str = Field(max_length=32, unique=True, index=True) # 用户名
safe_name: str = Field(max_length=32, unique=True, index=True) # 安全用户名
email: str = Field(max_length=254, unique=True, index=True)
priv: int = Field(default=1) # 权限
pw_bcrypt: str = Field(max_length=60) # bcrypt 哈希密码
country: str = Field(default="CN", max_length=2) # 国家代码
# 状态和时间
silence_end: int = Field(default=0)
donor_end: int = Field(default=0)
creation_time: int = Field(default=0) # Unix 时间戳
latest_activity: int = Field(default=0) # Unix 时间戳
# 游戏相关
preferred_mode: int = Field(default=0) # 偏好游戏模式
play_style: int = Field(default=0) # 游戏风格
# 扩展信息
clan_id: int = Field(default=0)
clan_priv: int = Field(default=0)
custom_badge_name: str | None = Field(default=None, max_length=16)
custom_badge_icon: str | None = Field(default=None, max_length=64)
userpage_content: str | None = Field(default=None, max_length=2048)
api_key: str | None = Field(default=None, max_length=36, unique=True)
# 虚拟字段用于兼容性
@property
def username(self):
return self.name
@property
def country_code(self):
return self.country
@property
def join_date(self):
creation_time = getattr(self, "creation_time", 0)
return (
datetime.fromtimestamp(creation_time)
if creation_time > 0
else datetime.utcnow()
)
@property
def last_visit(self):
latest_activity = getattr(self, "latest_activity", 0)
return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None
@property
def is_supporter(self):
return self.lazer_profile.is_supporter if self.lazer_profile else False
# 关联关系
lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user")
lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user")
lazer_counts: Optional["LazerUserCounts"] = Relationship(back_populates="user")
lazer_achievements: list["LazerUserAchievement"] = Relationship(
back_populates="user"
)
lazer_profile_sections: list["LazerUserProfileSections"] = Relationship(
back_populates="user"
)
statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user")
team_membership: Optional["TeamMember"] = Relationship(back_populates="user")
daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship(
back_populates="user"
)
rank_history: list["RankHistory"] = Relationship(back_populates="user")
avatar: Optional["UserAvatar"] = Relationship(back_populates="user")
active_banners: list["LazerUserBanners"] = Relationship(back_populates="user")
lazer_badges: list["LazerUserBadge"] = Relationship(back_populates="user")
lazer_monthly_playcounts: list["LazerUserMonthlyPlaycounts"] = Relationship(
back_populates="user"
)
lazer_previous_usernames: list["LazerUserPreviousUsername"] = Relationship(
back_populates="user"
)
lazer_replays_watched: list["LazerUserReplaysWatched"] = Relationship(
back_populates="user"
)
@classmethod
def all_select_option(cls):
return (
joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType]
joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType]
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
joinedload(cls.avatar), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType]
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
joinedload(cls.team_membership), # pyright: ignore[reportArgumentType]
selectinload(cls.rank_history), # pyright: ignore[reportArgumentType]
selectinload(cls.active_banners), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType]
)
@classmethod
def all_select_clause(cls):
return select(cls).options(*cls.all_select_option())
# ============================================
# Lazer API 专用表模型
# ============================================
class LazerUserProfile(SQLModel, table=True):
__tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
# 基本状态字段
is_active: bool = Field(default=True)
is_bot: bool = Field(default=False)
is_deleted: bool = Field(default=False)
is_online: bool = Field(default=True)
is_supporter: bool = Field(default=False)
is_restricted: bool = Field(default=False)
session_verified: bool = Field(default=False)
has_supported: bool = Field(default=False)
pm_friends_only: bool = Field(default=False)
# 基本资料字段
default_group: str = Field(default="default", max_length=50)
last_visit: datetime | None = Field(default=None, sa_column=Column(DateTime))
join_date: datetime | None = Field(default=None, sa_column=Column(DateTime))
profile_colour: str | None = Field(default=None, max_length=7)
profile_hue: int | None = Field(default=None)
# 社交媒体和个人资料字段
avatar_url: str | None = Field(default=None, max_length=500)
cover_url: str | None = Field(default=None, max_length=500)
discord: str | None = Field(default=None, max_length=100)
twitter: str | None = Field(default=None, max_length=100)
website: str | None = Field(default=None, max_length=500)
title: str | None = Field(default=None, max_length=100)
title_url: str | None = Field(default=None, max_length=500)
interests: str | None = Field(default=None, sa_column=Column(Text))
location: str | None = Field(default=None, max_length=100)
occupation: str | None = Field(default=None) # 职业字段,默认为 None
# 游戏相关字段
playmode: str = Field(default="osu", max_length=10)
support_level: int = Field(default=0)
max_blocks: int = Field(default=100)
max_friends: int = Field(default=500)
post_count: int = Field(default=0)
# 页面内容
page_html: str | None = Field(default=None, sa_column=Column(Text))
page_raw: str | None = Field(default=None, sa_column=Column(Text))
profile_order: str = Field(
default="me,recent_activity,top_ranks,medals,historical,beatmaps,kudosu"
)
# 关联关系
user: "User" = Relationship(back_populates="lazer_profile")
class LazerUserProfileSections(SQLModel, table=True):
__tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
section_name: str = Field(sa_column=Column(VARCHAR(50)))
display_order: int | None = Field(default=None)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_profile_sections")
class LazerUserCountry(SQLModel, table=True):
__tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
code: str = Field(max_length=2)
name: str = Field(max_length=100)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
class LazerUserKudosu(SQLModel, table=True):
__tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
available: int = Field(default=0)
total: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
class LazerUserCounts(SQLModel, table=True):
__tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
# 统计计数字段
beatmap_playcounts_count: int = Field(default=0)
comments_count: int = Field(default=0)
favourite_beatmapset_count: int = Field(default=0)
follower_count: int = Field(default=0)
graveyard_beatmapset_count: int = Field(default=0)
guest_beatmapset_count: int = Field(default=0)
loved_beatmapset_count: int = Field(default=0)
mapping_follower_count: int = Field(default=0)
nominated_beatmapset_count: int = Field(default=0)
pending_beatmapset_count: int = Field(default=0)
ranked_beatmapset_count: int = Field(default=0)
ranked_and_approved_beatmapset_count: int = Field(default=0)
unranked_beatmapset_count: int = Field(default=0)
scores_best_count: int = Field(default=0)
scores_first_count: int = Field(default=0)
scores_pinned_count: int = Field(default=0)
scores_recent_count: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
# 关联关系
user: "User" = Relationship(back_populates="lazer_counts")
class LazerUserStatistics(SQLModel, table=True):
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
mode: str = Field(default="osu", max_length=10, primary_key=True)
# 基本命中统计
count_100: int = Field(default=0)
count_300: int = Field(default=0)
count_50: int = Field(default=0)
count_miss: int = Field(default=0)
# 等级信息
level_current: int = Field(default=1)
level_progress: int = Field(default=0)
# 排名信息
global_rank: int | None = Field(default=None)
global_rank_exp: int | None = Field(default=None)
country_rank: int | None = Field(default=None)
# PP 和分数
pp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2)))
pp_exp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2)))
ranked_score: int = Field(default=0, sa_column=Column(BigInteger))
hit_accuracy: float = Field(default=0.00, sa_column=Column(DECIMAL(5, 2)))
total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_hits: int = Field(default=0, sa_column=Column(BigInteger))
maximum_combo: int = Field(default=0)
# 游戏统计
play_count: int = Field(default=0)
play_time: int = Field(default=0) # 秒
replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=False)
# 成绩等级计数
grade_ss: int = Field(default=0)
grade_ssh: int = Field(default=0)
grade_s: int = Field(default=0)
grade_sh: int = Field(default=0)
grade_a: int = Field(default=0)
# 最高排名记录
rank_highest: int | None = Field(default=None)
rank_highest_updated_at: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
# 关联关系
user: "User" = Relationship(back_populates="lazer_statistics")
class LazerUserBanners(SQLModel, table=True):
__tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
tournament_id: int
image_url: str = Field(sa_column=Column(VARCHAR(500)))
is_active: bool | None = Field(default=None)
# 修正user关系的back_populates值
user: "User" = Relationship(back_populates="active_banners")
class LazerUserAchievement(SQLModel, table=True):
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
achievement_id: int
achieved_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_achievements")
class LazerUserBadge(SQLModel, table=True):
__tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
badge_id: int
awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
description: str | None = Field(default=None, sa_column=Column(Text))
image_url: str | None = Field(default=None, max_length=500)
url: str | None = Field(default=None, max_length=500)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_badges")
class LazerUserMonthlyPlaycounts(SQLModel, table=True):
__tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
start_date: datetime = Field(sa_column=Column(Date))
play_count: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_monthly_playcounts")
class LazerUserPreviousUsername(SQLModel, table=True):
__tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
username: str = Field(max_length=32)
changed_at: datetime = Field(sa_column=Column(DateTime))
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_previous_usernames")
class LazerUserReplaysWatched(SQLModel, table=True):
__tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
start_date: datetime = Field(sa_column=Column(Date))
count: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_replays_watched")
# 类型转换用的 UserAchievement不是 SQLAlchemy 模型)
@dataclass
class UserAchievement:
achieved_at: datetime
achievement_id: int
class DailyChallengeStats(SQLModel, table=True):
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True)
)
daily_streak_best: int = Field(default=0)
daily_streak_current: int = Field(default=0)
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
last_weekly_streak: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
playcount: int = Field(default=0)
top_10p_placements: int = Field(default=0)
top_50p_placements: int = Field(default=0)
weekly_streak_best: int = Field(default=0)
weekly_streak_current: int = Field(default=0)
user: "User" = Relationship(back_populates="daily_challenge_stats")
class RankHistory(SQLModel, table=True):
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
mode: str = Field(max_length=10)
rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks
date_recorded: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="rank_history")
class UserAvatar(SQLModel, table=True):
__tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
filename: str = Field(max_length=255)
original_filename: str = Field(max_length=255)
file_size: int
mime_type: str = Field(max_length=100)
is_active: bool = Field(default=True)
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
r2_original_url: str | None = Field(default=None, max_length=500)
r2_game_url: str | None = Field(default=None, max_length=500)
user: "User" = Relationship(back_populates="avatar")

View File

@@ -0,0 +1,45 @@
from datetime import UTC, datetime
from enum import Enum
from app.models.model import UTCBaseModel
from sqlmodel import BigInteger, Column, Field, ForeignKey, Integer, SQLModel
class UserAccountHistoryType(str, Enum):
NOTE = "note"
RESTRICTION = "restriction"
SLIENCE = "silence"
TOURNAMENT_BAN = "tournament_ban"
class UserAccountHistoryBase(SQLModel, UTCBaseModel):
description: str | None = None
length: int
permanent: bool = False
timestamp: datetime = Field(default=datetime.now(UTC))
type: UserAccountHistoryType
class UserAccountHistory(UserAccountHistoryBase, table=True):
__tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
sa_column=Column(
Integer,
autoincrement=True,
index=True,
primary_key=True,
)
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
class UserAccountHistoryResp(UserAccountHistoryBase):
id: int | None = None
@classmethod
def from_db(cls, db_model: UserAccountHistory) -> "UserAccountHistoryResp":
return cls.model_validate(db_model)

View File

@@ -5,15 +5,11 @@ import json
from app.config import settings
from pydantic import BaseModel
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
try:
import redis
except ImportError:
redis = None
def json_serializer(value):
if isinstance(value, BaseModel | SQLModel):
@@ -25,10 +21,7 @@ def json_serializer(value):
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
# Redis 连接
if redis:
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
else:
redis_client = None
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
# 数据库依赖

View File

@@ -8,7 +8,7 @@ from app.log import logger
fetcher: Fetcher | None = None
def get_fetcher() -> Fetcher:
async def get_fetcher() -> Fetcher:
global fetcher
if fetcher is None:
fetcher = Fetcher(
@@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher:
settings.FETCHER_CALLBACK_URL,
)
redis = get_redis()
if redis:
access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}")
if access_token:
fetcher.access_token = str(access_token)
refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
if refresh_token:
fetcher.refresh_token = str(refresh_token)
if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
)
access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
if access_token:
fetcher.access_token = str(access_token)
refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
if refresh_token:
fetcher.refresh_token = str(refresh_token)
if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
)
return fetcher

View File

@@ -1,14 +1,13 @@
from __future__ import annotations
from app.auth import get_token_by_access_token
from app.database import (
User as DBUser,
)
from app.database import User
from .database import get_db
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer()
@@ -17,7 +16,7 @@ security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
) -> DBUser:
) -> User:
"""获取当前认证用户"""
token = credentials.credentials
@@ -27,13 +26,9 @@ async def get_current_user(
return user
async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None:
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
user = (
await db.exec(
DBUser.all_select_clause().where(DBUser.id == token_record.user_id)
)
).first()
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
return user

View File

@@ -59,16 +59,15 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
if redis:
redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
await redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
await redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
async def refresh_access_token(self) -> None:
async with AsyncClient() as client:
@@ -87,13 +86,12 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
if redis:
redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
await redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
await redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)

View File

@@ -4,7 +4,7 @@ from ._base import BaseFetcher
from httpx import AsyncClient
from loguru import logger
import redis
import redis.asyncio as redis
class OsuDotDirectFetcher(BaseFetcher):
@@ -22,8 +22,8 @@ class OsuDotDirectFetcher(BaseFetcher):
async def get_or_fetch_beatmap_raw(
self, redis: redis.Redis, beatmap_id: int
) -> str:
if redis.exists(f"beatmap:{beatmap_id}:raw"):
return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
if await redis.exists(f"beatmap:{beatmap_id}:raw"):
return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
raw = await self.get_beatmap_raw(beatmap_id)
redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
return raw

View File

@@ -42,11 +42,12 @@ class Language(IntEnum):
KOREAN = 6
FRENCH = 7
GERMAN = 8
ITALIAN = 9
SPANISH = 10
RUSSIAN = 11
POLISH = 12
OTHER = 13
SWEDISH = 9
ITALIAN = 10
SPANISH = 11
RUSSIAN = 12
POLISH = 13
OTHER = 14
class BeatmapAttributes(BaseModel):

View File

@@ -1,114 +1,85 @@
from __future__ import annotations
from enum import IntEnum
from typing import Any, Literal
from typing import ClassVar, Literal
from app.models.signalr import UserState
from app.models.signalr import SignalRUnionMessage, UserState
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel
class _UserActivity(BaseModel):
model_config = ConfigDict(serialize_by_alias=True)
type: Literal[
"ChoosingBeatmap",
"InSoloGame",
"WatchingReplay",
"SpectatingUser",
"SearchingForLobby",
"InLobby",
"InMultiplayerGame",
"SpectatingMultiplayerGame",
"InPlaylistGame",
"EditingBeatmap",
"ModdingBeatmap",
"TestingBeatmap",
"InDailyChallengeLobby",
"PlayingDailyChallenge",
] = Field(alias="$dtype")
value: Any | None = Field(alias="$value")
class _UserActivity(SignalRUnionMessage): ...
class ChoosingBeatmap(_UserActivity):
type: Literal["ChoosingBeatmap"] = Field(alias="$dtype")
class InGameValue(BaseModel):
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
ruleset_id: int = Field(alias="RulesetID")
ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb")
union_type: ClassVar[Literal[11]] = 11
class _InGame(_UserActivity):
value: InGameValue = Field(alias="$value")
beatmap_id: int
beatmap_display_title: str
ruleset_id: int
ruleset_playing_verb: str
class InSoloGame(_InGame):
type: Literal["InSoloGame"] = Field(alias="$dtype")
union_type: ClassVar[Literal[12]] = 12
class InMultiplayerGame(_InGame):
type: Literal["InMultiplayerGame"] = Field(alias="$dtype")
union_type: ClassVar[Literal[23]] = 23
class SpectatingMultiplayerGame(_InGame):
type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype")
union_type: ClassVar[Literal[24]] = 24
class InPlaylistGame(_InGame):
type: Literal["InPlaylistGame"] = Field(alias="$dtype")
union_type: ClassVar[Literal[31]] = 31
class EditingBeatmapValue(BaseModel):
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
class PlayingDailyChallenge(_InGame):
union_type: ClassVar[Literal[52]] = 52
class EditingBeatmap(_UserActivity):
type: Literal["EditingBeatmap"] = Field(alias="$dtype")
value: EditingBeatmapValue = Field(alias="$value")
union_type: ClassVar[Literal[41]] = 41
beatmap_id: int
beatmap_display_title: str
class TestingBeatmap(_UserActivity):
type: Literal["TestingBeatmap"] = Field(alias="$dtype")
class TestingBeatmap(EditingBeatmap):
union_type: ClassVar[Literal[43]] = 43
class ModdingBeatmap(_UserActivity):
type: Literal["ModdingBeatmap"] = Field(alias="$dtype")
class WatchingReplayValue(BaseModel):
score_id: int = Field(alias="ScoreID")
player_name: str = Field(alias="PlayerName")
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
class ModdingBeatmap(EditingBeatmap):
union_type: ClassVar[Literal[42]] = 42
class WatchingReplay(_UserActivity):
type: Literal["WatchingReplay"] = Field(alias="$dtype")
value: int | None = Field(alias="$value") # Replay ID
union_type: ClassVar[Literal[13]] = 13
score_id: int
player_name: str
beatmap_id: int
beatmap_display_title: str
class SpectatingUser(WatchingReplay):
type: Literal["SpectatingUser"] = Field(alias="$dtype")
union_type: ClassVar[Literal[14]] = 14
class SearchingForLobby(_UserActivity):
type: Literal["SearchingForLobby"] = Field(alias="$dtype")
class InLobbyValue(BaseModel):
room_id: int = Field(alias="RoomID")
room_name: str = Field(alias="RoomName")
union_type: ClassVar[Literal[21]] = 21
class InLobby(_UserActivity):
type: Literal["InLobby"] = "InLobby"
union_type: ClassVar[Literal[22]] = 22
room_id: int
room_name: str
class InDailyChallengeLobby(_UserActivity):
type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype")
union_type: ClassVar[Literal[51]] = 51
UserActivity = (
@@ -128,23 +99,25 @@ UserActivity = (
)
class MetadataClientState(UserState):
user_activity: UserActivity | None = None
status: OnlineStatus | None = None
class UserPresence(BaseModel):
activity: UserActivity | None = None
def to_dict(self) -> dict[str, Any] | None:
if self.status is None or self.status == OnlineStatus.OFFLINE:
return None
dumped = self.model_dump(by_alias=True, exclude_none=True)
return {
"Activity": dumped.get("user_activity"),
"Status": dumped.get("status"),
}
status: OnlineStatus | None = None
@property
def pushable(self) -> bool:
return self.status is not None and self.status != OnlineStatus.OFFLINE
@property
def for_push(self) -> "UserPresence | None":
return UserPresence(
activity=self.activity,
status=self.status,
)
class MetadataClientState(UserPresence, UserState): ...
class OnlineStatus(IntEnum):
OFFLINE = 0 # 隐身

22
app/models/model.py Normal file
View File

@@ -0,0 +1,22 @@
from __future__ import annotations
from datetime import UTC, datetime
from pydantic import BaseModel, field_serializer
class UTCBaseModel(BaseModel):
@field_serializer("*", when_used="json")
def serialize_datetime(self, v, _info):
if isinstance(v, datetime):
if v.tzinfo is None:
v = v.replace(tzinfo=UTC)
return v.astimezone(UTC).isoformat()
return v
Cursor = dict[str, int]
class RespWithCursor(BaseModel):
cursor: Cursor | None = None

View File

@@ -8,7 +8,7 @@ from app.path import STATIC_DIR
class APIMod(TypedDict):
acronym: str
settings: NotRequired[dict[str, bool | float | str]]
settings: NotRequired[dict[str, bool | float | str | int]]
# https://github.com/ppy/osu-api/wiki#mods

View File

@@ -0,0 +1,925 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from enum import IntEnum
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Literal,
TypedDict,
cast,
override,
)
from app.database.beatmap import Beatmap
from app.dependencies.database import engine
from app.exception import InvokeException
from .mods import APIMod
from .room import (
DownloadState,
MatchType,
MultiplayerRoomState,
MultiplayerUserState,
QueueMode,
RoomCategory,
RoomStatus,
)
from .signalr import (
SignalRMeta,
SignalRUnionMessage,
UserState,
)
from pydantic import BaseModel, Field
from sqlalchemy import update
from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.signalr.hub import MultiplayerHub
HOST_LIMIT = 50
PER_USER_LIMIT = 3
class MultiplayerClientState(UserState):
room_id: int = 0
class MultiplayerRoomSettings(BaseModel):
name: str = "Unnamed Room"
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
password: str = ""
match_type: MatchType = MatchType.HEAD_TO_HEAD
queue_mode: QueueMode = QueueMode.HOST_ONLY
auto_start_duration: timedelta = timedelta(seconds=0)
auto_skip: bool = False
@property
def auto_start_enabled(self) -> bool:
return self.auto_start_duration != timedelta(seconds=0)
class BeatmapAvailability(BaseModel):
state: DownloadState = DownloadState.UNKNOWN
download_progress: float | None = None
class _MatchUserState(SignalRUnionMessage): ...
class TeamVersusUserState(_MatchUserState):
team_id: int
union_type: ClassVar[Literal[0]] = 0
MatchUserState = TeamVersusUserState
class _MatchRoomState(SignalRUnionMessage): ...
class MultiplayerTeam(BaseModel):
id: int
name: str
class TeamVersusRoomState(_MatchRoomState):
teams: list[MultiplayerTeam] = Field(
default_factory=lambda: [
MultiplayerTeam(id=0, name="Team Red"),
MultiplayerTeam(id=1, name="Team Blue"),
]
)
union_type: ClassVar[Literal[0]] = 0
MatchRoomState = TeamVersusRoomState
class PlaylistItem(BaseModel):
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
owner_id: int
beatmap_id: int
beatmap_checksum: str
ruleset_id: int
required_mods: list[APIMod] = Field(default_factory=list)
allowed_mods: list[APIMod] = Field(default_factory=list)
expired: bool
playlist_order: int
played_at: datetime | None = None
star_rating: float
freestyle: bool
def _get_api_mods(self):
from app.models.mods import API_MODS, init_mods
if not API_MODS:
init_mods()
return API_MODS
def _validate_mod_for_ruleset(
self, mod: APIMod, ruleset_key: int, context: str = "mod"
) -> None:
from typing import Literal, cast
API_MODS = self._get_api_mods()
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
# Check if mod is valid for ruleset
if (
typed_ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[typed_ruleset_key]
):
raise InvokeException(
f"{context} {mod['acronym']} is invalid for this ruleset"
)
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
# Check if mod is unplayable in multiplayer
if mod_settings.get("UserPlayable", True) is False:
raise InvokeException(
f"{context} {mod['acronym']} is not playable by users"
)
if mod_settings.get("ValidForMultiplayer", True) is False:
raise InvokeException(
f"{context} {mod['acronym']} is not valid for multiplayer"
)
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
from typing import Literal, cast
API_MODS = self._get_api_mods()
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
for i, mod1 in enumerate(mods):
mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
if mod1_settings:
incompatible = set(mod1_settings.get("IncompatibleMods", []))
for mod2 in mods[i + 1 :]:
if mod2["acronym"] in incompatible:
raise InvokeException(
f"Mods {mod1['acronym']} and "
f"{mod2['acronym']} are incompatible"
)
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
from typing import Literal, cast
API_MODS = self._get_api_mods()
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
for req_mod in self.required_mods:
req_acronym = req_mod["acronym"]
req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
if req_settings:
incompatible = set(req_settings.get("IncompatibleMods", []))
conflicting_allowed = allowed_acronyms & incompatible
if conflicting_allowed:
conflict_list = ", ".join(conflicting_allowed)
raise InvokeException(
f"Required mod {req_acronym} conflicts with "
f"allowed mods: {conflict_list}"
)
def validate_playlist_item_mods(self) -> None:
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
# Validate required mods
for mod in self.required_mods:
self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
# Validate allowed mods
for mod in self.allowed_mods:
self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
# Check internal compatibility of required mods
self._check_mod_compatibility(self.required_mods, ruleset_key)
# Check compatibility between required and allowed mods
self._check_required_allowed_compatibility(ruleset_key)
def validate_user_mods(
self,
user: "MultiplayerRoomUser",
proposed_mods: list[APIMod],
) -> tuple[bool, list[APIMod]]:
"""
Validates user mods against playlist item rules and returns valid mods.
Returns (is_valid, valid_mods).
"""
from typing import Literal, cast
API_MODS = self._get_api_mods()
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
valid_mods = []
all_proposed_valid = True
# Check if mods are valid for the ruleset
for mod in proposed_mods:
if (
ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[ruleset_key]
):
all_proposed_valid = False
continue
valid_mods.append(mod)
# Check mod compatibility within user mods
incompatible_mods = set()
final_valid_mods = []
for mod in valid_mods:
if mod["acronym"] in incompatible_mods:
all_proposed_valid = False
continue
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
if setting_mods:
incompatible_mods.update(setting_mods["IncompatibleMods"])
final_valid_mods.append(mod)
# If not freestyle, check against allowed mods
if not self.freestyle:
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
filtered_valid_mods = []
for mod in final_valid_mods:
if mod["acronym"] not in allowed_acronyms:
all_proposed_valid = False
else:
filtered_valid_mods.append(mod)
final_valid_mods = filtered_valid_mods
# Check compatibility with required mods
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
all_mod_acronyms = {
mod["acronym"] for mod in final_valid_mods
} | required_mod_acronyms
# Check for incompatibility between required and user mods
filtered_valid_mods = []
for mod in final_valid_mods:
mod_acronym = mod["acronym"]
is_compatible = True
for other_acronym in all_mod_acronyms:
if other_acronym == mod_acronym:
continue
setting_mods = API_MODS[ruleset_key].get(mod_acronym)
if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
is_compatible = False
all_proposed_valid = False
break
if is_compatible:
filtered_valid_mods.append(mod)
return all_proposed_valid, filtered_valid_mods
def clone(self) -> "PlaylistItem":
copy = self.model_copy()
copy.required_mods = list(self.required_mods)
copy.allowed_mods = list(self.allowed_mods)
copy.expired = False
copy.played_at = None
return copy
class _MultiplayerCountdown(SignalRUnionMessage):
id: int = 0
time_remaining: timedelta
is_exclusive: Annotated[
bool, Field(default=True), SignalRMeta(member_ignore=True)
] = True
class MatchStartCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[0]] = 0
class ForceGameplayStartCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[1]] = 1
class ServerShuttingDownCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[2]] = 2
MultiplayerCountdown = (
MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
)
class MultiplayerRoomUser(BaseModel):
user_id: int
state: MultiplayerUserState = MultiplayerUserState.IDLE
availability: BeatmapAvailability = BeatmapAvailability(
state=DownloadState.UNKNOWN, download_progress=None
)
mods: list[APIMod] = Field(default_factory=list)
match_state: MatchUserState | None = None
ruleset_id: int | None = None # freestyle
beatmap_id: int | None = None # freestyle
class MultiplayerRoom(BaseModel):
room_id: int
state: MultiplayerRoomState
settings: MultiplayerRoomSettings
users: list[MultiplayerRoomUser] = Field(default_factory=list)
host: MultiplayerRoomUser | None = None
match_state: MatchRoomState | None = None
playlist: list[PlaylistItem] = Field(default_factory=list)
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
channel_id: int
@classmethod
def from_db(cls, room) -> "MultiplayerRoom":
"""
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
"""
# 用户列表
users = [MultiplayerRoomUser(user_id=room.host_id)]
host_user = MultiplayerRoomUser(user_id=room.host_id)
# playlist 转换
playlist = []
if hasattr(room, "playlist"):
for item in room.playlist:
playlist.append(
PlaylistItem(
id=item.id,
owner_id=item.owner_id,
beatmap_id=item.beatmap_id,
beatmap_checksum=item.beatmap.checksum if item.beatmap else "",
ruleset_id=item.ruleset_id,
required_mods=item.required_mods,
allowed_mods=item.allowed_mods,
expired=item.expired,
playlist_order=item.playlist_order,
played_at=item.played_at,
star_rating=item.beatmap.difficulty_rating
if item.beatmap is not None
else 0.0,
freestyle=item.freestyle,
)
)
return cls(
room_id=room.id,
state=getattr(room, "state", MultiplayerRoomState.OPEN),
settings=MultiplayerRoomSettings(
name=room.name,
playlist_item_id=playlist[0].id if playlist else 0,
password=getattr(room, "password", ""),
match_type=room.type,
queue_mode=room.queue_mode,
auto_start_duration=timedelta(seconds=room.auto_start_duration),
auto_skip=room.auto_skip,
),
users=users,
host=host_user,
match_state=None,
playlist=playlist,
active_countdowns=[],
channel_id=getattr(room, "channel_id", 0),
)
class MultiplayerQueue:
def __init__(self, room: "ServerMultiplayerRoom"):
self.server_room = room
self.current_index = 0
@property
def hub(self) -> "MultiplayerHub":
return self.server_room.hub
@property
def upcoming_items(self):
return sorted(
(item for item in self.room.playlist if not item.expired),
key=lambda i: i.playlist_order,
)
@property
def room(self):
return self.server_room.room
async def update_order(self):
from app.database import Playlist
match self.room.settings.queue_mode:
case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
ordered_active_items = []
is_first_set = True
first_set_order_by_user_id = {}
active_items = [item for item in self.room.playlist if not item.expired]
active_items.sort(key=lambda x: x.id)
user_item_groups = {}
for item in active_items:
if item.owner_id not in user_item_groups:
user_item_groups[item.owner_id] = []
user_item_groups[item.owner_id].append(item)
max_items = max(
(len(items) for items in user_item_groups.values()), default=0
)
for i in range(max_items):
current_set = []
for user_id, items in user_item_groups.items():
if i < len(items):
current_set.append(items[i])
if is_first_set:
current_set.sort(
key=lambda item: (item.playlist_order, item.id)
)
ordered_active_items.extend(current_set)
first_set_order_by_user_id = {
item.owner_id: idx
for idx, item in enumerate(ordered_active_items)
}
else:
current_set.sort(
key=lambda item: first_set_order_by_user_id.get(
item.owner_id, 0
)
)
ordered_active_items.extend(current_set)
is_first_set = False
for idx, item in enumerate(ordered_active_items):
item.playlist_order = idx
case _:
ordered_active_items = sorted(
(item for item in self.room.playlist if not item.expired),
key=lambda x: x.id,
)
async with AsyncSession(engine) as session:
for idx, item in enumerate(ordered_active_items):
if item.playlist_order == idx:
continue
item.playlist_order = idx
await Playlist.update(item, self.room.room_id, session)
await self.hub.playlist_changed(
self.server_room, item, beatmap_changed=False
)
async def update_current_item(self):
upcoming_items = self.upcoming_items
next_item = (
upcoming_items[0]
if upcoming_items
else max(
self.room.playlist,
key=lambda i: i.played_at or datetime.min,
)
)
self.current_index = self.room.playlist.index(next_item)
last_id = self.room.settings.playlist_item_id
self.room.settings.playlist_item_id = next_item.id
if last_id != next_item.id:
await self.hub.setting_changed(self.server_room, True)
async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
from app.database import Playlist
is_host = self.room.host and self.room.host.user_id == user.user_id
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
raise InvokeException("You are not the host")
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
if (
len([True for u in self.room.playlist if u.owner_id == user.user_id])
>= limit
):
raise InvokeException(f"You can only have {limit} items in the queue")
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with AsyncSession(engine) as session:
async with session:
beatmap = await session.get(Beatmap, item.beatmap_id)
if beatmap is None:
raise InvokeException("Beatmap not found")
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
item.validate_playlist_item_mods()
item.owner_id = user.user_id
item.star_rating = float(
beatmap.difficulty_rating
) # FIXME: beatmap use decimal
await Playlist.add_to_db(item, self.room.room_id, session)
self.room.playlist.append(item)
await self.hub.playlist_added(self.server_room, item)
await self.update_order()
await self.update_current_item()
async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
from app.database import Playlist
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with AsyncSession(engine) as session:
async with session:
beatmap = await session.get(Beatmap, item.beatmap_id)
if beatmap is None:
raise InvokeException("Beatmap not found")
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
existing_item = next(
(i for i in self.room.playlist if i.id == item.id), None
)
if existing_item is None:
raise InvokeException(
"Attempted to change an item that doesn't exist"
)
if existing_item.owner_id != user.user_id and self.room.host != user:
raise InvokeException(
"Attempted to change an item which is not owned by the user"
)
if existing_item.expired:
raise InvokeException(
"Attempted to change an item which has already been played"
)
item.validate_playlist_item_mods()
item.owner_id = user.user_id
item.star_rating = float(beatmap.difficulty_rating)
item.playlist_order = existing_item.playlist_order
await Playlist.update(item, self.room.room_id, session)
# Update item in playlist
for idx, playlist_item in enumerate(self.room.playlist):
if playlist_item.id == item.id:
self.room.playlist[idx] = item
break
await self.hub.playlist_changed(
self.server_room,
item,
beatmap_changed=item.beatmap_checksum
!= existing_item.beatmap_checksum,
)
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
from app.database import Playlist
item = next(
(i for i in self.room.playlist if i.id == playlist_item_id),
None,
)
if item is None:
raise InvokeException("Item does not exist in the room")
# Check if it's the only item and current item
if item == self.current_item:
upcoming_items = [i for i in self.room.playlist if not i.expired]
if len(upcoming_items) == 1:
raise InvokeException("The only item in the room cannot be removed")
if item.owner_id != user.user_id and self.room.host != user:
raise InvokeException(
"Attempted to remove an item which is not owned by the user"
)
if item.expired:
raise InvokeException(
"Attempted to remove an item which has already been played"
)
async with AsyncSession(engine) as session:
await Playlist.delete_item(item.id, self.room.room_id, session)
self.room.playlist.remove(item)
self.current_index = self.room.playlist.index(self.upcoming_items[0])
await self.update_order()
await self.update_current_item()
await self.hub.playlist_removed(self.server_room, item.id)
async def finish_current_item(self):
from app.database import Playlist
async with AsyncSession(engine) as session:
played_at = datetime.now(UTC)
await session.execute(
update(Playlist)
.where(
col(Playlist.id) == self.current_item.id,
col(Playlist.room_id) == self.room.room_id,
)
.values(expired=True, played_at=played_at)
)
self.room.playlist[self.current_index].expired = True
self.room.playlist[self.current_index].played_at = played_at
await self.hub.playlist_changed(self.server_room, self.current_item, True)
await self.update_order()
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
async def update_queue_mode(self):
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
await self.update_order()
await self.update_current_item()
@property
def current_item(self):
return self.room.playlist[self.current_index]
@dataclass
class CountdownInfo:
countdown: MultiplayerCountdown
duration: timedelta
task: asyncio.Task | None = None
def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown
self.duration = (
countdown.time_remaining
if countdown.time_remaining > timedelta(seconds=0)
else timedelta(seconds=0)
)
class _MatchRequest(SignalRUnionMessage): ...
class ChangeTeamRequest(_MatchRequest):
union_type: ClassVar[Literal[0]] = 0
team_id: int
class StartMatchCountdownRequest(_MatchRequest):
union_type: ClassVar[Literal[1]] = 1
duration: timedelta
class StopCountdownRequest(_MatchRequest):
union_type: ClassVar[Literal[2]] = 2
id: int
MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
class MatchTypeHandler(ABC):
def __init__(self, room: "ServerMultiplayerRoom"):
self.room = room
self.hub = room.hub
@abstractmethod
async def handle_join(self, user: MultiplayerRoomUser): ...
@abstractmethod
async def handle_request(
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
@abstractmethod
async def handle_leave(self, user: MultiplayerRoomUser): ...
@abstractmethod
def get_details(self) -> MatchStartedEventDetail: ...
class HeadToHeadHandler(MatchTypeHandler):
@override
async def handle_join(self, user: MultiplayerRoomUser):
if user.match_state is not None:
user.match_state = None
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
@override
def get_details(self) -> MatchStartedEventDetail:
detail = MatchStartedEventDetail(room_type="head_to_head", team=None)
return detail
class TeamVersusHandler(MatchTypeHandler):
@override
def __init__(self, room: "ServerMultiplayerRoom"):
super().__init__(room)
self.state = TeamVersusRoomState()
room.room.match_state = self.state
task = asyncio.create_task(self.hub.change_room_match_state(self.room))
self.hub.tasks.add(task)
task.add_done_callback(self.hub.tasks.discard)
def _get_best_available_team(self) -> int:
for team in self.state.teams:
if all(
(
user.match_state is None
or not isinstance(user.match_state, TeamVersusUserState)
or user.match_state.team_id != team.id
)
for user in self.room.room.users
):
return team.id
from collections import defaultdict
team_counts = defaultdict(int)
for user in self.room.room.users:
if user.match_state is not None and isinstance(
user.match_state, TeamVersusUserState
):
team_counts[user.match_state.team_id] += 1
if team_counts:
min_count = min(team_counts.values())
for team_id, count in team_counts.items():
if count == min_count:
return team_id
return self.state.teams[0].id if self.state.teams else 0
@override
async def handle_join(self, user: MultiplayerRoomUser):
best_team_id = self._get_best_available_team()
user.match_state = TeamVersusUserState(team_id=best_team_id)
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
if not isinstance(request, ChangeTeamRequest):
return
if request.team_id not in [team.id for team in self.state.teams]:
raise InvokeException("Invalid team ID")
user.match_state = TeamVersusUserState(team_id=request.team_id)
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
@override
def get_details(self) -> MatchStartedEventDetail:
teams: dict[int, Literal["blue", "red"]] = {}
for user in self.room.room.users:
if user.match_state is not None and isinstance(
user.match_state, TeamVersusUserState
):
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
return detail
MATCH_TYPE_HANDLERS = {
MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
MatchType.TEAM_VERSUS: TeamVersusHandler,
}
@dataclass
class ServerMultiplayerRoom:
room: MultiplayerRoom
category: RoomCategory
status: RoomStatus
start_at: datetime
hub: "MultiplayerHub"
match_type_handler: MatchTypeHandler
queue: MultiplayerQueue
_next_countdown_id: int
_countdown_id_lock: asyncio.Lock
_tracked_countdown: dict[int, CountdownInfo]
def __init__(
self,
room: MultiplayerRoom,
category: RoomCategory,
start_at: datetime,
hub: "MultiplayerHub",
):
self.room = room
self.category = category
self.status = RoomStatus.IDLE
self.start_at = start_at
self.hub = hub
self.queue = MultiplayerQueue(self)
self._next_countdown_id = 0
self._countdown_id_lock = asyncio.Lock()
self._tracked_countdown = {}
async def set_handler(self):
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](
self
)
for i in self.room.users:
await self.match_type_handler.handle_join(i)
async def get_next_countdown_id(self) -> int:
async with self._countdown_id_lock:
self._next_countdown_id += 1
return self._next_countdown_id
async def start_countdown(
self,
countdown: MultiplayerCountdown,
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
):
async def _countdown_task(self: "ServerMultiplayerRoom"):
await asyncio.sleep(info.duration.total_seconds())
if on_complete is not None:
await on_complete(self)
await self.stop_countdown(countdown)
if countdown.is_exclusive:
await self.stop_all_countdowns()
countdown.id = await self.get_next_countdown_id()
info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown)
self._tracked_countdown[countdown.id] = info
await self.hub.send_match_event(
self, CountdownStartedEvent(countdown=info.countdown)
)
info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown):
info = self._tracked_countdown.get(countdown.id)
if info is None:
return
del self._tracked_countdown[countdown.id]
self.room.active_countdowns.remove(countdown)
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
if info.task is not None and not info.task.done():
info.task.cancel()
async def stop_all_countdowns(self):
for countdown in list(self._tracked_countdown.values()):
await self.stop_countdown(countdown.countdown)
self._tracked_countdown.clear()
self.room.active_countdowns.clear()
class _MatchServerEvent(SignalRUnionMessage): ...
class CountdownStartedEvent(_MatchServerEvent):
countdown: MultiplayerCountdown
union_type: ClassVar[Literal[0]] = 0
class CountdownStoppedEvent(_MatchServerEvent):
id: int
union_type: ClassVar[Literal[1]] = 1
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
class GameplayAbortReason(IntEnum):
LOAD_TOOK_TOO_LONG = 0
HOST_ABORTED = 1
class MatchStartedEventDetail(TypedDict):
room_type: Literal["playlists", "head_to_head", "team_versus"]
team: dict[int, Literal["blue", "red"]] | None

View File

@@ -1,7 +1,6 @@
# OAuth 相关模型
from __future__ import annotations
from typing import List
from pydantic import BaseModel
@@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel):
class RegistrationErrorResponse(BaseModel):
"""注册错误响应模型"""
form_error: dict
class UserRegistrationErrors(BaseModel):
"""用户注册错误模型"""
username: List[str] = []
user_email: List[str] = []
password: List[str] = []
username: list[str] = []
user_email: list[str] = []
password: list[str] = []
class RegistrationRequestErrors(BaseModel):
"""注册请求错误模型"""
message: str | None = None
redirect: str | None = None
user: UserRegistrationErrors | None = None

View File

@@ -1,17 +1,8 @@
from __future__ import annotations
from datetime import datetime, timedelta
from enum import Enum
from app.database.beatmap import Beatmap, BeatmapResp
from app.database.user import User as DBUser
from app.fetcher import Fetcher
from app.models.mods import APIMod
from app.models.user import User
from app.utils import convert_db_user_to_api_user
from pydantic import BaseModel, Field
from sqlmodel.ext.asyncio.session import AsyncSession
from pydantic import BaseModel
class RoomCategory(str, Enum):
@@ -62,53 +53,24 @@ class MultiplayerUserState(str, Enum):
RESULTS = "results"
SPECTATING = "spectating"
@property
def is_playing(self) -> bool:
return self in {
self.WAITING_FOR_LOAD,
self.PLAYING,
self.READY_FOR_GAMEPLAY,
self.LOADED,
}
class DownloadState(str, Enum):
UNKOWN = "unkown"
UNKNOWN = "unknown"
NOT_DOWNLOADED = "not_downloaded"
DOWNLOADING = "downloading"
IMPORTING = "importing"
LOCALLY_AVAILABLE = "locally_available"
class PlaylistItem(BaseModel):
id: int
owner_id: int
ruleset_id: int
expired: bool
playlist_order: int | None
played_at: datetime | None
allowed_mods: list[APIMod] = []
required_mods: list[APIMod] = []
beatmap_id: int
beatmap: BeatmapResp | None
freestyle: bool
class Config:
exclude_none = True
@classmethod
async def from_mpListItem(
cls, item: MultiPlayerListItem, db: AsyncSession, fetcher: Fetcher
):
s = cls.model_validate(item.model_dump())
s.id = item.id
s.owner_id = item.OwnerID
s.ruleset_id = item.RulesetID
s.expired = item.Expired
s.playlist_order = item.PlaylistOrder
s.played_at = item.PlayedAt
s.required_mods = item.RequierdMods
s.allowed_mods = item.AllowedMods
s.freestyle = item.Freestyle
cur_beatmap = await Beatmap.get_or_fetch(
db, fetcher=fetcher, bid=item.BeatmapID
)
s.beatmap = BeatmapResp.from_db(cur_beatmap)
s.beatmap_id = item.BeatmapID
return s
class RoomPlaylistItemStats(BaseModel):
count_active: int
count_total: int
@@ -120,269 +82,7 @@ class RoomDifficultyRange(BaseModel):
max: float
class ItemAttemptsCount(BaseModel):
id: int
attempts: int
passed: bool
class PlaylistAggregateScore(BaseModel):
playlist_item_attempts: list[ItemAttemptsCount]
class MultiplayerRoomSettings(BaseModel):
Name: str = "Unnamed Room"
PlaylistItemId: int
Password: str = ""
MatchType: MatchType
QueueMode: QueueMode
AutoStartDuration: timedelta
AutoSkip: bool
@classmethod
def from_apiRoom(cls, room: Room):
s = cls.model_validate(room.model_dump())
s.Name = room.name
s.Password = room.password if room.password is not None else ""
s.MatchType = room.type
s.QueueMode = room.queue_mode
s.AutoStartDuration = timedelta(seconds=room.auto_start_duration)
s.AutoSkip = room.auto_skip
return s
class BeatmapAvailability(BaseModel):
State: DownloadState
DownloadProgress: float | None
class MatchUserState(BaseModel):
class Config:
extra = "allow"
class TeamVersusState(MatchUserState):
TeamId: int
MatchUserStateType = TeamVersusState | MatchUserState
class MultiplayerRoomUser(BaseModel):
UserID: int
State: MultiplayerUserState = MultiplayerUserState.IDLE
BeatmapAvailability: BeatmapAvailability
Mods: list[APIMod] = []
MatchUserState: MatchUserStateType | None
RulesetId: int | None
BeatmapId: int | None
User: User | None
@classmethod
async def from_id(cls, id: int, db: AsyncSession):
actualUser = (
await db.exec(
DBUser.all_select_clause().where(
DBUser.id == id,
)
)
).first()
user = (
await convert_db_user_to_api_user(actualUser)
if actualUser is not None
else None
)
return MultiplayerRoomUser(
UserID=id,
MatchUserState=None,
BeatmapAvailability=BeatmapAvailability(
State=DownloadState.UNKOWN, DownloadProgress=None
),
RulesetId=None,
BeatmapId=None,
User=user,
)
class MatchRoomState(BaseModel):
class Config:
extra = "allow"
class MultiPlayerTeam(BaseModel):
id: int = 0
name: str = ""
class TeamVersusRoomState(BaseModel):
teams: list[MultiPlayerTeam] = []
class Config:
pass
@classmethod
def create_default(cls):
return cls(
teams=[
MultiPlayerTeam(id=0, name="Team Red"),
MultiPlayerTeam(id=1, name="Team Blue"),
]
)
MatchRoomStateType = TeamVersusRoomState | MatchRoomState
class MultiPlayerListItem(BaseModel):
id: int
OwnerID: int
BeatmapID: int
BeatmapChecksum: str = ""
RulesetID: int
RequierdMods: list[APIMod]
AllowedMods: list[APIMod]
Expired: bool
PlaylistOrder: int | None
PlayedAt: datetime | None
StarRating: float
Freestyle: bool
@classmethod
async def from_apiItem(cls, item: PlaylistItem, db: AsyncSession, fetcher: Fetcher):
s = cls.model_validate(item.model_dump())
s.id = item.id
s.OwnerID = item.owner_id
if item.beatmap is None: # 从客户端接受的一定没有这字段
cur_beatmap = await Beatmap.get_or_fetch(
db, fetcher=fetcher, bid=item.beatmap_id
)
s.BeatmapID = cur_beatmap.id if cur_beatmap.id is not None else 0
s.BeatmapChecksum = cur_beatmap.checksum
s.StarRating = cur_beatmap.difficulty_rating
s.RulesetID = item.ruleset_id
s.RequierdMods = item.required_mods
s.AllowedMods = item.allowed_mods
s.Expired = item.expired
s.PlaylistOrder = item.playlist_order if item.playlist_order is not None else 0
s.PlayedAt = item.played_at
s.Freestyle = item.freestyle
return s
class MultiplayerCountdown(BaseModel):
id: int = 0
time_remaining: timedelta = timedelta(seconds=0)
is_exclusive: bool = True
class Config:
extra = "allow"
class MatchStartCountdown(MultiplayerCountdown):
pass
class ForceGameplayStartCountdown(MultiplayerCountdown):
pass
class ServerShuttingCountdown(MultiplayerCountdown):
pass
MultiplayerCountdownType = (
MatchStartCountdown
| ForceGameplayStartCountdown
| ServerShuttingCountdown
| MultiplayerCountdown
)
class PlaylistStatus(BaseModel):
count_active: int
count_total: int
ruleset_ids: list[int]
class MultiplayerRoom(BaseModel):
RoomId: int
State: MultiplayerRoomState
Settings: MultiplayerRoomSettings = MultiplayerRoomSettings(
PlaylistItemId=0,
MatchType=MatchType.HEAD_TO_HEAD,
QueueMode=QueueMode.HOST_ONLY,
AutoStartDuration=timedelta(0),
AutoSkip=False,
)
Users: list[MultiplayerRoomUser]
Host: MultiplayerRoomUser
MatchState: MatchRoomState | None
Playlist: list[MultiPlayerListItem]
ActivecCountDowns: list[MultiplayerCountdownType]
ChannelID: int
@classmethod
def CanAddPlayistItem(cls, user: MultiplayerRoomUser) -> bool:
return user == cls.Host or cls.Settings.QueueMode != QueueMode.HOST_ONLY
@classmethod
async def from_apiRoom(cls, room: Room, db: AsyncSession, fetcher: Fetcher):
s = cls.model_validate(room.model_dump())
s.RoomId = room.room_id if room.room_id is not None else 0
s.ChannelID = room.channel_id
s.Settings = MultiplayerRoomSettings.from_apiRoom(room)
s.Host = await MultiplayerRoomUser.from_id(room.host.id if room.host else 0, db)
s.Playlist = [
await MultiPlayerListItem.from_apiItem(item, db, fetcher)
for item in room.playlist
]
return s
class Room(BaseModel):
room_id: int
name: str
password: str | None
has_password: bool = Field(exclude=True)
host: User | None
category: RoomCategory
duration: int | None
starts_at: datetime | None
ends_at: datetime | None
max_particapants: int | None = Field(exclude=True)
particapant_count: int
recent_particapants: list[User]
type: MatchType
max_attempts: int | None
playlist: list[PlaylistItem]
playlist_item_status: list[RoomPlaylistItemStats]
difficulity_range: RoomDifficultyRange
queue_mode: QueueMode
auto_skip: bool
auto_start_duration: int
current_user_score: PlaylistAggregateScore | None
current_playlist_item: PlaylistItem | None
channel_id: int
status: RoomStatus
availability: RoomAvailability = Field(exclude=True)
class Config:
exclude_none = True
@classmethod
async def from_mpRoom(
cls, room: MultiplayerRoom, db: AsyncSession, fetcher: Fetcher
):
s = cls.model_validate(room.model_dump())
s.room_id = room.RoomId
s.name = room.Settings.Name
s.password = room.Settings.Password
s.type = room.Settings.MatchType
s.queue_mode = room.Settings.QueueMode
s.auto_skip = room.Settings.AutoSkip
s.host = room.Host.User
s.playlist = [
await PlaylistItem.from_mpListItem(item, db, fetcher)
for item in room.Playlist
]
return s

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from enum import Enum, IntEnum
from enum import Enum
from typing import Literal, TypedDict
from .mods import API_MODS, APIMod, init_mods
@@ -93,52 +93,14 @@ class HitResult(str, Enum):
)
class HitResultInt(IntEnum):
PERFECT = 0
GREAT = 1
GOOD = 2
OK = 3
MEH = 4
MISS = 5
LARGE_TICK_HIT = 6
SMALL_TICK_HIT = 7
SLIDER_TAIL_HIT = 8
LARGE_BONUS = 9
SMALL_BONUS = 10
LARGE_TICK_MISS = 11
SMALL_TICK_MISS = 12
IGNORE_HIT = 13
IGNORE_MISS = 14
NONE = 15
COMBO_BREAK = 16
LEGACY_COMBO_INCREASE = 99
def is_hit(self) -> bool:
return self not in (
HitResultInt.NONE,
HitResultInt.IGNORE_MISS,
HitResultInt.COMBO_BREAK,
HitResultInt.LARGE_TICK_MISS,
HitResultInt.SMALL_TICK_MISS,
HitResultInt.MISS,
)
class LeaderboardType(Enum):
GLOBAL = "global"
FRIENDS = "friends"
FRIENDS = "friend"
COUNTRY = "country"
TEAM = "team"
ScoreStatistics = dict[HitResult, int]
ScoreStatisticsInt = dict[HitResultInt, int]
class SoloScoreSubmissionInfo(BaseModel):
@@ -176,8 +138,8 @@ class SoloScoreSubmissionInfo(BaseModel):
class LegacyReplaySoloScoreInfo(TypedDict):
online_id: int
mods: list[APIMod]
statistics: ScoreStatisticsInt
maximum_statistics: ScoreStatisticsInt
statistics: ScoreStatistics
maximum_statistics: ScoreStatistics
client_version: str
rank: Rank
user_id: int

View File

@@ -1,54 +1,23 @@
from __future__ import annotations
import datetime
from typing import Any, get_origin
from dataclasses import dataclass
from typing import ClassVar
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
model_serializer,
model_validator,
)
def serialize_to_list(value: BaseModel) -> list[Any]:
data = []
for field, info in value.__class__.model_fields.items():
v = getattr(value, field)
anno = get_origin(info.annotation)
if anno and issubclass(anno, BaseModel):
data.append(serialize_to_list(v))
elif anno and issubclass(anno, list):
data.append(
TypeAdapter(
info.annotation, config=ConfigDict(arbitrary_types_allowed=True)
).dump_python(v)
)
elif isinstance(v, datetime.datetime):
data.append([v, 0])
else:
data.append(v)
return data
@dataclass
class SignalRMeta:
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
json_ignore: bool = False # implement of JsonIgnore (json) attribute
use_abbr: bool = True
class MessagePackArrayModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="before")
@classmethod
def unpack(cls, v: Any) -> Any:
if isinstance(v, list):
fields = list(cls.model_fields.keys())
if len(v) != len(fields):
raise ValueError(f"Expected list of length {len(fields)}, got {len(v)}")
return dict(zip(fields, v))
return v
@model_serializer
def serialize(self) -> list[Any]:
return serialize_to_list(self)
class SignalRUnionMessage(BaseModel):
union_type: ClassVar[int]
class Transport(BaseModel):

View File

@@ -2,17 +2,17 @@ from __future__ import annotations
import datetime
from enum import IntEnum
from typing import Any
from typing import Annotated, Any
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from .score import (
ScoreStatisticsInt,
ScoreStatistics,
)
from .signalr import MessagePackArrayModel, UserState
from .signalr import SignalRMeta, UserState
from msgpack_lazer_api import APIMod
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Field, field_validator
class SpectatedUserState(IntEnum):
@@ -24,14 +24,12 @@ class SpectatedUserState(IntEnum):
Quit = 5
class SpectatorState(MessagePackArrayModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class SpectatorState(BaseModel):
beatmap_id: int | None = None
ruleset_id: int | None = None # 0,1,2,3
mods: list[APIMod] = Field(default_factory=list)
state: SpectatedUserState
maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict)
maximum_statistics: ScoreStatistics = Field(default_factory=dict)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SpectatorState):
@@ -44,22 +42,20 @@ class SpectatorState(MessagePackArrayModel):
)
class ScoreProcessorStatistics(MessagePackArrayModel):
base_score: int
maximum_base_score: int
class ScoreProcessorStatistics(BaseModel):
base_score: float
maximum_base_score: float
accuracy_judgement_count: int
combo_portion: float
bouns_portion: float
bonus_portion: float
class FrameHeader(MessagePackArrayModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class FrameHeader(BaseModel):
total_score: int
acc: float
accuracy: float
combo: int
max_combo: int
statistics: ScoreStatisticsInt = Field(default_factory=dict)
statistics: ScoreStatistics = Field(default_factory=dict)
score_processor_statistics: ScoreProcessorStatistics
received_time: datetime.datetime
mods: list[APIMod] = Field(default_factory=list)
@@ -87,14 +83,18 @@ class FrameHeader(MessagePackArrayModel):
# SMOKE = 16
class LegacyReplayFrame(MessagePackArrayModel):
class LegacyReplayFrame(BaseModel):
time: float # from ReplayFrame,the parent of LegacyReplayFrame
x: float | None = None
y: float | None = None
mouse_x: float | None = None
mouse_y: float | None = None
button_state: int
header: Annotated[
FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)
]
class FrameDataBundle(MessagePackArrayModel):
class FrameDataBundle(BaseModel):
header: FrameHeader
frames: list[LegacyReplayFrame]
@@ -106,18 +106,16 @@ class APIUser(BaseModel):
class ScoreInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
mods: list[APIMod]
user: APIUser
ruleset: int
maximum_statistics: ScoreStatisticsInt
maximum_statistics: ScoreStatistics
id: int | None = None
total_score: int | None = None
acc: float | None = None
accuracy: float | None = None
max_combo: int | None = None
combo: int | None = None
statistics: ScoreStatisticsInt = Field(default_factory=dict)
statistics: ScoreStatistics = Field(default_factory=dict)
class StoreScore(BaseModel):

View File

@@ -2,15 +2,11 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
from .score import GameMode
from .model import UTCBaseModel
from pydantic import BaseModel
if TYPE_CHECKING:
from app.database import LazerUserAchievement, Team
class PlayStyle(str, Enum):
MOUSE = "mouse"
@@ -77,24 +73,7 @@ class MonthlyPlaycount(BaseModel):
count: int
class UserAchievement(BaseModel):
achieved_at: datetime
achievement_id: int
# 添加数据库模型转换方法
def to_db_model(self, user_id: int) -> "LazerUserAchievement":
from app.database import (
LazerUserAchievement,
)
return LazerUserAchievement(
user_id=user_id,
achievement_id=self.achievement_id,
achieved_at=self.achieved_at,
)
class RankHighest(BaseModel):
class RankHighest(UTCBaseModel):
rank: int
updated_at: datetime
@@ -104,115 +83,6 @@ class RankHistory(BaseModel):
data: list[int]
class DailyChallengeStats(BaseModel):
daily_streak_best: int = 0
daily_streak_current: int = 0
last_update: datetime | None = None
last_weekly_streak: datetime | None = None
playcount: int = 0
top_10p_placements: int = 0
top_50p_placements: int = 0
user_id: int
weekly_streak_best: int = 0
weekly_streak_current: int = 0
class Page(BaseModel):
html: str = ""
raw: str = ""
class User(BaseModel):
# 基本信息
id: int
username: str
avatar_url: str
country_code: str
default_group: str = "default"
is_active: bool = True
is_bot: bool = False
is_deleted: bool = False
is_online: bool = True
is_supporter: bool = False
is_restricted: bool = False
last_visit: datetime | None = None
pm_friends_only: bool = False
profile_colour: str | None = None
# 个人资料
cover_url: str | None = None
discord: str | None = None
has_supported: bool = False
interests: str | None = None
join_date: datetime
location: str | None = None
max_blocks: int = 100
max_friends: int = 500
occupation: str | None = None
playmode: GameMode = GameMode.OSU
playstyle: list[PlayStyle] = []
post_count: int = 0
profile_hue: int | None = None
profile_order: list[str] = [
"me",
"recent_activity",
"top_ranks",
"medals",
"historical",
"beatmaps",
"kudosu",
]
title: str | None = None
title_url: str | None = None
twitter: str | None = None
website: str | None = None
session_verified: bool = False
support_level: int = 0
# 关联对象
country: Country
cover: Cover
kudosu: Kudosu
statistics: Statistics
statistics_rulesets: dict[str, Statistics]
# 计数信息
beatmap_playcounts_count: int = 0
comments_count: int = 0
favourite_beatmapset_count: int = 0
follower_count: int = 0
graveyard_beatmapset_count: int = 0
guest_beatmapset_count: int = 0
loved_beatmapset_count: int = 0
mapping_follower_count: int = 0
nominated_beatmapset_count: int = 0
pending_beatmapset_count: int = 0
ranked_beatmapset_count: int = 0
ranked_and_approved_beatmapset_count: int = 0
unranked_beatmapset_count: int = 0
scores_best_count: int = 0
scores_first_count: int = 0
scores_pinned_count: int = 0
scores_recent_count: int = 0
# 历史数据
account_history: list[dict] = []
active_tournament_banner: dict | None = None
active_tournament_banners: list[dict] = []
badges: list[dict] = []
current_season_stats: dict | None = None
daily_challenge_user_stats: DailyChallengeStats | None = None
groups: list[dict] = []
monthly_playcounts: list[MonthlyPlaycount] = []
page: Page = Page()
previous_usernames: list[str] = []
rank_highest: RankHighest | None = None
rank_history: RankHistory | None = None
rankHistory: RankHistory | None = None # 兼容性别名
replays_watched_counts: list[dict] = []
team: "Team | None" = None
user_achievements: list[UserAchievement] = []
class APIUser(BaseModel):
id: int

View File

@@ -7,6 +7,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
beatmapset,
me,
relationship,
room,
score,
user,
)
@@ -14,4 +15,9 @@ from .api_router import router as api_router
from .auth import router as auth_router
from .fetcher import fetcher_router as fetcher_router
__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"]
__all__ = [
"api_router",
"auth_router",
"fetcher_router",
"signalr_router",
]

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from datetime import timedelta
from datetime import UTC, datetime, timedelta
import re
from app.auth import (
@@ -12,17 +12,21 @@ from app.auth import (
store_token,
)
from app.config import settings
from app.database import User as DBUser
from app.database import DailyChallengeStats, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
from app.log import logger
from app.models.oauth import (
OAuthErrorResponse,
RegistrationRequestErrors,
TokenResponse,
UserRegistrationErrors,
)
from app.models.score import GameMode
from fastapi import APIRouter, Depends, Form
from fastapi.responses import JSONResponse
from sqlalchemy import text
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -110,12 +114,12 @@ async def register_user(
email_errors = validate_email(user_email)
password_errors = validate_password(user_password)
result = await db.exec(select(DBUser).where(DBUser.name == user_username))
result = await db.exec(select(User).where(User.username == user_username))
existing_user = result.first()
if existing_user:
username_errors.append("Username is already taken")
result = await db.exec(select(DBUser).where(DBUser.email == user_email))
result = await db.exec(select(User).where(User.email == user_email))
existing_email = result.first()
if existing_email:
email_errors.append("Email is already taken")
@@ -135,119 +139,41 @@ async def register_user(
try:
# 创建新用户
from datetime import datetime
import time
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
result = await db.execute( # pyright: ignore[reportDeprecated]
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
)
)
next_id = result.one()[0]
if next_id <= 2:
await db.execute(text("ALTER TABLE lazer_users AUTO_INCREMENT = 3"))
await db.commit()
new_user = DBUser(
name=user_username,
safe_name=user_username.lower(), # 安全用户名(小写)
new_user = User(
username=user_username,
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
country="CN", # 默认国家
creation_time=int(time.time()),
latest_activity=int(time.time()),
preferred_mode=0, # 默认模式
play_style=0, # 默认游戏风格
country_code="CN", # 默认国家
join_date=datetime.now(UTC),
last_visit=datetime.now(UTC),
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
# 保存用户ID因为会话可能会关闭
user_id = new_user.id
if user_id <= 2:
await db.rollback()
try:
from sqlalchemy import text
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
await db.commit()
# 重新创建用户
new_user = DBUser(
name=user_username,
safe_name=user_username.lower(),
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1,
country="CN",
creation_time=int(time.time()),
latest_activity=int(time.time()),
preferred_mode=0,
play_style=0,
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
user_id = new_user.id
# 最终检查ID是否有效
if user_id <= 2:
await db.rollback()
errors = RegistrationRequestErrors(
message=(
"Failed to create account with valid ID. "
"Please contact support."
)
)
return JSONResponse(
status_code=500, content={"form_error": errors.model_dump()}
)
except Exception as fix_error:
await db.rollback()
print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
errors = RegistrationRequestErrors(
message="Failed to create account with valid ID. Please try again."
)
return JSONResponse(
status_code=500, content={"form_error": errors.model_dump()}
)
# 创建默认的 lazer_profile
from app.database.user import LazerUserProfile
lazer_profile = LazerUserProfile(
user_id=user_id,
is_active=True,
is_bot=False,
is_deleted=False,
is_online=True,
is_supporter=False,
is_restricted=False,
session_verified=False,
has_supported=False,
pm_friends_only=False,
default_group="default",
join_date=datetime.utcnow(),
playmode="osu",
support_level=0,
max_blocks=50,
max_friends=250,
post_count=0,
)
db.add(lazer_profile)
assert new_user.id is not None, "New user ID should not be None"
for i in GameMode:
statistics = UserStatistics(mode=i, user_id=new_user.id)
db.add(statistics)
daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id)
db.add(daily_challenge_user_stats)
await db.commit()
# 返回成功响应
return JSONResponse(
status_code=201,
content={"message": "Account created successfully", "user_id": user_id},
)
except Exception as e:
except Exception:
await db.rollback()
# 打印详细错误信息用于调试
print(f"Registration error: {e}")
import traceback
traceback.print_exc()
logger.exception(f"Registration error for user {user_username}")
# 返回通用错误
errors = RegistrationRequestErrors(
@@ -323,6 +249,7 @@ async def oauth_token(
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user.id
await store_token(
db,
user.id,

View File

@@ -5,12 +5,7 @@ import hashlib
import json
from app.calculator import calculate_beatmap_attribute
from app.database import (
Beatmap,
BeatmapResp,
User as DBUser,
)
from app.database.beatmapset import Beatmapset
from app.database import Beatmap, BeatmapResp, User
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -27,9 +22,8 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from redis import Redis
from redis.asyncio import Redis
import rosu_pp_py as rosu
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -39,7 +33,7 @@ async def lookup_beatmap(
id: int | None = Query(default=None, alias="id"),
md5: str | None = Query(default=None, alias="checksum"),
filename: str | None = Query(default=None, alias="filename"),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -56,19 +50,19 @@ async def lookup_beatmap(
if beatmap is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap)
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
return BeatmapResp.from_db(beatmap)
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -80,43 +74,28 @@ class BatchGetResp(BaseModel):
@router.get("/beatmaps", tags=["beatmap"], response_model=BatchGetResp)
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="id", default_factory=list),
current_user: DBUser = Depends(get_current_user),
b_ids: list[int] = Query(alias="ids[]", default_factory=list),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if not b_ids:
# select 50 beatmaps by last_updated
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
)
).all()
else:
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
return BatchGetResp(
beatmaps=[
await BeatmapResp.from_db(bm, session=db, user=current_user)
for bm in beatmaps
]
)
@router.post(
@@ -126,7 +105,7 @@ async def batch_get_beatmaps(
)
async def get_beatmap_attributes(
beatmap: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
mods: list[str] = Query(default_factory=list),
ruleset: GameMode | None = Query(default=None),
ruleset_id: int | None = Query(default=None),
@@ -153,8 +132,8 @@ async def get_beatmap_attributes(
f"beatmap:{beatmap}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
if redis.exists(key):
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
@@ -164,7 +143,7 @@ async def get_beatmap_attributes(
)
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
raise HTTPException(status_code=400, detail=str(e))
redis.set(key, attr.model_dump_json())
await redis.set(key, attr.model_dump_json())
return attr
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")

View File

@@ -1,10 +1,8 @@
from __future__ import annotations
from app.database import (
Beatmapset,
BeatmapsetResp,
User as DBUser,
)
from typing import Literal
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.dependencies.database import get_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -12,27 +10,47 @@ from app.fetcher import Fetcher
from .api_router import router
from fastapi import Depends, HTTPException
from fastapi import Depends, Form, HTTPException, Query
from fastapi.responses import RedirectResponse
from httpx import HTTPStatusError
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp)
async def lookup_beatmapset(
beatmap_id: int = Query(),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmapset_id = (
await db.exec(select(Beatmap.beatmapset_id).where(Beatmap.id == beatmap_id))
).first()
if not beatmapset_id:
try:
resp = await fetcher.get_beatmap(beatmap_id)
await Beatmap.from_resp(db, resp)
await db.refresh(current_user)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
beatmapset = (
await db.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))
).first()
if not beatmapset:
raise HTTPException(status_code=404, detail="Beatmapset not found")
resp = await BeatmapsetResp.from_db(beatmapset, session=db, user=current_user)
return resp
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset(
sid: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmapset = (
await db.exec(
select(Beatmapset)
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmapset.id == sid)
)
).first()
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
if not beatmapset:
try:
resp = await fetcher.get_beatmapset(sid)
@@ -40,5 +58,55 @@ async def get_beatmapset(
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
else:
resp = BeatmapsetResp.from_db(beatmapset)
resp = await BeatmapsetResp.from_db(
beatmapset, session=db, include=["recent_favourites"], user=current_user
)
return resp
@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])
async def download_beatmapset(
beatmapset: int,
no_video: bool = Query(True, alias="noVideo"),
current_user: User = Depends(get_current_user),
):
if current_user.country_code == "CN":
return RedirectResponse(
f"https://txy1.sayobot.cn/beatmaps/download/"
f"{'novideo' if no_video else 'full'}/{beatmapset}?server=auto"
)
else:
return RedirectResponse(
f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}"
)
@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"])
async def favourite_beatmapset(
beatmapset: int,
action: Literal["favourite", "unfavourite"] = Form(),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
existing_favourite = (
await db.exec(
select(FavouriteBeatmapset).where(
FavouriteBeatmapset.user_id == current_user.id,
FavouriteBeatmapset.beatmapset_id == beatmapset,
)
)
).first()
if action == "favourite" and existing_favourite:
raise HTTPException(status_code=400, detail="Already favourited")
elif action == "unfavourite" and not existing_favourite:
raise HTTPException(status_code=400, detail="Not favourited")
if action == "favourite":
favourite = FavouriteBeatmapset(
user_id=current_user.id, beatmapset_id=beatmapset
)
db.add(favourite)
else:
await db.delete(existing_favourite)
await db.commit()

View File

@@ -1,28 +1,27 @@
from __future__ import annotations
from typing import Literal
from app.database import (
User as DBUser,
)
from app.database import User, UserResp
from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user
from app.models.user import (
User as ApiUser,
)
from app.utils import convert_db_user_to_api_user
from app.dependencies.database import get_db
from app.models.score import GameMode
from .api_router import router
from fastapi import Depends
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/me/{ruleset}", response_model=ApiUser)
@router.get("/me/", response_model=ApiUser)
@router.get("/me/{ruleset}", response_model=UserResp)
@router.get("/me/", response_model=UserResp)
async def get_user_info_default(
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
current_user: DBUser = Depends(get_current_user),
ruleset: GameMode | None = None,
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
):
"""获取当前用户信息默认使用osu模式"""
# 默认使用osu模式
api_user = await convert_db_user_to_api_user(current_user, ruleset)
return api_user
return await UserResp.from_db(
current_user,
session,
ALL_INCLUDED,
ruleset,
)

View File

@@ -8,7 +8,7 @@ from app.dependencies.user import get_current_user
from .api_router import router
from fastapi import Depends, HTTPException, Query, Request
from sqlalchemy.orm import joinedload
from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -26,17 +26,19 @@ async def get_relationship(
else RelationshipType.BLOCK
)
relationships = await db.exec(
select(Relationship)
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
.where(
select(Relationship).where(
Relationship.user_id == current_user.id,
Relationship.type == relationship_type,
)
)
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
@router.post("/friends", tags=["relationship"], response_model=RelationshipResp)
class AddFriendResp(BaseModel):
user_relation: RelationshipResp
@router.post("/friends", tags=["relationship"], response_model=AddFriendResp)
@router.post("/blocks", tags=["relationship"])
async def add_relationship(
request: Request,
@@ -87,14 +89,10 @@ async def add_relationship(
if origin_type == RelationshipType.FOLLOW:
relationship = (
await db.exec(
select(Relationship)
.where(
select(Relationship).where(
Relationship.user_id == current_user_id,
Relationship.target_id == target,
)
.options(
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
)
)
).first()
assert relationship, "Relationship should exist after commit"

View File

@@ -1,106 +1,296 @@
from __future__ import annotations
from app.database.room import RoomIndex
from datetime import UTC, datetime
from typing import Literal
from app.database.beatmap import Beatmap, BeatmapResp
from app.database.beatmapset import BeatmapsetResp
from app.database.lazer_user import User, UserResp
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from app.database.playlist_attempts import ItemAttemptsCount, ItemAttemptsResp
from app.database.playlists import Playlist, PlaylistResp
from app.database.room import Room, RoomBase, RoomResp
from app.database.score import Score
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
from app.models.room import MultiplayerRoom, MultiplayerRoomState, Room
from app.models.multiplayer_hub import (
MultiplayerRoom,
MultiplayerRoomUser,
ServerMultiplayerRoom,
)
from app.models.room import RoomStatus
from app.signalr.hub import MultiplayerHubs
from .api_router import router
from api_router import router
from fastapi import Depends, HTTPException, Query
from sqlmodel import select
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/rooms", tags=["rooms"], response_model=list[Room])
@router.get("/rooms", tags=["rooms"], response_model=list[RoomResp])
async def get_all_rooms(
mode: str | None = Query(None), # TODO: 对房间根据状态进行筛选
status: str | None = Query(None),
category: str | None = Query(
None
), # TODO: 对房间根据分类进行筛选(真的有人用这功能吗)
mode: Literal["open", "ended", "participated", "owned", None] = Query(
default="open"
), # TODO: 对房间根据状态进行筛选
category: str = Query(default="realtime"), # TODO
status: RoomStatus | None = Query(None),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
redis: Redis = Depends(get_redis),
current_user: User = Depends(get_current_user),
):
all_roomID = (await db.exec(select(RoomIndex))).all()
redis = get_redis()
if redis is not None:
resp: list[Room] = []
for id in all_roomID:
dumped_room = redis.get(str(id))
validated_room = MultiplayerRoom.model_validate_json(str(dumped_room))
flag: bool = False
if status is not None:
if (
validated_room.State == MultiplayerRoomState.OPEN
and status == "idle"
):
flag = True
elif validated_room != MultiplayerRoomState.CLOSED:
flag = True
if flag:
resp.append(
await Room.from_mpRoom(
MultiplayerRoom.model_validate_json(str(dumped_room)),
db,
fetcher,
)
)
return resp
else:
raise HTTPException(status_code=500, detail="Redis Error")
rooms = MultiplayerHubs.rooms.values()
resp_list: list[RoomResp] = []
for room in rooms:
# if category == "realtime" and room.category != "normal":
# continue
# elif category != room.category and category != "":
# continue
resp_list.append(await RoomResp.from_hub(room))
return resp_list
@router.get("/rooms/{room}", tags=["room"], response_model=Room)
class APICreatedRoom(RoomResp):
error: str = ""
class APIUploadedRoom(RoomBase):
def to_room(self) -> Room:
"""
将 APIUploadedRoom 转换为 Room 对象playlist 字段需单独处理。
"""
room_dict = self.model_dump()
room_dict.pop("playlist", None)
# host_id 已在字段中
return Room(**room_dict)
id: int | None
host_id: int | None = None
playlist: list[Playlist]
@router.post("/rooms", tags=["room"], response_model=APICreatedRoom)
async def create_room(
room: APIUploadedRoom,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
# db_room = Room.from_resp(room)
await db.refresh(current_user)
user_id = current_user.id
db_room = room.to_room()
db_room.host_id = current_user.id if current_user.id else 1
db.add(db_room)
await db.commit()
await db.refresh(db_room)
playlist: list[Playlist] = []
# 处理 APIUploadedRoom 里的 playlist 字段
for item in room.playlist:
# 确保 room_id 正确赋值
item.id = await Playlist.get_next_id_for_room(db_room.id, db)
item.room_id = db_room.id
item.owner_id = user_id if user_id else 1
db.add(item)
await db.commit()
await db.refresh(item)
playlist.append(item)
await db.refresh(db_room)
db_room.playlist = playlist
server_room = ServerMultiplayerRoom(
room=MultiplayerRoom.from_db(db_room),
category=db_room.category,
start_at=datetime.now(UTC),
hub=MultiplayerHubs,
)
MultiplayerHubs.rooms[db_room.id] = server_room
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room))
created_room.error = ""
return created_room
@router.get("/rooms/{room}", tags=["room"], response_model=RoomResp)
async def get_room(
room: int,
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
redis = get_redis()
if redis:
dumped_room = str(redis.get(str(room)))
if dumped_room is not None:
resp = await Room.from_mpRoom(
MultiplayerRoom.model_validate_json(str(dumped_room)), db, fetcher
)
return resp
else:
raise HTTPException(status_code=404, detail="Room Not Found")
else:
raise HTTPException(status_code=500, detail="Redis error")
class APICreatedRoom(Room):
error: str | None
@router.post("/rooms", tags=["beatmap"], response_model=APICreatedRoom)
async def create_room(
room: Room,
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
redis = get_redis()
if redis:
room_index = RoomIndex()
db.add(room_index)
await db.commit()
await db.refresh(room_index)
server_room = await MultiplayerRoom.from_apiRoom(room, db, fetcher)
redis.set(str(room_index.id), server_room.model_dump_json())
room.room_id = room_index.id
return APICreatedRoom(**room.model_dump(), error=None)
else:
raise HTTPException(status_code=500, detail="redis error")
server_room = MultiplayerHubs.rooms[room]
return await RoomResp.from_hub(server_room)
@router.delete("/rooms/{room}", tags=["room"])
async def remove_room(room: int, db: AsyncSession = Depends(get_db)):
redis = get_redis()
if redis:
redis.delete(str(room))
room_index = await db.get(RoomIndex, room)
if room_index:
await db.delete(room_index)
async def delete_room(room: int, db: AsyncSession = Depends(get_db)):
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
else:
await db.delete(db_room)
return None
@router.put("/rooms/{room}/users/{user}", tags=["room"])
async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_db)):
server_room = MultiplayerHubs.rooms[room]
server_room.room.users.append(MultiplayerRoomUser(user_id=user))
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is not None:
db_room.participant_count += 1
await db.commit()
resp = await RoomResp.from_hub(server_room)
await db.refresh(db_room)
for item in db_room.playlist:
resp.playlist.append(await PlaylistResp.from_db(item, ["beatmap"]))
return resp
else:
raise HTTPException(404, "room not found0")
class APILeaderboard(BaseModel):
leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
user_score: ItemAttemptsResp | None = None
@router.get("/rooms/{room}/leaderboard", tags=["room"], response_model=APILeaderboard)
async def get_room_leaderboard(
room: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
server_room = MultiplayerHubs.rooms[room]
if not server_room:
raise HTTPException(404, "Room not found")
aggs = await db.exec(
select(ItemAttemptsCount)
.where(ItemAttemptsCount.room_id == room)
.order_by(col(ItemAttemptsCount.total_score).desc())
)
aggs_resp = []
user_agg = None
for i, agg in enumerate(aggs):
resp = await ItemAttemptsResp.from_db(agg, db)
resp.position = i + 1
aggs_resp.append(resp)
if agg.user_id == current_user.id:
user_agg = resp
return APILeaderboard(
leaderboard=aggs_resp,
user_score=user_agg,
)
class RoomEvents(BaseModel):
beatmaps: list[BeatmapResp] = Field(default_factory=list)
beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
current_playlist_item_id: int = 0
events: list[MultiplayerEventResp] = Field(default_factory=list)
first_event_id: int = 0
last_event_id: int = 0
playlist_items: list[PlaylistResp] = Field(default_factory=list)
room: RoomResp
user: list[UserResp] = Field(default_factory=list)
@router.get("/rooms/{room_id}/events", response_model=RoomEvents, tags=["room"])
async def get_room_events(
room_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
limit: int = Query(100, ge=1, le=1000),
after: int | None = Query(None, ge=0),
before: int | None = Query(None, ge=0),
):
events = (
await db.exec(
select(MultiplayerEvent)
.where(
MultiplayerEvent.room_id == room_id,
col(MultiplayerEvent.id) > after if after is not None else True,
col(MultiplayerEvent.id) < before if before is not None else True,
)
.order_by(col(MultiplayerEvent.id).desc())
.limit(limit)
)
).all()
user_ids = set()
playlist_items = {}
beatmap_ids = set()
event_resps = []
first_event_id = 0
last_event_id = 0
current_playlist_item_id = 0
for event in events:
event_resps.append(MultiplayerEventResp.from_db(event))
if event.user_id:
user_ids.add(event.user_id)
if event.playlist_item_id is not None and (
playitem := (
await db.exec(
select(Playlist).where(
Playlist.id == event.playlist_item_id,
Playlist.room_id == room_id,
)
)
).first()
):
current_playlist_item_id = playitem.id
playlist_items[event.playlist_item_id] = playitem
beatmap_ids.add(playitem.beatmap_id)
scores = await db.exec(
select(Score).where(
Score.playlist_item_id == event.playlist_item_id,
Score.room_id == room_id,
)
)
for score in scores:
user_ids.add(score.user_id)
beatmap_ids.add(score.beatmap_id)
assert event.id is not None
first_event_id = min(first_event_id, event.id)
last_event_id = max(last_event_id, event.id)
if room := MultiplayerHubs.rooms.get(room_id):
current_playlist_item_id = room.queue.current_item.id
room_resp = await RoomResp.from_hub(room)
else:
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if room is None:
raise HTTPException(404, "Room not found")
room_resp = await RoomResp.from_db(room)
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
user_resps = [await UserResp.from_db(user, db) for user in users]
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
beatmap_resps = [
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
]
beatmapset_resps = {}
for beatmap_resp in beatmap_resps:
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
playlist_items_resps = [
await PlaylistResp.from_db(item) for item in playlist_items.values()
]
return RoomEvents(
beatmaps=beatmap_resps,
beatmapsets=beatmapset_resps,
current_playlist_item_id=current_playlist_item_id,
events=event_resps,
first_event_id=first_event_id,
last_event_id=last_event_id,
playlist_items=playlist_items_resps,
room=room_resp,
user=user_resps,
)

View File

@@ -1,18 +1,41 @@
from __future__ import annotations
from datetime import UTC, datetime
import time
from app.calculator import clamp
from app.database import (
User as DBUser,
Beatmap,
Playlist,
Room,
Score,
ScoreResp,
ScoreToken,
ScoreTokenResp,
User,
)
from app.database.playlist_attempts import ItemAttemptsCount
from app.database.playlist_best_score import (
PlaylistBestScore,
get_position,
process_playlist_best_score,
)
from app.database.score import (
MultiplayerScores,
ScoreAround,
get_leaderboard,
process_score,
process_user,
)
from app.database.beatmap import Beatmap
from app.database.score import Score, ScoreResp, process_score, process_user
from app.database.score_token import ScoreToken, ScoreTokenResp
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
from app.models.beatmap import BeatmapRankStatus
from app.models.score import (
INT_TO_MODE,
GameMode,
LeaderboardType,
Rank,
SoloScoreSubmissionInfo,
)
@@ -21,11 +44,75 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel
from redis import Redis
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload
from sqlmodel import col, select, true
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
READ_SCORE_TIMEOUT = 10
async def submit_score(
info: SoloScoreSubmissionInfo,
beatmap: int,
token: int,
current_user: User,
db: AsyncSession,
redis: Redis,
fetcher: Fetcher,
item_id: int | None = None,
room_id: int | None = None,
):
if not info.passed:
info.rank = Rank.F
score_token = (
await db.exec(
select(ScoreToken)
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
.where(ScoreToken.id == token)
)
).first()
if not score_token or score_token.user_id != current_user.id:
raise HTTPException(status_code=404, detail="Score token not found")
if score_token.score_id:
score = (
await db.exec(
select(Score).where(
Score.id == score_token.score_id,
Score.user_id == current_user.id,
)
)
).first()
if not score:
raise HTTPException(status_code=404, detail="Score not found")
else:
beatmap_status = (
await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap))
).first()
if beatmap_status is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
ranked = beatmap_status in {
BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED,
}
score = await process_score(
current_user,
beatmap,
ranked,
score_token,
info,
fetcher,
db,
redis,
)
await db.refresh(current_user)
score_id = score.id
score_token.score_id = score_id
await process_user(db, current_user, score, ranked)
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
assert score is not None
return await ScoreResp.from_db(db, score)
class BeatmapScores(BaseModel):
scores: list[ScoreResp]
@@ -37,44 +124,26 @@ class BeatmapScores(BaseModel):
)
async def get_beatmap_scores(
beatmap: int,
mode: GameMode,
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
mode: GameMode | None = Query(None),
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
type: str = Query(None),
current_user: DBUser = Depends(get_current_user),
mods: list[str] = Query(default_factory=set, alias="mods[]"),
type: LeaderboardType = Query(LeaderboardType.GLOBAL),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
limit: int = Query(50, ge=1, le=200),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="this server only contains lazer scores"
)
all_scores = (
await db.exec(
Score.select_clause_unique(
Score.beatmap_id == beatmap,
col(Score.passed).is_(True),
Score.gamemode == mode if mode is not None else true(),
)
)
).all()
user_score = (
await db.exec(
Score.select_clause_unique(
Score.beatmap_id == beatmap,
Score.user_id == current_user.id,
col(Score.passed).is_(True),
Score.gamemode == mode if mode is not None else true(),
)
)
).first()
all_scores, user_score = await get_leaderboard(
db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods
)
return BeatmapScores(
scores=[await ScoreResp.from_db(db, score, score.user) for score in all_scores],
userScore=await ScoreResp.from_db(db, user_score, user_score.user)
if user_score
else None,
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
)
@@ -94,7 +163,7 @@ async def get_user_beatmap_score(
legacy_only: bool = Query(None),
mode: str = Query(None),
mods: str = Query(None), # TODO:添加mods筛选
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -103,7 +172,7 @@ async def get_user_beatmap_score(
)
user_score = (
await db.exec(
Score.select_clause(True)
select(Score)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,
@@ -118,9 +187,10 @@ async def get_user_beatmap_score(
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
)
else:
resp = await ScoreResp.from_db(db, user_score)
return BeatmapUserScore(
position=user_score.position if user_score.position is not None else 0,
score=await ScoreResp.from_db(db, user_score, user_score.user),
position=resp.rank_global or 0,
score=resp,
)
@@ -134,7 +204,7 @@ async def get_user_all_beatmap_scores(
user: int,
legacy_only: bool = Query(None),
ruleset: str = Query(None),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -143,7 +213,7 @@ async def get_user_all_beatmap_scores(
)
all_user_scores = (
await db.exec(
Score.select_clause()
select(Score)
.where(
Score.gamemode == ruleset if ruleset is not None else True,
Score.beatmap_id == beatmap,
@@ -153,9 +223,7 @@ async def get_user_all_beatmap_scores(
)
).all()
return [
await ScoreResp.from_db(db, score, current_user) for score in all_user_scores
]
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
@router.post(
@@ -166,9 +234,10 @@ async def create_solo_score(
version_hash: str = Form(""),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
assert current_user.id
async with db:
score_token = ScoreToken(
user_id=current_user.id,
@@ -190,64 +259,275 @@ async def submit_solo_score(
beatmap: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
):
if not info.passed:
info.rank = Rank.F
async with db:
score_token = (
await db.exec(
select(ScoreToken)
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
.where(ScoreToken.id == token, ScoreToken.user_id == current_user.id)
return await submit_score(info, beatmap, token, current_user, db, redis, fetcher)
@router.post(
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp
)
async def create_playlist_score(
room_id: int,
playlist_id: int,
beatmap_id: int = Form(),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
version_hash: str = Form(""),
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
):
room = await session.get(Room, room_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
if room.ended_at and room.ended_at < datetime.now(UTC):
raise HTTPException(status_code=400, detail="Room has ended")
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
).first()
if not score_token or score_token.user_id != current_user.id:
raise HTTPException(status_code=404, detail="Score token not found")
if score_token.score_id:
score = (
await db.exec(
select(Score)
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
Score.id == score_token.score_id,
Score.user_id == current_user.id,
)
).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist not found")
# validate
if not item.freestyle:
if item.ruleset_id != ruleset_id:
raise HTTPException(
status_code=400, detail="Ruleset mismatch in playlist item"
)
if item.beatmap_id != beatmap_id:
raise HTTPException(
status_code=400, detail="Beatmap ID mismatch in playlist item"
)
agg = await session.exec(
select(ItemAttemptsCount).where(
ItemAttemptsCount.room_id == room_id,
ItemAttemptsCount.user_id == current_user.id,
)
)
agg = agg.first()
if agg and room.max_attempts and agg.attempts >= room.max_attempts:
raise HTTPException(
status_code=422,
detail="You have reached the maximum attempts for this room",
)
if item.expired:
raise HTTPException(status_code=400, detail="Playlist item has expired")
if item.played_at:
raise HTTPException(
status_code=400, detail="Playlist item has already been played"
)
# 这里应该不用验证mod了吧。。。
score_token = ScoreToken(
user_id=current_user.id,
beatmap_id=beatmap_id,
ruleset_id=INT_TO_MODE[ruleset_id],
playlist_item_id=playlist_id,
)
session.add(score_token)
await session.commit()
await session.refresh(score_token)
return ScoreTokenResp.from_db(score_token)
@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}")
async def submit_playlist_score(
room_id: int,
playlist_id: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist item not found")
user_id = current_user.id
score_resp = await submit_score(
info,
item.beatmap_id,
token,
current_user,
session,
redis,
fetcher,
item.id,
room_id,
)
await process_playlist_best_score(
room_id,
playlist_id,
user_id,
score_resp.id,
score_resp.total_score,
session,
redis,
)
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
return score_resp
class IndexedScoreResp(MultiplayerScores):
total: int
user_score: ScoreResp | None = None
@router.get(
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=IndexedScoreResp
)
async def index_playlist_scores(
room_id: int,
playlist_id: int,
limit: int = 50,
cursor: int = Query(2000000, alias="cursor[total_score]"),
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
):
limit = clamp(limit, 1, 50)
scores = (
await session.exec(
select(PlaylistBestScore)
.where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
PlaylistBestScore.total_score < cursor,
)
.order_by(col(PlaylistBestScore.total_score).desc())
.limit(limit + 1)
)
).all()
has_more = len(scores) > limit
if has_more:
scores = scores[:-1]
user_score = None
score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
for score in score_resp:
score.position = await get_position(room_id, playlist_id, score.id, session)
if score.user_id == current_user.id:
user_score = score
resp = IndexedScoreResp(
scores=score_resp,
user_score=user_score,
total=len(scores),
params={
"limit": limit,
},
)
if has_more:
resp.cursor = {
"total_score": scores[-1].total_score,
}
return resp
@router.get(
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
response_model=ScoreResp,
)
async def show_playlist_score(
room_id: int,
playlist_id: int,
score_id: int,
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
start_time = time.time()
score_record = None
completed = False
while time.time() - start_time < READ_SCORE_TIMEOUT:
if score_record is None:
score_record = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.score_id == score_id,
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
)
)
).first()
if not score:
raise HTTPException(status_code=404, detail="Score not found")
else:
beatmap_status = (
await db.exec(
select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
if completed_players := await redis.get(
f"multiplayer:{room_id}:gameplay:players"
):
completed = completed_players == "0"
if score_record and completed:
break
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(room_id, playlist_id, score_id, session)
if completed:
scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
)
).first()
if beatmap_status is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
ranked = beatmap_status in {
BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED,
}
score = await process_score(
current_user,
beatmap,
ranked,
score_token,
info,
fetcher,
db,
redis,
)
await db.refresh(current_user)
score_id = score.id
score_token.score_id = score_id
await process_user(db, current_user, score, ranked)
score = (
await db.exec(Score.select_clause().where(Score.id == score_id))
).first()
assert score is not None
return await ScoreResp.from_db(db, score, current_user)
).all()
higher_scores = []
lower_scores = []
for score in scores:
if score.total_score > resp.total_score:
higher_scores.append(await ScoreResp.from_db(session, score.score))
elif score.total_score < resp.total_score:
lower_scores.append(await ScoreResp.from_db(session, score.score))
resp.scores_around = ScoreAround(
higher=MultiplayerScores(scores=higher_scores),
lower=MultiplayerScores(scores=lower_scores),
)
return resp
@router.get(
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
response_model=ScoreResp,
)
async def get_user_playlist_score(
room_id: int,
playlist_id: int,
user_id: int,
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
):
score_record = None
start_time = time.time()
while time.time() - start_time < READ_SCORE_TIMEOUT:
score_record = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.user_id == user_id,
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
)
)
).first()
if score_record:
break
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(
room_id, playlist_id, score_record.score_id, session
)
return resp

View File

@@ -1,12 +1,9 @@
from __future__ import annotations
from typing import Literal
from app.database import User as DBUser
from app.database import User, UserResp
from app.database.lazer_user import SEARCH_INCLUDED
from app.dependencies.database import get_db
from app.models.score import INT_TO_MODE
from app.models.user import User as ApiUser
from app.utils import convert_db_user_to_api_user
from app.models.score import GameMode
from .api_router import router
@@ -17,28 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import col
# ---------- Shared Utility ----------
async def get_user_by_lookup(
db: AsyncSession, lookup: str, key: str = "id"
) -> DBUser | None:
"""根据查找方式获取用户"""
if key == "id":
try:
user_id = int(lookup)
result = await db.exec(select(DBUser).where(DBUser.id == user_id))
return result.first()
except ValueError:
return None
elif key == "username":
result = await db.exec(select(DBUser).where(DBUser.name == lookup))
return result.first()
else:
return None
# ---------- Batch Users ----------
class BatchUserResponse(BaseModel):
users: list[ApiUser]
users: list[UserResp]
@router.get("/users", response_model=BatchUserResponse)
@@ -51,75 +28,44 @@ async def get_users(
):
if user_ids:
searched_users = (
await session.exec(
DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
)
await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
).all()
else:
searched_users = (
await session.exec(DBUser.all_select_clause().limit(50))
).all()
searched_users = (await session.exec(select(User).limit(50))).all()
return BatchUserResponse(
users=[
await convert_db_user_to_api_user(
searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value
await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
)
for searched_user in searched_users
]
)
# # ---------- Individual User ----------
# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
# async def get_user_with_mode(
# user_lookup: str,
# mode: Literal["osu", "taiko", "fruits", "mania"],
# key: Literal["id", "username"] = Query("id"),
# current_user: DBUser = Depends(get_current_user),
# db: AsyncSession = Depends(get_db),
# ):
# """获取指定游戏模式的用户信息"""
# user = await get_user_by_lookup(db, user_lookup, key)
# if not user:
# raise HTTPException(status_code=404, detail="User not found")
# return await convert_db_user_to_api_user(user, mode)
# @router.get("/users/{user_lookup}", response_model=ApiUser)
# @router.get("/users/{user_lookup}/", response_model=ApiUser)
# async def get_user_default(
# user_lookup: str,
# key: Literal["id", "username"] = Query("id"),
# current_user: DBUser = Depends(get_current_user),
# db: AsyncSession = Depends(get_db),
# ):
# """获取用户信息默认使用osu模式但包含所有模式的统计信息"""
# user = await get_user_by_lookup(db, user_lookup, key)
# if not user:
# raise HTTPException(status_code=404, detail="User not found")
# return await convert_db_user_to_api_user(user, "osu")
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
@router.get("/users/{user}/", response_model=ApiUser)
@router.get("/users/{user}", response_model=ApiUser)
@router.get("/users/{user}/{ruleset}", response_model=UserResp)
@router.get("/users/{user}/", response_model=UserResp)
@router.get("/users/{user}", response_model=UserResp)
async def get_user_info(
user: str,
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
ruleset: GameMode | None = None,
session: AsyncSession = Depends(get_db),
):
searched_user = (
await session.exec(
DBUser.all_select_clause().where(
DBUser.id == int(user)
select(User).where(
User.id == int(user)
if user.isdigit()
else DBUser.name == user.removeprefix("@")
else User.username == user.removeprefix("@")
)
)
).first()
if not searched_user:
raise HTTPException(404, detail="User not found")
return await convert_db_user_to_api_user(searched_user, ruleset=ruleset)
return await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
ruleset=ruleset,
)

View File

@@ -6,9 +6,9 @@ import time
from typing import Any
from app.config import settings
from app.exception import InvokeException
from app.log import logger
from app.models.signalr import UserState
from app.signalr.exception import InvokeException
from app.signalr.packet import (
ClosePacket,
CompletionPacket,
@@ -21,7 +21,6 @@ from app.signalr.store import ResultStore
from app.signalr.utils import get_signature
from fastapi import WebSocket
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
@@ -49,7 +48,7 @@ class Client:
self.connection_id = connection_id
self.connection_token = connection_token
self.connection = connection
self.procotol = protocol
self.protocol = protocol
self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None
self._store = ResultStore()
@@ -62,14 +61,14 @@ class Client:
return int(self.connection_id)
async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.procotol.encode(packet))
await self.connection.send_bytes(self.protocol.encode(packet))
async def receive_packets(self) -> list[Packet]:
message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
return []
return self.procotol.decode(d)
return self.protocol.decode(d)
async def _ping(self):
while True:
@@ -263,10 +262,9 @@ class Hub[TState: UserState]:
for name, param in signature.parameters.items():
if name == "self" or param.annotation is Client:
continue
if issubclass(param.annotation, BaseModel):
call_params.append(param.annotation.model_validate(args.pop(0)))
else:
call_params.append(args.pop(0))
call_params.append(
client.protocol.validate_object(args.pop(0), param.annotation)
)
return await method_(client, *call_params)
async def call(self, client: Client, method: str, *args: Any) -> Any:

View File

@@ -2,15 +2,15 @@ from __future__ import annotations
import asyncio
from collections.abc import Coroutine
from datetime import UTC, datetime
from typing import override
from app.database.relationship import Relationship, RelationshipType
from app.dependencies.database import engine
from app.database import Relationship, RelationshipType, User
from app.dependencies.database import engine, get_redis
from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity
from .hub import Client, Hub
from pydantic import TypeAdapter
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -30,7 +30,7 @@ class MetadataHub(Hub[MetadataClientState]):
) -> set[Coroutine]:
if store is not None and not store.pushable:
return set()
data = store.to_dict() if store else None
data = store.for_push if store else None
return {
self.broadcast_group_call(
self.online_presence_watchers_group(),
@@ -54,6 +54,18 @@ class MetadataHub(Hub[MetadataClientState]):
async def _clean_state(self, state: MetadataClientState) -> None:
if state.pushable:
await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None))
redis = get_redis()
if await redis.exists(f"metadata:online:{state.connection_id}"):
await redis.delete(f"metadata:online:{state.connection_id}")
async with AsyncSession(engine) as session:
async with session.begin():
user = (
await session.exec(
select(User).where(User.id == int(state.connection_id))
)
).one()
user.last_visit = datetime.now(UTC)
await session.commit()
@override
def create_state(self, client: Client) -> MetadataClientState:
@@ -89,10 +101,14 @@ class MetadataHub(Hub[MetadataClientState]):
self.friend_presence_watchers_group(friend_id),
"FriendPresenceUpdated",
friend_id,
friend_state.to_dict(),
friend_state.for_push
if friend_state.pushable
else None,
)
)
await asyncio.gather(*tasks)
redis = get_redis()
await redis.set(f"metadata:online:{user_id}", "")
async def UpdateStatus(self, client: Client, status: int) -> None:
status_ = OnlineStatus(status)
@@ -107,27 +123,24 @@ class MetadataHub(Hub[MetadataClientState]):
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
store.for_push,
)
)
await asyncio.gather(*tasks)
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
async def UpdateActivity(
self, client: Client, activity: UserActivity | None
) -> None:
user_id = int(client.connection_id)
activity = (
TypeAdapter(UserActivity).validate_python(activity_dict)
if activity_dict
else None
)
store = self.get_or_create_state(client)
store.user_activity = activity
store.activity = activity
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
store.for_push,
)
)
await asyncio.gather(*tasks)
@@ -139,7 +152,7 @@ class MetadataHub(Hub[MetadataClientState]):
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
store,
)
for user_id, store in self.state.items()
if store.pushable

File diff suppressed because it is too large Load Diff

View File

@@ -7,15 +7,13 @@ import struct
import time
from typing import override
from app.database import Beatmap
from app.database import Beatmap, User
from app.database.score import Score
from app.database.score_token import ScoreToken
from app.database.user import User
from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
from app.models.signalr import serialize_to_list
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
from app.models.spectator_hub import (
APIUser,
FrameDataBundle,
@@ -70,8 +68,8 @@ def save_replay(
md5: str,
username: str,
score: Score,
statistics: ScoreStatisticsInt,
maximum_statistics: ScoreStatisticsInt,
statistics: ScoreStatistics,
maximum_statistics: ScoreStatistics,
frames: list[LegacyReplayFrame],
) -> None:
data = bytearray()
@@ -108,8 +106,8 @@ def save_replay(
last_time = 0
for frame in frames:
frame_strs.append(
f"{frame.time - last_time}|{frame.x or 0.0}"
f"|{frame.y or 0.0}|{frame.button_state}"
f"{frame.time - last_time}|{frame.mouse_x or 0.0}"
f"|{frame.mouse_y or 0.0}|{frame.button_state}"
)
last_time = frame.time
frame_strs.append("-12345|0|0|0")
@@ -166,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]):
async def on_client_connect(self, client: Client) -> None:
tasks = [
self.call_noblock(
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
)
self.call_noblock(client, "UserBeganPlaying", user_id, store.state)
for user_id, store in self.state.items()
if store.state is not None
]
@@ -197,7 +193,7 @@ class SpectatorHub(Hub[StoreClientState]):
).first()
if not user:
return
name = user.name
name = user.username
store.state = state
store.beatmap_status = beatmap.beatmap_status
store.checksum = beatmap.checksum
@@ -215,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserBeganPlaying",
user_id,
serialize_to_list(state),
state,
)
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
@@ -223,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]):
state = self.get_or_create_state(client)
if not state.score:
return
state.score.score_info.acc = frame_data.header.acc
state.score.score_info.accuracy = frame_data.header.accuracy
state.score.score_info.combo = frame_data.header.combo
state.score.score_info.max_combo = frame_data.header.max_combo
state.score.score_info.statistics = frame_data.header.statistics
@@ -234,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserSentFrames",
user_id,
frame_data.model_dump(),
frame_data,
)
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
@@ -297,19 +293,18 @@ class SpectatorHub(Hub[StoreClientState]):
score_record.id,
)
# save replay
if store.state.state == SpectatedUserState.Passed:
score_record.has_replay = True
await session.commit()
await session.refresh(score_record)
save_replay(
ruleset_id=store.ruleset_id,
md5=store.checksum,
username=store.score.score_info.user.name,
score=score_record,
statistics=store.score.score_info.statistics,
maximum_statistics=store.score.score_info.maximum_statistics,
frames=store.score.replay_frames,
)
score_record.has_replay = True
await session.commit()
await session.refresh(score_record)
save_replay(
ruleset_id=store.ruleset_id,
md5=store.checksum,
username=store.score.score_info.user.name,
score=score_record,
statistics=store.score.score_info.statistics,
maximum_statistics=store.score.score_info.maximum_statistics,
frames=store.score.replay_frames,
)
async def _end_session(self, user_id: int, state: SpectatorState) -> None:
if state.state == SpectatedUserState.Playing:
@@ -318,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserFinishedPlaying",
user_id,
serialize_to_list(state) if state else None,
state,
)
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
@@ -329,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]):
client,
"UserBeganPlaying",
target_id,
serialize_to_list(target_store.state),
target_store.state,
)
store = self.get_or_create_state(client)
store.watched_user.add(target_id)
@@ -339,7 +334,7 @@ class SpectatorHub(Hub[StoreClientState]):
async with AsyncSession(engine) as session:
async with session.begin():
username = (
await session.exec(select(User.name).where(User.id == user_id))
await session.exec(select(User.username).where(User.id == user_id))
).first()
if not username:
return

View File

@@ -1,14 +1,24 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
import datetime
from enum import Enum, IntEnum
import inspect
import json
from types import NoneType, UnionType
from typing import (
Any,
Protocol as TypingProtocol,
Union,
get_args,
get_origin,
)
from app.models.signalr import SignalRMeta, SignalRUnionMessage
from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal
import msgpack_lazer_api as m
from pydantic import BaseModel
SEP = b"\x1e"
@@ -73,8 +83,67 @@ class Protocol(TypingProtocol):
@staticmethod
def encode(packet: Packet) -> bytes: ...
@classmethod
def validate_object(cls, v: Any, typ: type) -> Any: ...
class MsgpackProtocol:
@classmethod
def serialize_msgpack(cls, v: Any) -> Any:
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_to_list(v)
elif issubclass(typ, list):
return [cls.serialize_msgpack(item) for item in v]
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif issubclass(typ, datetime.timedelta):
return int(v.total_seconds() * 10_000_000)
elif isinstance(v, dict):
return {
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
for k, value in v.items()
}
elif issubclass(typ, Enum):
list_ = list(typ)
return list_.index(v) if v in list_ else v.value
return v
@classmethod
def serialize_to_list(cls, value: BaseModel) -> list[Any]:
values = []
for field, info in value.__class__.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.member_ignore:
continue
values.append(cls.serialize_msgpack(v=getattr(value, field)))
if issubclass(value.__class__, SignalRUnionMessage):
return [value.__class__.union_type, values]
else:
return values
@staticmethod
def process_object(v: Any, typ: type[BaseModel]) -> Any:
if isinstance(v, list):
d = {}
i = 0
for field, info in typ.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.member_ignore:
continue
anno = info.annotation
if anno is None:
d[camel_to_snake(field)] = v[i]
else:
d[field] = MsgpackProtocol.validate_object(v[i], anno)
i += 1
return d
return v
@staticmethod
def _encode_varint(value: int) -> bytes:
result = []
@@ -140,6 +209,53 @@ class MsgpackProtocol:
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@classmethod
def validate_object(cls, v: Any, typ: type) -> Any:
if issubclass(typ, BaseModel):
return typ.model_validate(obj=cls.process_object(v, typ))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return v[0]
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
return datetime.timedelta(seconds=int(v / 10_000_000))
elif get_origin(typ) is list:
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
elif get_origin(typ) is dict:
return {
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
v, get_args(typ)[1]
)
for k, v in v.items()
}
elif (origin := get_origin(typ)) is Union or origin is UnionType:
args = get_args(typ)
if len(args) == 2 and NoneType in args:
non_none_args = [arg for arg in args if arg is not NoneType]
if len(non_none_args) == 1:
if v is None:
return None
return cls.validate_object(v, non_none_args[0])
# suppose use `MessagePack-CSharp Union | None`
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
):
raise ValueError(
f"Cannot validate {v} to {typ}, "
"only SignalRUnionMessage subclasses are supported"
)
union_type = v[0]
for arg in args:
assert issubclass(arg, SignalRUnionMessage)
if arg.union_type == union_type:
return cls.validate_object(v[1], arg)
return v
@staticmethod
def encode(packet: Packet) -> bytes:
payload = [packet.type.value, packet.header or {}]
@@ -151,20 +267,24 @@ class MsgpackProtocol:
]
)
if packet.arguments is not None:
payload.append(packet.arguments)
payload.append(
[MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments]
)
if packet.stream_ids is not None:
payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket):
result_kind = 2
if packet.error:
result_kind = 1
elif packet.result is None:
elif packet.result is not None:
result_kind = 3
payload.extend(
[
packet.invocation_id,
result_kind,
packet.error or packet.result or None,
packet.error
or MsgpackProtocol.serialize_msgpack(packet.result)
or None,
]
)
elif isinstance(packet, ClosePacket):
@@ -181,6 +301,86 @@ class MsgpackProtocol:
class JSONProtocol:
@classmethod
def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False):
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_model(v, in_union)
elif isinstance(v, dict):
return {
cls.serialize_to_json(k, True): cls.serialize_to_json(value)
for k, value in v.items()
}
elif isinstance(v, list):
return [cls.serialize_to_json(item) for item in v]
elif isinstance(v, datetime.datetime):
return v.isoformat()
elif isinstance(v, datetime.timedelta):
# d.hh:mm:ss
total_seconds = int(v.total_seconds())
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{hours:02}:{minutes:02}:{seconds:02}"
elif isinstance(v, Enum) and dict_key:
return v.value
elif isinstance(v, Enum):
list_ = list(typ)
return list_.index(v)
return v
@classmethod
def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]:
d = {}
is_union = issubclass(v.__class__, SignalRUnionMessage)
for field, info in v.__class__.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.json_ignore:
continue
name = (
snake_to_camel(
field,
metadata.use_abbr if metadata else True,
)
if not is_union
else snake_to_pascal(
field,
metadata.use_abbr if metadata else True,
)
)
d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union)
if is_union and not in_union:
return {
"$dtype": v.__class__.__name__,
"$value": d,
}
return d
@staticmethod
def process_object(
v: Any, typ: type[BaseModel], from_union: bool = False
) -> dict[str, Any]:
d = {}
for field, info in typ.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.json_ignore:
continue
name = (
snake_to_camel(field, metadata.use_abbr if metadata else True)
if not from_union
else snake_to_pascal(field, metadata.use_abbr if metadata else True)
)
value = v.get(name)
anno = typ.model_fields[field].annotation
if anno is None:
d[field] = value
continue
d[field] = JSONProtocol.validate_object(value, anno)
return d
@staticmethod
def decode(input: bytes) -> list[Packet]:
packets_raw = input.removesuffix(SEP).split(SEP)
@@ -225,6 +425,63 @@ class JSONProtocol:
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@classmethod
def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
if issubclass(typ, BaseModel):
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return datetime.datetime.fromisoformat(v)
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
# d.hh:mm:ss
parts = v.split(":")
if len(parts) == 3:
return datetime.timedelta(
hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2])
)
elif len(parts) == 2:
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
elif len(parts) == 1:
return datetime.timedelta(seconds=int(parts[0]))
elif get_origin(typ) is list:
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
elif get_origin(typ) is dict:
return {
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
v, get_args(typ)[1]
)
for k, v in v.items()
}
elif (origin := get_origin(typ)) is Union or origin is UnionType:
args = get_args(typ)
if len(args) == 2 and NoneType in args:
non_none_args = [arg for arg in args if arg is not NoneType]
if len(non_none_args) == 1:
if v is None:
return None
return cls.validate_object(v, non_none_args[0])
# suppose use `MessagePack-CSharp Union | None`
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
):
raise ValueError(
f"Cannot validate {v} to {typ}, "
"only SignalRUnionMessage subclasses are supported"
)
# https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
union_type = v["$dtype"]
for arg in args:
assert issubclass(arg, SignalRUnionMessage)
if arg.__name__ == union_type:
return cls.validate_object(v["$value"], arg, True)
return v
@staticmethod
def encode(packet: Packet) -> bytes:
payload: dict[str, Any] = {
@@ -241,7 +498,9 @@ class JSONProtocol:
if packet.invocation_id is not None:
payload["invocationId"] = packet.invocation_id
if packet.arguments is not None:
payload["arguments"] = packet.arguments
payload["arguments"] = [
JSONProtocol.serialize_to_json(arg) for arg in packet.arguments
]
if packet.stream_ids is not None:
payload["streamIds"] = packet.stream_ids
elif isinstance(packet, CompletionPacket):
@@ -253,7 +512,7 @@ class JSONProtocol:
if packet.error is not None:
payload["error"] = packet.error
if packet.result is not None:
payload["result"] = packet.result
payload["result"] = JSONProtocol.serialize_to_json(packet.result)
elif isinstance(packet, PingPacket):
pass
elif isinstance(packet, ClosePacket):

View File

@@ -1,465 +1,92 @@
from __future__ import annotations
from datetime import UTC, datetime
from app.database import (
LazerUserCounts,
LazerUserProfile,
LazerUserStatistics,
User as DBUser,
)
from app.models.user import (
Country,
Cover,
DailyChallengeStats,
GradeCounts,
Kudosu,
Level,
Page,
RankHighest,
RankHistory,
Statistics,
User,
UserAchievement,
)
def unix_timestamp_to_windows(timestamp: int) -> int:
"""Convert a Unix timestamp to a Windows timestamp."""
return (timestamp + 62135596800) * 10_000_000
async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
"""将数据库用户模型转换为API用户模型使用 Lazer 表)"""
# 从db_user获取基本字段值
user_id = getattr(db_user, "id")
user_name = getattr(db_user, "name")
user_country = getattr(db_user, "country")
user_country_code = user_country # 在User模型中country字段就是country_code
# 获取 Lazer 用户资料
profile = db_user.lazer_profile
if not profile:
# 如果没有 lazer 资料,使用默认值
profile = LazerUserProfile(
user_id=user_id,
)
# 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系
lzrcnt = db_user.lazer_counts
if not lzrcnt:
# 如果没有 lazer 计数,使用默认值
lzrcnt = LazerUserCounts(user_id=user_id)
# 获取指定模式的统计信息
user_stats = None
if db_user.lazer_statistics:
for stat in db_user.lazer_statistics:
if stat.mode == ruleset:
user_stats = stat
break
if not user_stats:
# 如果没有找到指定模式的统计,创建默认统计
user_stats = LazerUserStatistics(user_id=user_id)
# 获取国家信息
country_code = db_user.country_code if db_user.country_code is not None else "XX"
country = Country(code=str(country_code), name=get_country_name(str(country_code)))
# 获取 Kudosu 信息
kudosu = Kudosu(available=0, total=0)
# 获取计数信息
# counts = LazerUserCounts(user_id=user_id)
# 转换统计信息
statistics = Statistics(
count_100=user_stats.count_100,
count_300=user_stats.count_300,
count_50=user_stats.count_50,
count_miss=user_stats.count_miss,
level=Level(
current=user_stats.level_current, progress=user_stats.level_progress
),
global_rank=user_stats.global_rank,
global_rank_exp=user_stats.global_rank_exp,
pp=float(user_stats.pp) if user_stats.pp else 0.0,
pp_exp=float(user_stats.pp_exp) if user_stats.pp_exp else 0.0,
ranked_score=user_stats.ranked_score,
hit_accuracy=float(user_stats.hit_accuracy) if user_stats.hit_accuracy else 0.0,
play_count=user_stats.play_count,
play_time=user_stats.play_time,
total_score=user_stats.total_score,
total_hits=user_stats.total_hits,
maximum_combo=user_stats.maximum_combo,
replays_watched_by_others=user_stats.replays_watched_by_others,
is_ranked=user_stats.is_ranked,
grade_counts=GradeCounts(
ss=user_stats.grade_ss,
ssh=user_stats.grade_ssh,
s=user_stats.grade_s,
sh=user_stats.grade_sh,
a=user_stats.grade_a,
),
country_rank=user_stats.country_rank,
rank={"country": user_stats.country_rank} if user_stats.country_rank else None,
)
# 转换所有模式的统计信息
statistics_rulesets = {}
if db_user.lazer_statistics:
for stat in db_user.lazer_statistics:
statistics_rulesets[stat.mode] = Statistics(
count_100=stat.count_100,
count_300=stat.count_300,
count_50=stat.count_50,
count_miss=stat.count_miss,
level=Level(current=stat.level_current, progress=stat.level_progress),
global_rank=stat.global_rank,
global_rank_exp=stat.global_rank_exp,
pp=float(stat.pp) if stat.pp else 0.0,
pp_exp=float(stat.pp_exp) if stat.pp_exp else 0.0,
ranked_score=stat.ranked_score,
hit_accuracy=float(stat.hit_accuracy) if stat.hit_accuracy else 0.0,
play_count=stat.play_count,
play_time=stat.play_time,
total_score=stat.total_score,
total_hits=stat.total_hits,
maximum_combo=stat.maximum_combo,
replays_watched_by_others=stat.replays_watched_by_others,
is_ranked=stat.is_ranked,
grade_counts=GradeCounts(
ss=stat.grade_ss,
ssh=stat.grade_ssh,
s=stat.grade_s,
sh=stat.grade_sh,
a=stat.grade_a,
),
country_rank=stat.country_rank,
rank={"country": stat.country_rank} if stat.country_rank else None,
)
# 转换国家信息
country = Country(code=user_country_code, name=get_country_name(user_country_code))
# 转换封面信息
cover_url = (
profile.cover_url
if profile and profile.cover_url
else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
)
cover = Cover(
custom_url=profile.cover_url if profile else None, url=str(cover_url), id=None
)
# 转换 Kudosu 信息
kudosu = Kudosu(available=0, total=0)
# 转换成就信息
user_achievements = []
if db_user.lazer_achievements:
for achievement in db_user.lazer_achievements:
user_achievements.append(
UserAchievement(
achieved_at=achievement.achieved_at,
achievement_id=achievement.achievement_id,
)
)
# 转换排名历史
rank_history = None
rank_history_data = None
for rh in db_user.rank_history:
if rh.mode == ruleset:
rank_history_data = rh.rank_data
break
if rank_history_data:
rank_history = RankHistory(mode=ruleset, data=rank_history_data)
# 转换每日挑战统计
# daily_challenge_stats = None
# if db_user.daily_challenge_stats:
# dcs = db_user.daily_challenge_stats
# daily_challenge_stats = DailyChallengeStats(
# daily_streak_best=dcs.daily_streak_best,
# daily_streak_current=dcs.daily_streak_current,
# last_update=dcs.last_update,
# last_weekly_streak=dcs.last_weekly_streak,
# playcount=dcs.playcount,
# top_10p_placements=dcs.top_10p_placements,
# top_50p_placements=dcs.top_50p_placements,
# user_id=dcs.user_id,
# weekly_streak_best=dcs.weekly_streak_best,
# weekly_streak_current=dcs.weekly_streak_current,
# )
# 转换最高排名
rank_highest = None
if user_stats.rank_highest:
rank_highest = RankHighest(
rank=user_stats.rank_highest,
updated_at=user_stats.rank_highest_updated_at or datetime.utcnow(),
)
# 转换团队信息
team = None
if db_user.team_membership:
team_member = db_user.team_membership # 假设用户只属于一个团队
team = team_member.team
# 创建用户对象
# 从db_user获取基本字段值
user_id = getattr(db_user, "id")
user_name = getattr(db_user, "name")
user_country = getattr(db_user, "country")
# 获取用户头像URL
avatar_url = None
# 首先检查 profile 中的 avatar_url
if profile and hasattr(profile, "avatar_url") and profile.avatar_url:
avatar_url = str(profile.avatar_url)
# 然后检查是否有关联的头像记录
if avatar_url is None and hasattr(db_user, "avatar") and db_user.avatar is not None:
if db_user.avatar.r2_game_url:
# 优先使用游戏用的头像URL
avatar_url = str(db_user.avatar.r2_game_url)
elif db_user.avatar.r2_original_url:
# 其次使用原始头像URL
avatar_url = str(db_user.avatar.r2_original_url)
# 如果还是没有找到,通过查询获取
# if db_session and avatar_url is None:
# try:
# # 导入UserAvatar模型
# # 尝试查找用户的头像记录
# statement = select(UserAvatar).where(
# UserAvatar.user_id == user_id, UserAvatar.is_active == True
# )
# avatar_record = db_session.exec(statement).first()
# if avatar_record is not None:
# if avatar_record.r2_game_url is not None:
# # 优先使用游戏用的头像URL
# avatar_url = str(avatar_record.r2_game_url)
# elif avatar_record.r2_original_url is not None:
# # 其次使用原始头像URL
# avatar_url = str(avatar_record.r2_original_url)
# except Exception as e:
# print(f"获取用户头像时出错: {e}")
# print(f"最终头像URL: {avatar_url}")
# 如果仍然没有找到头像URL则使用默认URL
if avatar_url is None:
avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1"
# 处理 profile_order 列表排序
profile_order = [
"me",
"recent_activity",
"top_ranks",
"medals",
"historical",
"beatmaps",
"kudosu",
]
if profile and profile.profile_order:
profile_order = profile.profile_order.split(",")
# 在convert_db_user_to_api_user函数中添加active_tournament_banners处理
active_tournament_banners = []
if db_user.active_banners:
for banner in db_user.active_banners:
active_tournament_banners.append(
{
"tournament_id": banner.tournament_id,
"image_url": banner.image_url,
"is_active": banner.is_active,
}
)
# 在convert_db_user_to_api_user函数中添加badges处理
badges = []
if db_user.lazer_badges:
for badge in db_user.lazer_badges:
badges.append(
{
"badge_id": badge.badge_id,
"awarded_at": badge.awarded_at,
"description": badge.description,
"image_url": badge.image_url,
"url": badge.url,
}
)
# 在convert_db_user_to_api_user函数中添加monthly_playcounts处理
monthly_playcounts = []
if db_user.lazer_monthly_playcounts:
for playcount in db_user.lazer_monthly_playcounts:
monthly_playcounts.append(
{
"start_date": playcount.start_date.isoformat()
if playcount.start_date
else None,
"play_count": playcount.play_count,
}
)
# 在convert_db_user_to_api_user函数中添加previous_usernames处理
previous_usernames = []
if db_user.lazer_previous_usernames:
for username in db_user.lazer_previous_usernames:
previous_usernames.append(
{
"username": username.username,
"changed_at": username.changed_at.isoformat()
if username.changed_at
else None,
}
)
# 在convert_db_user_to_api_user函数中添加replays_watched_counts处理
replays_watched_counts = []
if hasattr(db_user, "lazer_replays_watched") and db_user.lazer_replays_watched:
for replay in db_user.lazer_replays_watched:
replays_watched_counts.append(
{
"start_date": replay.start_date.isoformat()
if replay.start_date
else None,
"count": replay.count,
}
)
# 创建用户对象
user = User(
id=user_id,
username=user_name,
avatar_url=avatar_url,
country_code=str(country_code),
default_group=profile.default_group if profile else "default",
is_active=profile.is_active,
is_bot=profile.is_bot,
is_deleted=profile.is_deleted,
is_online=profile.is_online,
is_supporter=profile.is_supporter,
is_restricted=profile.is_restricted,
last_visit=db_user.last_visit,
pm_friends_only=profile.pm_friends_only,
profile_colour=profile.profile_colour,
cover_url=profile.cover_url
if profile and profile.cover_url
else "https://assets.ppy.sh/user-profile-covers/default.jpeg",
discord=profile.discord if profile else None,
has_supported=profile.has_supported if profile else False,
interests=profile.interests if profile else None,
join_date=profile.join_date if profile.join_date else datetime.now(UTC),
location=profile.location if profile else None,
max_blocks=profile.max_blocks if profile and profile.max_blocks else 100,
max_friends=profile.max_friends if profile and profile.max_friends else 500,
post_count=profile.post_count if profile and profile.post_count else 0,
profile_hue=profile.profile_hue if profile and profile.profile_hue else None,
profile_order=profile_order, # 使用排序后的 profile_order
title=profile.title if profile else None,
title_url=profile.title_url if profile else None,
twitter=profile.twitter if profile else None,
website=profile.website if profile else None,
session_verified=True,
support_level=profile.support_level if profile else 0,
country=country,
cover=cover,
kudosu=kudosu,
statistics=statistics,
statistics_rulesets=statistics_rulesets,
beatmap_playcounts_count=lzrcnt.beatmap_playcounts_count if lzrcnt else 0,
comments_count=lzrcnt.comments_count if lzrcnt else 0,
favourite_beatmapset_count=lzrcnt.favourite_beatmapset_count if lzrcnt else 0,
follower_count=lzrcnt.follower_count if lzrcnt else 0,
graveyard_beatmapset_count=lzrcnt.graveyard_beatmapset_count if lzrcnt else 0,
guest_beatmapset_count=lzrcnt.guest_beatmapset_count if lzrcnt else 0,
loved_beatmapset_count=lzrcnt.loved_beatmapset_count if lzrcnt else 0,
mapping_follower_count=lzrcnt.mapping_follower_count if lzrcnt else 0,
nominated_beatmapset_count=lzrcnt.nominated_beatmapset_count if lzrcnt else 0,
pending_beatmapset_count=lzrcnt.pending_beatmapset_count if lzrcnt else 0,
ranked_beatmapset_count=lzrcnt.ranked_beatmapset_count if lzrcnt else 0,
ranked_and_approved_beatmapset_count=lzrcnt.ranked_and_approved_beatmapset_count
if lzrcnt
else 0,
unranked_beatmapset_count=lzrcnt.unranked_beatmapset_count if lzrcnt else 0,
scores_best_count=lzrcnt.scores_best_count if lzrcnt else 0,
scores_first_count=lzrcnt.scores_first_count if lzrcnt else 0,
scores_pinned_count=lzrcnt.scores_pinned_count,
scores_recent_count=lzrcnt.scores_recent_count if lzrcnt else 0,
account_history=[], # TODO: 获取用户历史账户信息
# active_tournament_banner=len(active_tournament_banners),
active_tournament_banners=active_tournament_banners,
badges=badges,
current_season_stats=None,
daily_challenge_user_stats=DailyChallengeStats(
user_id=user_id,
daily_streak_best=db_user.daily_challenge_stats.daily_streak_best
if db_user.daily_challenge_stats
else 0,
daily_streak_current=db_user.daily_challenge_stats.daily_streak_current
if db_user.daily_challenge_stats
else 0,
last_update=db_user.daily_challenge_stats.last_update
if db_user.daily_challenge_stats
else None,
last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak
if db_user.daily_challenge_stats
else None,
playcount=db_user.daily_challenge_stats.playcount
if db_user.daily_challenge_stats
else 0,
top_10p_placements=db_user.daily_challenge_stats.top_10p_placements
if db_user.daily_challenge_stats
else 0,
top_50p_placements=db_user.daily_challenge_stats.top_50p_placements
if db_user.daily_challenge_stats
else 0,
weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best
if db_user.daily_challenge_stats
else 0,
weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current
if db_user.daily_challenge_stats
else 0,
),
groups=[],
monthly_playcounts=monthly_playcounts,
page=Page(html=profile.page_html or "", raw=profile.page_raw or "")
if profile.page_html or profile.page_raw
else Page(),
previous_usernames=previous_usernames,
rank_highest=rank_highest,
rank_history=rank_history,
rankHistory=rank_history,
replays_watched_counts=replays_watched_counts,
team=team,
user_achievements=user_achievements,
)
return user
def camel_to_snake(name: str) -> str:
"""Convert a camelCase string to snake_case."""
result = []
last_chr = ""
for char in name:
if char.isupper():
if not last_chr.isupper() and result:
result.append("_")
result.append(char.lower())
else:
result.append(char)
last_chr = char
return "".join(result)
def get_country_name(country_code: str) -> str:
"""根据国家代码获取国家名称"""
country_names = {
"CN": "China",
"JP": "Japan",
"US": "United States",
"GB": "United Kingdom",
"DE": "Germany",
"FR": "France",
"KR": "South Korea",
"CA": "Canada",
"AU": "Australia",
"BR": "Brazil",
# 可以添加更多国家
def snake_to_camel(name: str, use_abbr: bool = True) -> str:
"""Convert a snake_case string to camelCase."""
if not name:
return name
parts = name.split("_")
if not parts:
return name
# 常见缩写词列表
abbreviations = {
"id",
"url",
"api",
"http",
"https",
"xml",
"json",
"css",
"html",
"sql",
"db",
}
return country_names.get(country_code, "Unknown")
result = []
for part in parts:
if part.lower() in abbreviations and use_abbr:
result.append(part.upper())
else:
if result:
result.append(part.capitalize())
else:
result.append(part.lower())
return "".join(result)
def snake_to_pascal(name: str, use_abbr: bool = True) -> str:
"""Convert a snake_case string to PascalCase."""
if not name:
return name
parts = name.split("_")
if not parts:
return name
# 常见缩写词列表
abbreviations = {
"id",
"url",
"api",
"http",
"https",
"xml",
"json",
"css",
"html",
"sql",
"db",
}
result = []
for part in parts:
if part.lower() in abbreviations and use_abbr:
result.append(part.upper())
else:
result.append(part.capitalize())
return "".join(result)

114
main.py
View File

@@ -4,25 +4,27 @@ from contextlib import asynccontextmanager
from datetime import datetime
from app.config import settings
from app.database import Team # noqa: F401
from app.dependencies.database import create_tables, engine
from app.dependencies.database import create_tables, engine, redis_client
from app.dependencies.fetcher import get_fetcher
from app.models.user import User
from app.router import api_router, auth_router, fetcher_router, signalr_router
from app.router import (
api_router,
auth_router,
fetcher_router,
signalr_router,
)
from fastapi import FastAPI
User.model_rebuild()
@asynccontextmanager
async def lifespan(app: FastAPI):
# on startup
await create_tables()
get_fetcher() # 初始化 fetcher
await get_fetcher() # 初始化 fetcher
# on shutdown
yield
await engine.dispose()
await redis_client.aclose()
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
@@ -44,104 +46,6 @@ async def health_check():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
# @app.get("/api/v2/friends")
# async def get_friends():
# return JSONResponse(
# content=[
# {
# "id": 123456,
# "username": "BestFriend",
# "is_online": True,
# "is_supporter": False,
# "country": {"code": "US", "name": "United States"},
# }
# ]
# )
# @app.get("/api/v2/notifications")
# async def get_notifications():
# return JSONResponse(content={"notifications": [], "unread_count": 0})
# @app.post("/api/v2/chat/ack")
# async def chat_ack():
# return JSONResponse(content={"status": "ok"})
# @app.get("/api/v2/users/{user_id}/{mode}")
# async def get_user_mode(user_id: int, mode: str):
# return JSONResponse(
# content={
# "id": user_id,
# "username": "测试测试测",
# "statistics": {
# "level": {"current": 97, "progress": 96},
# "pp": 114514,
# "global_rank": 666,
# "country_rank": 1,
# "hit_accuracy": 100,
# },
# "country": {"code": "JP", "name": "Japan"},
# }
# )
# @app.get("/api/v2/me")
# async def get_me():
# return JSONResponse(
# content={
# "id": 15651670,
# "username": "Googujiang",
# "is_online": True,
# "country": {"code": "JP", "name": "Japan"},
# "statistics": {
# "level": {"current": 97, "progress": 96},
# "pp": 2826.26,
# "global_rank": 298026,
# "country_rank": 11220,
# "hit_accuracy": 95.7168,
# },
# }
# )
# @app.post("/signalr/metadata/negotiate")
# async def metadata_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "abc123",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
# @app.post("/signalr/spectator/negotiate")
# async def spectator_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "spec456",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
# @app.post("/signalr/multiplayer/negotiate")
# async def multiplayer_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "multi789",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
if __name__ == "__main__":
from app.log import logger # noqa: F401

View File

@@ -1,8 +1,8 @@
"""score: remove best_id in database
"""beatmapset: support favourite count
Revision ID: 78be13c71791
Revises: dc4d25c428c7
Create Date: 2025-07-29 07:57:33.764517
Revision ID: 1178d0758ebf
Revises:
Create Date: 2025-08-01 04:05:09.882800
"""
@@ -15,8 +15,8 @@ import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "78be13c71791"
down_revision: str | Sequence[str] | None = "dc4d25c428c7"
revision: str = "1178d0758ebf"
down_revision: str | Sequence[str] | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
@@ -24,7 +24,7 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("scores", "best_id")
op.drop_column("beatmapsets", "favourite_count")
# ### end Alembic commands ###
@@ -32,7 +32,9 @@ def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"scores",
sa.Column("best_id", mysql.INTEGER(), autoincrement=False, nullable=True),
"beatmapsets",
sa.Column(
"favourite_count", mysql.INTEGER(), autoincrement=False, nullable=False
),
)
# ### end Alembic commands ###

View File

@@ -0,0 +1,54 @@
"""relationship: fix unique relationship
Revision ID: 58a11441d302
Revises: 1178d0758ebf
Create Date: 2025-08-01 04:23:02.498166
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "58a11441d302"
down_revision: str | Sequence[str] | None = "1178d0758ebf"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"relationship",
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
)
op.drop_constraint("PRIMARY", "relationship", type_="primary")
op.create_primary_key("pk_relationship", "relationship", ["id"])
op.alter_column(
"relationship", "user_id", existing_type=mysql.BIGINT(), nullable=True
)
op.alter_column(
"relationship", "target_id", existing_type=mysql.BIGINT(), nullable=True
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("pk_relationship", "relationship", type_="primary")
op.create_primary_key("PRIMARY", "relationship", ["user_id", "target_id"])
op.alter_column(
"relationship", "target_id", existing_type=mysql.BIGINT(), nullable=False
)
op.alter_column(
"relationship", "user_id", existing_type=mysql.BIGINT(), nullable=False
)
op.drop_column("relationship", "id")
# ### end Alembic commands ###

View File

@@ -0,0 +1,89 @@
"""playlist: index playlist id
Revision ID: d0c1b2cefe91
Revises: 58a11441d302
Create Date: 2025-08-06 06:02:10.512616
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "d0c1b2cefe91"
down_revision: str | Sequence[str] | None = "58a11441d302"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(
op.f("ix_room_playlists_id"), "room_playlists", ["id"], unique=False
)
op.create_table(
"playlist_best_scores",
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("score_id", sa.BigInteger(), nullable=False),
sa.Column("room_id", sa.Integer(), nullable=False),
sa.Column("playlist_id", sa.Integer(), nullable=False),
sa.Column("total_score", sa.BigInteger(), nullable=True),
sa.ForeignKeyConstraint(
["playlist_id"],
["room_playlists.id"],
),
sa.ForeignKeyConstraint(
["room_id"],
["rooms.id"],
),
sa.ForeignKeyConstraint(
["score_id"],
["scores.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["lazer_users.id"],
),
sa.PrimaryKeyConstraint("score_id"),
)
op.create_index(
op.f("ix_playlist_best_scores_playlist_id"),
"playlist_best_scores",
["playlist_id"],
unique=False,
)
op.create_index(
op.f("ix_playlist_best_scores_room_id"),
"playlist_best_scores",
["room_id"],
unique=False,
)
op.create_index(
op.f("ix_playlist_best_scores_user_id"),
"playlist_best_scores",
["user_id"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(
op.f("ix_playlist_best_scores_user_id"), table_name="playlist_best_scores"
)
op.drop_index(
op.f("ix_playlist_best_scores_room_id"), table_name="playlist_best_scores"
)
op.drop_index(
op.f("ix_playlist_best_scores_playlist_id"), table_name="playlist_best_scores"
)
op.drop_table("playlist_best_scores")
op.drop_index(op.f("ix_room_playlists_id"), table_name="room_playlists")
# ### end Alembic commands ###

View File

@@ -1,36 +0,0 @@
"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator
Revision ID: dc4d25c428c7
Revises:
Create Date: 2025-07-29 01:43:40.221070
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "dc4d25c428c7"
down_revision: str | Sequence[str] | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True))
op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("scores", "nsmall_tick_hit")
op.drop_column("scores", "nlarge_tick_hit")
# ### end Alembic commands ###

View File

@@ -1,11 +1,4 @@
from typing import Any
class APIMod:
def __init__(self, acronym: str, settings: dict[str, Any]) -> None: ...
@property
def acronym(self) -> str: ...
@property
def settings(self) -> str: ...
def encode(obj: Any) -> bytes: ...
def decode(data: bytes) -> Any: ...

View File

@@ -1,8 +1,6 @@
use crate::APIMod;
use chrono::{TimeZone, Utc};
use pyo3::types::PyDict;
use pyo3::{prelude::*, IntoPyObjectExt};
use std::collections::HashMap;
use std::io::Read;
pub fn read_object(
@@ -13,6 +11,8 @@ pub fn read_object(
match rmp::decode::read_marker(cursor) {
Ok(marker) => match marker {
rmp::Marker::Null => Ok(py.None()),
rmp::Marker::True => Ok(true.into_py_any(py)?),
rmp::Marker::False => Ok(false.into_py_any(py)?),
rmp::Marker::FixPos(val) => Ok(val.into_pyobject(py)?.into_any().unbind()),
rmp::Marker::FixNeg(val) => Ok(val.into_pyobject(py)?.into_any().unbind()),
rmp::Marker::U8 => {
@@ -86,8 +86,6 @@ pub fn read_object(
cursor.read_exact(&mut data).map_err(to_py_err)?;
Ok(data.into_pyobject(py)?.into_any().unbind())
}
rmp::Marker::True => Ok(true.into_py_any(py)?),
rmp::Marker::False => Ok(false.into_py_any(py)?),
rmp::Marker::FixStr(len) => read_string(py, cursor, len as u32),
rmp::Marker::Str8 => {
let mut buf = [0u8; 1];
@@ -206,13 +204,12 @@ fn read_array(
let obj1 = read_object(py, cursor, false)?;
if obj1.extract::<String>(py).map_or(false, |k| k.len() == 2) {
let obj2 = read_object(py, cursor, true)?;
return Ok(APIMod {
acronym: obj1.extract::<String>(py)?,
settings: obj2.extract::<HashMap<String, PyObject>>(py)?,
}
.into_pyobject(py)?
.into_any()
.unbind());
let api_mod_dict = PyDict::new(py);
api_mod_dict.set_item("acronym", obj1)?;
api_mod_dict.set_item("settings", obj2)?;
return Ok(api_mod_dict.into_pyobject(py)?.into_any().unbind());
} else {
items.push(obj1);
i += 1;

View File

@@ -1,8 +1,7 @@
use crate::APIMod;
use chrono::{DateTime, Utc};
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyStringMethods};
use chrono::{DateTime, Utc};
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyResult, PyStringMethods};
use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString};
use pyo3::{Bound, PyAny, PyRef, Python};
use pyo3::{Bound, PyAny};
use std::io::Write;
fn write_list(buf: &mut Vec<u8>, obj: &Bound<'_, PyList>) {
@@ -61,19 +60,42 @@ fn write_hashmap(buf: &mut Vec<u8>, obj: &Bound<'_, PyDict>) {
}
}
fn write_nil(buf: &mut Vec<u8>){
fn write_nil(buf: &mut Vec<u8>) {
rmp::encode::write_nil(buf).unwrap();
}
// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
fn write_api_mod(buf: &mut Vec<u8>, api_mod: PyRef<APIMod>) {
rmp::encode::write_array_len(buf, 2).unwrap();
rmp::encode::write_str(buf, &api_mod.acronym).unwrap();
rmp::encode::write_array_len(buf, api_mod.settings.len() as u32).unwrap();
for (k, v) in api_mod.settings.iter() {
rmp::encode::write_str(buf, k).unwrap();
Python::with_gil(|py| write_object(buf, &v.bind(py)));
fn is_api_mod(dict: &Bound<'_, PyDict>) -> bool {
if let Ok(Some(acronym)) = dict.get_item("acronym") {
if let Ok(acronym_str) = acronym.extract::<String>() {
return acronym_str.len() == 2;
}
}
false
}
// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
fn write_api_mod(buf: &mut Vec<u8>, api_mod: &Bound<'_, PyDict>) -> PyResult<()> {
let acronym = api_mod
.get_item("acronym")?
.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("APIMod missing 'acronym' field"))?;
let acronym_str = acronym.extract::<String>()?;
let settings = api_mod
.get_item("settings")?
.unwrap_or_else(|| PyDict::new(acronym.py()).into_any());
let settings_dict = settings.downcast::<PyDict>()?;
rmp::encode::write_array_len(buf, 2).unwrap();
rmp::encode::write_str(buf, &acronym_str).unwrap();
rmp::encode::write_array_len(buf, settings_dict.len() as u32).unwrap();
for (k, v) in settings_dict.iter() {
let key_str = k.extract::<String>()?;
rmp::encode::write_str(buf, &key_str).unwrap();
write_object(buf, &v);
}
Ok(())
}
fn write_datetime(buf: &mut Vec<u8>, obj: &Bound<'_, PyDateTime>) {
@@ -110,22 +132,24 @@ pub fn write_object(buf: &mut Vec<u8>, obj: &Bound<'_, PyAny>) {
write_list(buf, list);
} else if let Ok(string) = obj.downcast::<PyString>() {
write_string(buf, string);
} else if let Ok(integer) = obj.downcast::<PyInt>() {
write_integer(buf, integer);
} else if let Ok(float) = obj.downcast::<PyFloat>() {
write_float(buf, float);
} else if let Ok(boolean) = obj.downcast::<PyBool>() {
write_bool(buf, boolean);
} else if let Ok(float) = obj.downcast::<PyFloat>() {
write_float(buf, float);
} else if let Ok(integer) = obj.downcast::<PyInt>() {
write_integer(buf, integer);
} else if let Ok(bytes) = obj.downcast::<PyBytes>() {
write_bin(buf, bytes);
} else if let Ok(dict) = obj.downcast::<PyDict>() {
write_hashmap(buf, dict);
if is_api_mod(dict) {
write_api_mod(buf, dict).unwrap_or_else(|_| write_hashmap(buf, dict));
} else {
write_hashmap(buf, dict);
}
} else if let Ok(_none) = obj.downcast::<PyNone>() {
write_nil(buf);
} else if let Ok(datetime) = obj.downcast::<PyDateTime>() {
write_datetime(buf, datetime);
} else if let Ok(api_mod) = obj.extract::<PyRef<APIMod>>() {
write_api_mod(buf, api_mod);
} else {
panic!("Unsupported type");
}

View File

@@ -2,30 +2,6 @@ mod decode;
mod encode;
use pyo3::prelude::*;
use std::collections::HashMap;
#[pyclass]
struct APIMod {
#[pyo3(get, set)]
acronym: String,
#[pyo3(get, set)]
settings: HashMap<String, PyObject>,
}
#[pymethods]
impl APIMod {
#[new]
fn new(acronym: String, settings: HashMap<String, PyObject>) -> Self {
APIMod { acronym, settings }
}
fn __repr__(&self) -> String {
format!(
"APIMod(acronym='{}', settings={:?})",
self.acronym, self.settings
)
}
}
#[pyfunction]
#[pyo3(name = "encode")]
@@ -46,6 +22,5 @@ fn decode_py(py: Python, data: &[u8]) -> PyResult<PyObject> {
fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(encode_py, m)?)?;
m.add_function(wrap_pyfunction!(decode_py, m)?)?;
m.add_class::<APIMod>()?;
Ok(())
}

View File

@@ -2,4 +2,4 @@
- `mods.json`: 包含了游戏中的所有可用mod的详细信息。
- Origin: https://github.com/ppy/osu-web/blob/master/database/mods.json
- Version: 2025/6/10 `b68c920b1db3d443b9302fdc3f86010c875fe380`
- Version: 2025/7/30 `ff49b66b27a2850aea4b6b3ba563cfe936cb6082`

View File

@@ -2438,7 +2438,8 @@
"Settings": [],
"IncompatibleMods": [
"CN",
"RX"
"RX",
"MF"
],
"RequiresConfiguration": false,
"UserPlayable": false,
@@ -2460,7 +2461,8 @@
"AC",
"AT",
"CN",
"RX"
"RX",
"MF"
],
"RequiresConfiguration": false,
"UserPlayable": false,
@@ -2477,7 +2479,8 @@
"Settings": [],
"IncompatibleMods": [
"AT",
"CN"
"CN",
"MF"
],
"RequiresConfiguration": false,
"UserPlayable": true,
@@ -2638,6 +2641,24 @@
"ValidForMultiplayerAsFreeMod": true,
"AlwaysValidForSubmission": false
},
{
"Acronym": "MF",
"Name": "Moving Fast",
"Description": "Dashing by default, slow down!",
"Type": "Fun",
"Settings": [],
"IncompatibleMods": [
"AT",
"CN",
"RX"
],
"RequiresConfiguration": false,
"UserPlayable": true,
"ValidForMultiplayer": true,
"ValidForFreestyleAsRequiredMod": false,
"ValidForMultiplayerAsFreeMod": true,
"AlwaysValidForSubmission": false
},
{
"Acronym": "SV2",
"Name": "Score V2",