Merge branch 'feat/multiplayer-api' of https://github.com/GooGuTeam/osu_lazer_api into feat/multiplayer-api
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -37,6 +37,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
test-cert/
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
@@ -8,7 +8,7 @@ import string
|
||||
from app.config import settings
|
||||
from app.database import (
|
||||
OAuthToken,
|
||||
User as DBUser,
|
||||
User,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
@@ -74,7 +74,7 @@ def get_password_hash(password: str) -> str:
|
||||
|
||||
async def authenticate_user_legacy(
|
||||
db: AsyncSession, name: str, password: str
|
||||
) -> DBUser | None:
|
||||
) -> User | None:
|
||||
"""
|
||||
验证用户身份 - 使用类似 from_login 的逻辑
|
||||
"""
|
||||
@@ -82,7 +82,7 @@ async def authenticate_user_legacy(
|
||||
pw_md5 = hashlib.md5(password.encode()).hexdigest()
|
||||
|
||||
# 2. 根据用户名查找用户
|
||||
statement = select(DBUser).where(DBUser.name == name)
|
||||
statement = select(User).where(User.username == name)
|
||||
user = (await db.exec(statement)).first()
|
||||
if not user:
|
||||
return None
|
||||
@@ -113,7 +113,7 @@ async def authenticate_user_legacy(
|
||||
|
||||
async def authenticate_user(
|
||||
db: AsyncSession, username: str, password: str
|
||||
) -> DBUser | None:
|
||||
) -> User | None:
|
||||
"""验证用户身份"""
|
||||
return await authenticate_user_legacy(db, username, password)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .achievement import UserAchievement, UserAchievementResp
|
||||
from .auth import OAuthToken
|
||||
from .beatmap import (
|
||||
Beatmap as Beatmap,
|
||||
@@ -8,65 +9,64 @@ from .beatmapset import (
|
||||
BeatmapsetResp as BeatmapsetResp,
|
||||
)
|
||||
from .best_score import BestScore
|
||||
from .legacy import LegacyOAuthToken, LegacyUserStatistics
|
||||
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
from .lazer_user import (
|
||||
User,
|
||||
UserResp,
|
||||
)
|
||||
from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||
from .playlist_attempts import ItemAttemptsCount, ItemAttemptsResp
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
from .playlists import Playlist, PlaylistResp
|
||||
from .pp_best_score import PPBestScore
|
||||
from .relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from .room import Room, RoomResp
|
||||
from .score import (
|
||||
MultiplayerScores,
|
||||
Score,
|
||||
ScoreAround,
|
||||
ScoreBase,
|
||||
ScoreResp,
|
||||
ScoreStatistics,
|
||||
)
|
||||
from .score_token import ScoreToken, ScoreTokenResp
|
||||
from .statistics import (
|
||||
UserStatistics,
|
||||
UserStatisticsResp,
|
||||
)
|
||||
from .team import Team, TeamMember
|
||||
from .user import (
|
||||
DailyChallengeStats,
|
||||
LazerUserAchievement,
|
||||
LazerUserBadge,
|
||||
LazerUserBanners,
|
||||
LazerUserCountry,
|
||||
LazerUserCounts,
|
||||
LazerUserKudosu,
|
||||
LazerUserMonthlyPlaycounts,
|
||||
LazerUserPreviousUsername,
|
||||
LazerUserProfile,
|
||||
LazerUserProfileSections,
|
||||
LazerUserReplaysWatched,
|
||||
LazerUserStatistics,
|
||||
RankHistory,
|
||||
User,
|
||||
UserAchievement,
|
||||
UserAvatar,
|
||||
from .user_account_history import (
|
||||
UserAccountHistory,
|
||||
UserAccountHistoryResp,
|
||||
UserAccountHistoryType,
|
||||
)
|
||||
|
||||
BeatmapsetResp.model_rebuild()
|
||||
BeatmapResp.model_rebuild()
|
||||
__all__ = [
|
||||
"Beatmap",
|
||||
"BeatmapResp",
|
||||
"Beatmapset",
|
||||
"BeatmapsetResp",
|
||||
"BestScore",
|
||||
"DailyChallengeStats",
|
||||
"LazerUserAchievement",
|
||||
"LazerUserBadge",
|
||||
"LazerUserBanners",
|
||||
"LazerUserCountry",
|
||||
"LazerUserCounts",
|
||||
"LazerUserKudosu",
|
||||
"LazerUserMonthlyPlaycounts",
|
||||
"LazerUserPreviousUsername",
|
||||
"LazerUserProfile",
|
||||
"LazerUserProfileSections",
|
||||
"LazerUserReplaysWatched",
|
||||
"LazerUserStatistics",
|
||||
"LegacyOAuthToken",
|
||||
"LegacyUserStatistics",
|
||||
"DailyChallengeStatsResp",
|
||||
"FavouriteBeatmapset",
|
||||
"ItemAttemptsCount",
|
||||
"ItemAttemptsResp",
|
||||
"MultiplayerEvent",
|
||||
"MultiplayerEventResp",
|
||||
"MultiplayerScores",
|
||||
"OAuthToken",
|
||||
"RankHistory",
|
||||
"PPBestScore",
|
||||
"Playlist",
|
||||
"PlaylistBestScore",
|
||||
"PlaylistResp",
|
||||
"Relationship",
|
||||
"RelationshipResp",
|
||||
"RelationshipType",
|
||||
"Room",
|
||||
"RoomResp",
|
||||
"Score",
|
||||
"ScoreAround",
|
||||
"ScoreBase",
|
||||
"ScoreResp",
|
||||
"ScoreStatistics",
|
||||
@@ -75,6 +75,17 @@ __all__ = [
|
||||
"Team",
|
||||
"TeamMember",
|
||||
"User",
|
||||
"UserAccountHistory",
|
||||
"UserAccountHistoryResp",
|
||||
"UserAccountHistoryType",
|
||||
"UserAchievement",
|
||||
"UserAvatar",
|
||||
"UserAchievement",
|
||||
"UserAchievementResp",
|
||||
"UserResp",
|
||||
"UserStatistics",
|
||||
"UserStatisticsResp",
|
||||
]
|
||||
|
||||
for i in __all__:
|
||||
if i.endswith("Resp"):
|
||||
globals()[i].model_rebuild() # type: ignore[call-arg]
|
||||
|
||||
40
app/database/achievement.py
Normal file
40
app/database/achievement.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
|
||||
|
||||
class UserAchievementBase(SQLModel, UTCBaseModel):
|
||||
achievement_id: int = Field(primary_key=True)
|
||||
achieved_at: datetime = Field(
|
||||
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
|
||||
|
||||
class UserAchievement(UserAchievementBase, table=True):
|
||||
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True
|
||||
)
|
||||
user: "User" = Relationship(back_populates="achievement")
|
||||
|
||||
|
||||
class UserAchievementResp(UserAchievementBase):
|
||||
@classmethod
|
||||
def from_db(cls, db_model: UserAchievement) -> "UserAchievementResp":
|
||||
return cls.model_validate(db_model)
|
||||
@@ -1,19 +1,21 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
from .lazer_user import User
|
||||
|
||||
|
||||
class OAuthToken(SQLModel, table=True):
|
||||
class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("users.id"), index=True)
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
access_token: str = Field(max_length=500, unique=True)
|
||||
refresh_token: str = Field(max_length=500, unique=True)
|
||||
|
||||
@@ -7,13 +7,14 @@ from app.models.score import MODE_TO_INT, GameMode
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
from .lazer_user import User
|
||||
|
||||
|
||||
class BeatmapOwner(SQLModel):
|
||||
id: int
|
||||
@@ -66,7 +67,9 @@ class Beatmap(BeatmapBase, table=True):
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
beatmap_status: BeatmapRankStatus
|
||||
# optional
|
||||
beatmapset: Beatmapset = Relationship(back_populates="beatmaps")
|
||||
beatmapset: Beatmapset = Relationship(
|
||||
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
|
||||
@property
|
||||
def can_ranked(self) -> bool:
|
||||
@@ -87,13 +90,7 @@ class Beatmap(BeatmapBase, table=True):
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.where(Beatmap.id == resp.id)
|
||||
)
|
||||
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
|
||||
).first()
|
||||
assert beatmap is not None, "Beatmap should not be None after commit"
|
||||
return beatmap
|
||||
@@ -131,13 +128,9 @@ class Beatmap(BeatmapBase, table=True):
|
||||
) -> "Beatmap":
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap)
|
||||
.where(
|
||||
select(Beatmap).where(
|
||||
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
|
||||
)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
@@ -164,11 +157,13 @@ class BeatmapResp(BeatmapBase):
|
||||
url: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_db(
|
||||
async def from_db(
|
||||
cls,
|
||||
beatmap: Beatmap,
|
||||
query_mode: GameMode | None = None,
|
||||
from_set: bool = False,
|
||||
session: AsyncSession | None = None,
|
||||
user: "User | None" = None,
|
||||
) -> "BeatmapResp":
|
||||
beatmap_ = beatmap.model_dump()
|
||||
if query_mode is not None and beatmap.mode != query_mode:
|
||||
@@ -178,5 +173,7 @@ class BeatmapResp(BeatmapBase):
|
||||
beatmap_["ranked"] = beatmap.beatmap_status.value
|
||||
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
|
||||
if not from_set:
|
||||
beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset)
|
||||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(
|
||||
beatmap.beatmapset, session=session, user=user
|
||||
)
|
||||
return cls.model_validate(beatmap_)
|
||||
|
||||
@@ -4,13 +4,17 @@ from typing import TYPE_CHECKING, TypedDict, cast
|
||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .lazer_user import BASE_INCLUDES, User, UserResp
|
||||
|
||||
from pydantic import BaseModel, model_serializer
|
||||
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import Field, Relationship, SQLModel, col, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
|
||||
|
||||
class BeatmapCovers(SQLModel):
|
||||
@@ -88,7 +92,6 @@ class BeatmapsetBase(SQLModel):
|
||||
artist_unicode: str = Field(index=True)
|
||||
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
|
||||
creator: str
|
||||
favourite_count: int
|
||||
nsfw: bool = Field(default=False)
|
||||
play_count: int
|
||||
preview_url: str
|
||||
@@ -112,11 +115,9 @@ class BeatmapsetBase(SQLModel):
|
||||
|
||||
pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
|
||||
ratings: list[int] = Field(default=None, sa_column=Column(JSON))
|
||||
# TODO: recent_favourites: Optional[list[User]] = None
|
||||
# TODO: related_users: Optional[list[User]] = None
|
||||
# TODO: user: Optional[User] = Field(default=None)
|
||||
track_id: int | None = Field(default=None) # feature artist?
|
||||
# TODO: has_favourited
|
||||
|
||||
# BeatmapsetExtended
|
||||
bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2)))
|
||||
@@ -129,7 +130,7 @@ class BeatmapsetBase(SQLModel):
|
||||
tags: str = Field(default="", sa_column=Column(Text))
|
||||
|
||||
|
||||
class Beatmapset(BeatmapsetBase, table=True):
|
||||
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
@@ -150,6 +151,7 @@ class Beatmapset(BeatmapsetBase, table=True):
|
||||
hype_required: int = Field(default=0)
|
||||
availability_info: str | None = Field(default=None)
|
||||
download_disabled: bool = Field(default=False)
|
||||
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
||||
|
||||
@classmethod
|
||||
async def from_resp(
|
||||
@@ -197,40 +199,88 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
genre: BeatmapTranslationText | None = None
|
||||
language: BeatmapTranslationText | None = None
|
||||
nominations: BeatmapNominations | None = None
|
||||
has_favourited: bool = False
|
||||
favourite_count: int = 0
|
||||
recent_favourites: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
|
||||
async def from_db(
|
||||
cls,
|
||||
beatmapset: Beatmapset,
|
||||
include: list[str] = [],
|
||||
session: AsyncSession | None = None,
|
||||
user: User | None = None,
|
||||
) -> "BeatmapsetResp":
|
||||
from .beatmap import BeatmapResp
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
|
||||
beatmaps = [
|
||||
BeatmapResp.from_db(beatmap, from_set=True)
|
||||
for beatmap in beatmapset.beatmaps
|
||||
]
|
||||
update = {
|
||||
"beatmaps": [
|
||||
await BeatmapResp.from_db(beatmap, from_set=True)
|
||||
for beatmap in await beatmapset.awaitable_attrs.beatmaps
|
||||
],
|
||||
"hype": BeatmapHype(
|
||||
current=beatmapset.hype_current, required=beatmapset.hype_required
|
||||
),
|
||||
"availability": BeatmapAvailability(
|
||||
more_information=beatmapset.availability_info,
|
||||
download_disabled=beatmapset.download_disabled,
|
||||
),
|
||||
"genre": BeatmapTranslationText(
|
||||
name=beatmapset.beatmap_genre.name,
|
||||
id=beatmapset.beatmap_genre.value,
|
||||
),
|
||||
"language": BeatmapTranslationText(
|
||||
name=beatmapset.beatmap_language.name,
|
||||
id=beatmapset.beatmap_language.value,
|
||||
),
|
||||
"nominations": BeatmapNominations(
|
||||
required=beatmapset.nominations_required,
|
||||
current=beatmapset.nominations_current,
|
||||
),
|
||||
"status": beatmapset.beatmap_status.name.lower(),
|
||||
"ranked": beatmapset.beatmap_status.value,
|
||||
"is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
|
||||
**beatmapset.model_dump(),
|
||||
}
|
||||
if session and user:
|
||||
existing_favourite = (
|
||||
await session.exec(
|
||||
select(FavouriteBeatmapset).where(
|
||||
FavouriteBeatmapset.beatmapset_id == beatmapset.id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
update["has_favourited"] = existing_favourite is not None
|
||||
|
||||
if session and "recent_favourites" in include:
|
||||
recent_favourites = (
|
||||
await session.exec(
|
||||
select(FavouriteBeatmapset)
|
||||
.where(
|
||||
FavouriteBeatmapset.beatmapset_id == beatmapset.id,
|
||||
)
|
||||
.order_by(col(FavouriteBeatmapset.date).desc())
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
update["recent_favourites"] = [
|
||||
await UserResp.from_db(
|
||||
await favourite.awaitable_attrs.user,
|
||||
session=session,
|
||||
include=BASE_INCLUDES,
|
||||
)
|
||||
for favourite in recent_favourites
|
||||
]
|
||||
|
||||
if session:
|
||||
update["favourite_count"] = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(FavouriteBeatmapset)
|
||||
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||
)
|
||||
).one()
|
||||
return cls.model_validate(
|
||||
{
|
||||
"beatmaps": beatmaps,
|
||||
"hype": BeatmapHype(
|
||||
current=beatmapset.hype_current, required=beatmapset.hype_required
|
||||
),
|
||||
"availability": BeatmapAvailability(
|
||||
more_information=beatmapset.availability_info,
|
||||
download_disabled=beatmapset.download_disabled,
|
||||
),
|
||||
"genre": BeatmapTranslationText(
|
||||
name=beatmapset.beatmap_genre.name,
|
||||
id=beatmapset.beatmap_genre.value,
|
||||
),
|
||||
"language": BeatmapTranslationText(
|
||||
name=beatmapset.beatmap_language.name,
|
||||
id=beatmapset.beatmap_language.value,
|
||||
),
|
||||
"nominations": BeatmapNominations(
|
||||
required=beatmapset.nominations_required,
|
||||
current=beatmapset.nominations_current,
|
||||
),
|
||||
"status": beatmapset.beatmap_status.name.lower(),
|
||||
"ranked": beatmapset.beatmap_status.value,
|
||||
"is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
|
||||
**beatmapset.model_dump(),
|
||||
}
|
||||
update,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.score import GameMode
|
||||
from app.models.score import GameMode, Rank
|
||||
|
||||
from .user import User
|
||||
from .lazer_user import User
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
@@ -20,22 +20,27 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class BestScore(SQLModel, table=True):
|
||||
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType]
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("users.id"), index=True)
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
pp: float = Field(
|
||||
sa_column=Column(Float, default=0),
|
||||
)
|
||||
acc: float = Field(
|
||||
sa_column=Column(Float, default=0),
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
mods: list[str] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
rank: Rank
|
||||
|
||||
user: User = Relationship()
|
||||
score: "Score" = Relationship()
|
||||
score: "Score" = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "[BestScore.score_id]",
|
||||
"lazy": "joined",
|
||||
}
|
||||
)
|
||||
beatmap: "Beatmap" = Relationship()
|
||||
|
||||
58
app/database/daily_challenge.py
Normal file
58
app/database/daily_challenge.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
|
||||
|
||||
class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
|
||||
daily_streak_best: int = Field(default=0)
|
||||
daily_streak_current: int = Field(default=0)
|
||||
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
last_weekly_streak: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime)
|
||||
)
|
||||
playcount: int = Field(default=0)
|
||||
top_10p_placements: int = Field(default=0)
|
||||
top_50p_placements: int = Field(default=0)
|
||||
weekly_streak_best: int = Field(default=0)
|
||||
weekly_streak_current: int = Field(default=0)
|
||||
|
||||
|
||||
class DailyChallengeStats(DailyChallengeStatsBase, table=True):
|
||||
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("lazer_users.id"),
|
||||
unique=True,
|
||||
index=True,
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
user: "User" = Relationship(back_populates="daily_challenge_stats")
|
||||
|
||||
|
||||
class DailyChallengeStatsResp(DailyChallengeStatsBase):
|
||||
user_id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(
|
||||
cls,
|
||||
obj: DailyChallengeStats,
|
||||
) -> "DailyChallengeStatsResp":
|
||||
return cls.model_validate(obj)
|
||||
53
app/database/favourite_beatmapset.py
Normal file
53
app/database/favourite_beatmapset.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import datetime
|
||||
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.lazer_user import User
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
|
||||
class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True):
|
||||
__tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
|
||||
exclude=True,
|
||||
)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("lazer_users.id"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
beatmapset_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
ForeignKey("beatmapsets.id"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
date: datetime.datetime = Field(
|
||||
default=datetime.datetime.now(datetime.UTC),
|
||||
sa_column=Column(
|
||||
DateTime,
|
||||
),
|
||||
)
|
||||
|
||||
user: User = Relationship(back_populates="favourite_beatmapsets")
|
||||
beatmapset: Beatmapset = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"lazy": "selectin",
|
||||
},
|
||||
back_populates="favourites",
|
||||
)
|
||||
333
app/database/lazer_user.py
Normal file
333
app/database/lazer_user.py
Normal file
@@ -0,0 +1,333 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, NotRequired, TypedDict
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import Country, Page, RankHistory
|
||||
|
||||
from .achievement import UserAchievement, UserAchievementResp
|
||||
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
|
||||
from .monthly_playcounts import MonthlyPlaycounts, MonthlyPlaycountsResp
|
||||
from .statistics import UserStatistics, UserStatisticsResp
|
||||
from .team import Team, TeamMember
|
||||
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
from .relationship import RelationshipResp
|
||||
|
||||
|
||||
class Kudosu(TypedDict):
|
||||
available: int
|
||||
total: int
|
||||
|
||||
|
||||
class RankHighest(TypedDict):
|
||||
rank: int
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class UserProfileCover(TypedDict):
|
||||
url: str
|
||||
custom_url: NotRequired[str]
|
||||
id: NotRequired[str]
|
||||
|
||||
|
||||
Badge = TypedDict(
|
||||
"Badge",
|
||||
{
|
||||
"awarded_at": datetime,
|
||||
"description": str,
|
||||
"image@2x_url": str,
|
||||
"image_url": str,
|
||||
"url": str,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class UserBase(UTCBaseModel, SQLModel):
|
||||
avatar_url: str = ""
|
||||
country_code: str = Field(default="CN", max_length=2, index=True)
|
||||
# ? default_group: str|None
|
||||
is_active: bool = True
|
||||
is_bot: bool = False
|
||||
is_supporter: bool = False
|
||||
last_visit: datetime | None = Field(
|
||||
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
pm_friends_only: bool = False
|
||||
profile_colour: str | None = None
|
||||
username: str = Field(max_length=32, unique=True, index=True)
|
||||
page: Page = Field(sa_column=Column(JSON), default=Page(html="", raw=""))
|
||||
previous_usernames: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
# TODO: replays_watched_counts
|
||||
support_level: int = 0
|
||||
badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
|
||||
# optional
|
||||
is_restricted: bool = False
|
||||
# blocks
|
||||
cover: UserProfileCover = Field(
|
||||
default=UserProfileCover(
|
||||
url="https://assets.ppy.sh/user-profile-covers/default.jpeg"
|
||||
),
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
beatmap_playcounts_count: int = 0
|
||||
# kudosu
|
||||
|
||||
# UserExtended
|
||||
playmode: GameMode = GameMode.OSU
|
||||
discord: str | None = None
|
||||
has_supported: bool = False
|
||||
interests: str | None = None
|
||||
join_date: datetime = Field(default=datetime.now(UTC))
|
||||
location: str | None = None
|
||||
max_blocks: int = 50
|
||||
max_friends: int = 500
|
||||
occupation: str | None = None
|
||||
playstyle: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
# TODO: post_count
|
||||
profile_hue: int | None = None
|
||||
profile_order: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"me",
|
||||
"recent_activity",
|
||||
"top_ranks",
|
||||
"medals",
|
||||
"historical",
|
||||
"beatmaps",
|
||||
"kudosu",
|
||||
],
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
title: str | None = None
|
||||
title_url: str | None = None
|
||||
twitter: str | None = None
|
||||
website: str | None = None
|
||||
|
||||
# undocumented
|
||||
comments_count: int = 0
|
||||
post_count: int = 0
|
||||
is_admin: bool = False
|
||||
is_gmt: bool = False
|
||||
is_qat: bool = False
|
||||
is_bng: bool = False
|
||||
|
||||
|
||||
class User(AsyncAttrs, UserBase, table=True):
|
||||
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
|
||||
)
|
||||
account_history: list[UserAccountHistory] = Relationship()
|
||||
statistics: list[UserStatistics] = Relationship()
|
||||
achievement: list[UserAchievement] = Relationship(back_populates="user")
|
||||
team_membership: TeamMember | None = Relationship(back_populates="user")
|
||||
daily_challenge_stats: DailyChallengeStats | None = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
|
||||
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
|
||||
email: str = Field(max_length=254, unique=True, index=True, exclude=True)
|
||||
priv: int = Field(default=1, exclude=True)
|
||||
pw_bcrypt: str = Field(max_length=60, exclude=True)
|
||||
silence_end_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
||||
)
|
||||
donor_end_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
||||
)
|
||||
|
||||
|
||||
class UserResp(UserBase):
|
||||
id: int | None = None
|
||||
is_online: bool = False
|
||||
groups: list = [] # TODO
|
||||
country: Country = Field(default_factory=lambda: Country(code="CN", name="China"))
|
||||
favourite_beatmapset_count: int = 0 # TODO
|
||||
graveyard_beatmapset_count: int = 0 # TODO
|
||||
guest_beatmapset_count: int = 0 # TODO
|
||||
loved_beatmapset_count: int = 0 # TODO
|
||||
mapping_follower_count: int = 0 # TODO
|
||||
nominated_beatmapset_count: int = 0 # TODO
|
||||
pending_beatmapset_count: int = 0 # TODO
|
||||
ranked_beatmapset_count: int = 0 # TODO
|
||||
follow_user_mapping: list[int] = Field(default_factory=list)
|
||||
follower_count: int = 0
|
||||
friends: list["RelationshipResp"] | None = None
|
||||
scores_best_count: int = 0
|
||||
scores_first_count: int = 0
|
||||
scores_recent_count: int = 0
|
||||
scores_pinned_count: int = 0
|
||||
account_history: list[UserAccountHistoryResp] = []
|
||||
active_tournament_banners: list[dict] = [] # TODO
|
||||
kudosu: Kudosu = Field(default_factory=lambda: Kudosu(available=0, total=0)) # TODO
|
||||
monthly_playcounts: list[MonthlyPlaycountsResp] = Field(default_factory=list)
|
||||
unread_pm_count: int = 0 # TODO
|
||||
rank_history: RankHistory | None = None # TODO
|
||||
rank_highest: RankHighest | None = None # TODO
|
||||
statistics: UserStatisticsResp | None = None
|
||||
statistics_rulesets: dict[str, UserStatisticsResp] | None = None
|
||||
user_achievements: list[UserAchievementResp] = Field(default_factory=list)
|
||||
cover_url: str = "" # deprecated
|
||||
team: Team | None = None
|
||||
session_verified: bool = True
|
||||
daily_challenge_user_stats: DailyChallengeStatsResp | None = None
|
||||
|
||||
# TODO: monthly_playcounts, unread_pm_count, rank_history, user_preferences
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
obj: User,
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
ruleset: GameMode | None = None,
|
||||
) -> "UserResp":
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
from .best_score import BestScore
|
||||
from .relationship import Relationship, RelationshipResp, RelationshipType
|
||||
|
||||
u = cls.model_validate(obj.model_dump())
|
||||
u.id = obj.id
|
||||
u.follower_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(Relationship)
|
||||
.where(
|
||||
Relationship.target_id == obj.id,
|
||||
Relationship.type == RelationshipType.FOLLOW,
|
||||
)
|
||||
)
|
||||
).one()
|
||||
u.scores_best_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(BestScore)
|
||||
.where(
|
||||
BestScore.user_id == obj.id,
|
||||
)
|
||||
.limit(200)
|
||||
)
|
||||
).one()
|
||||
redis = get_redis()
|
||||
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
|
||||
u.cover_url = (
|
||||
obj.cover.get(
|
||||
"url", "https://assets.ppy.sh/user-profile-covers/default.jpeg"
|
||||
)
|
||||
if obj.cover
|
||||
else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
|
||||
)
|
||||
|
||||
if "friends" in include:
|
||||
u.friends = [
|
||||
await RelationshipResp.from_db(session, r)
|
||||
for r in (
|
||||
await session.exec(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == obj.id,
|
||||
Relationship.type == RelationshipType.FOLLOW,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
if "team" in include:
|
||||
if await obj.awaitable_attrs.team_membership:
|
||||
assert obj.team_membership
|
||||
u.team = obj.team_membership.team
|
||||
|
||||
if "account_history" in include:
|
||||
u.account_history = [
|
||||
UserAccountHistoryResp.from_db(ah)
|
||||
for ah in await obj.awaitable_attrs.account_history
|
||||
]
|
||||
|
||||
if "daily_challenge_user_stats":
|
||||
if await obj.awaitable_attrs.daily_challenge_stats:
|
||||
assert obj.daily_challenge_stats
|
||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
|
||||
obj.daily_challenge_stats
|
||||
)
|
||||
|
||||
if "statistics" in include:
|
||||
current_stattistics = None
|
||||
for i in await obj.awaitable_attrs.statistics:
|
||||
if i.mode == (ruleset or obj.playmode):
|
||||
current_stattistics = i
|
||||
break
|
||||
u.statistics = (
|
||||
UserStatisticsResp.from_db(current_stattistics)
|
||||
if current_stattistics
|
||||
else None
|
||||
)
|
||||
|
||||
if "statistics_rulesets" in include:
|
||||
u.statistics_rulesets = {
|
||||
i.mode.value: UserStatisticsResp.from_db(i)
|
||||
for i in await obj.awaitable_attrs.statistics
|
||||
}
|
||||
|
||||
if "monthly_playcounts" in include:
|
||||
u.monthly_playcounts = [
|
||||
MonthlyPlaycountsResp.from_db(pc)
|
||||
for pc in await obj.awaitable_attrs.monthly_playcounts
|
||||
]
|
||||
|
||||
if "achievements" in include:
|
||||
u.user_achievements = [
|
||||
UserAchievementResp.from_db(ua)
|
||||
for ua in await obj.awaitable_attrs.achievement
|
||||
]
|
||||
|
||||
return u
|
||||
|
||||
|
||||
ALL_INCLUDED = [
|
||||
"friends",
|
||||
"team",
|
||||
"account_history",
|
||||
"daily_challenge_user_stats",
|
||||
"statistics",
|
||||
"statistics_rulesets",
|
||||
"achievements",
|
||||
"monthly_playcounts",
|
||||
]
|
||||
|
||||
|
||||
SEARCH_INCLUDED = [
|
||||
"team",
|
||||
"daily_challenge_user_stats",
|
||||
"statistics",
|
||||
"statistics_rulesets",
|
||||
"achievements",
|
||||
"monthly_playcounts",
|
||||
]
|
||||
|
||||
BASE_INCLUDES = [
|
||||
"team",
|
||||
"daily_challenge_user_stats",
|
||||
"statistics",
|
||||
]
|
||||
@@ -1,94 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
# ============================================
|
||||
# 旧的兼容性表模型(保留以便向后兼容)
|
||||
# ============================================
|
||||
|
||||
|
||||
class LegacyUserStatistics(SQLModel, table=True):
|
||||
__tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
mode: str = Field(max_length=10) # osu, taiko, fruits, mania
|
||||
|
||||
# 基本统计
|
||||
count_100: int = Field(default=0)
|
||||
count_300: int = Field(default=0)
|
||||
count_50: int = Field(default=0)
|
||||
count_miss: int = Field(default=0)
|
||||
|
||||
# 等级信息
|
||||
level_current: int = Field(default=1)
|
||||
level_progress: int = Field(default=0)
|
||||
|
||||
# 排名信息
|
||||
global_rank: int | None = Field(default=None)
|
||||
global_rank_exp: int | None = Field(default=None)
|
||||
country_rank: int | None = Field(default=None)
|
||||
|
||||
# PP 和分数
|
||||
pp: float = Field(default=0.0)
|
||||
pp_exp: float = Field(default=0.0)
|
||||
ranked_score: int = Field(default=0)
|
||||
hit_accuracy: float = Field(default=0.0)
|
||||
total_score: int = Field(default=0)
|
||||
total_hits: int = Field(default=0)
|
||||
maximum_combo: int = Field(default=0)
|
||||
|
||||
# 游戏统计
|
||||
play_count: int = Field(default=0)
|
||||
play_time: int = Field(default=0)
|
||||
replays_watched_by_others: int = Field(default=0)
|
||||
is_ranked: bool = Field(default=False)
|
||||
|
||||
# 成绩等级计数
|
||||
grade_ss: int = Field(default=0)
|
||||
grade_ssh: int = Field(default=0)
|
||||
grade_s: int = Field(default=0)
|
||||
grade_sh: int = Field(default=0)
|
||||
grade_a: int = Field(default=0)
|
||||
|
||||
# 最高排名记录
|
||||
rank_highest: int | None = Field(default=None)
|
||||
rank_highest_updated_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
# 关联关系
|
||||
user: Mapped["User"] = Relationship(back_populates="statistics")
|
||||
|
||||
|
||||
class LegacyOAuthToken(SQLModel, table=True):
|
||||
__tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
access_token: str = Field(max_length=255, index=True)
|
||||
refresh_token: str = Field(max_length=255, index=True)
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime))
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
previous_usernames: list = Field(default_factory=list, sa_column=Column(JSON))
|
||||
replays_watched_counts: list = Field(default_factory=list, sa_column=Column(JSON))
|
||||
|
||||
# 用户关系
|
||||
user: "User" = Relationship()
|
||||
43
app/database/monthly_playcounts.py
Normal file
43
app/database/monthly_playcounts.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from datetime import date
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
|
||||
|
||||
class MonthlyPlaycounts(SQLModel, table=True):
|
||||
__tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
year: int = Field(index=True)
|
||||
month: int = Field(index=True)
|
||||
playcount: int = Field(default=0)
|
||||
|
||||
user: "User" = Relationship(back_populates="monthly_playcounts")
|
||||
|
||||
|
||||
class MonthlyPlaycountsResp(SQLModel):
|
||||
start_date: date
|
||||
count: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_model: MonthlyPlaycounts) -> "MonthlyPlaycountsResp":
|
||||
return cls(
|
||||
start_date=date(db_model.year, db_model.month, 1),
|
||||
count=db_model.playcount,
|
||||
)
|
||||
56
app/database/multiplayer_event.py
Normal file
56
app/database/multiplayer_event.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerEventBase(SQLModel, UTCBaseModel):
|
||||
playlist_item_id: int | None = None
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),
|
||||
)
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
event_type: str = Field(index=True)
|
||||
|
||||
|
||||
class MultiplayerEvent(MultiplayerEventBase, table=True):
|
||||
__tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
|
||||
)
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
updated_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
event_detail: dict[str, Any] | None = Field(
|
||||
sa_column=Column(JSON),
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerEventResp(MultiplayerEventBase):
|
||||
id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, event: MultiplayerEvent) -> "MultiplayerEventResp":
|
||||
return cls.model_validate(event)
|
||||
114
app/database/playlist_attempts.py
Normal file
114
app/database/playlist_attempts.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from .lazer_user import User, UserResp
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class ItemAttemptsCountBase(SQLModel):
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
attempts: int = Field(default=0)
|
||||
completed: int = Field(default=0)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
accuracy: float = 0.0
|
||||
pp: float = 0
|
||||
total_score: int = 0
|
||||
|
||||
|
||||
class ItemAttemptsCount(ItemAttemptsCountBase, table=True):
|
||||
__tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
user: User = Relationship()
|
||||
|
||||
async def get_position(self, session: AsyncSession) -> int:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=col(ItemAttemptsCountBase.room_id),
|
||||
order_by=col(ItemAttemptsCountBase.total_score).desc(),
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
subq = select(ItemAttemptsCountBase, rownum).subquery()
|
||||
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
|
||||
result = await session.exec(stmt)
|
||||
return result.one()
|
||||
|
||||
async def update(self, session: AsyncSession):
|
||||
playlist_scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == self.room_id,
|
||||
PlaylistBestScore.user_id == self.user_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
self.attempts = sum(score.attempts for score in playlist_scores)
|
||||
self.total_score = sum(score.total_score for score in playlist_scores)
|
||||
self.pp = sum(score.score.pp for score in playlist_scores)
|
||||
self.completed = len(playlist_scores)
|
||||
self.accuracy = (
|
||||
sum(score.score.accuracy * score.attempts for score in playlist_scores)
|
||||
/ self.completed
|
||||
if self.completed > 0
|
||||
else 0.0
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(self)
|
||||
|
||||
@classmethod
|
||||
async def get_or_create(
|
||||
cls,
|
||||
room_id: int,
|
||||
user_id: int,
|
||||
session: AsyncSession,
|
||||
) -> "ItemAttemptsCount":
|
||||
item_attempts = await session.exec(
|
||||
select(cls).where(
|
||||
cls.room_id == room_id,
|
||||
cls.user_id == user_id,
|
||||
)
|
||||
)
|
||||
item_attempts = item_attempts.first()
|
||||
if item_attempts is None:
|
||||
item_attempts = cls(room_id=room_id, user_id=user_id)
|
||||
session.add(item_attempts)
|
||||
await session.commit()
|
||||
await session.refresh(item_attempts)
|
||||
await item_attempts.update(session)
|
||||
return item_attempts
|
||||
|
||||
|
||||
class ItemAttemptsResp(ItemAttemptsCountBase):
|
||||
user: UserResp | None = None
|
||||
position: int | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
item_attempts: ItemAttemptsCount,
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
) -> "ItemAttemptsResp":
|
||||
resp = cls.model_validate(item_attempts)
|
||||
resp.user = await UserResp.from_db(
|
||||
item_attempts.user,
|
||||
session=session,
|
||||
include=["statistics", "team", "daily_challenge_user_stats"],
|
||||
)
|
||||
if "position" in include:
|
||||
resp.position = await item_attempts.get_position(session)
|
||||
return resp
|
||||
109
app/database/playlist_best_score.py
Normal file
109
app/database/playlist_best_score.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .lazer_user import User
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .score import Score
|
||||
|
||||
|
||||
class PlaylistBestScore(SQLModel, table=True):
|
||||
__tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
playlist_id: int = Field(foreign_key="room_playlists.id", index=True)
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
attempts: int = Field(default=0) # playlist
|
||||
|
||||
user: User = Relationship()
|
||||
score: "Score" = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "[PlaylistBestScore.score_id]",
|
||||
"lazy": "joined",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def process_playlist_best_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
user_id: int,
|
||||
score_id: int,
|
||||
total_score: int,
|
||||
session: AsyncSession,
|
||||
redis: Redis,
|
||||
):
|
||||
previous = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.user_id == user_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if previous is None:
|
||||
score = PlaylistBestScore(
|
||||
user_id=user_id,
|
||||
score_id=score_id,
|
||||
room_id=room_id,
|
||||
playlist_id=playlist_id,
|
||||
total_score=total_score,
|
||||
)
|
||||
session.add(score)
|
||||
else:
|
||||
previous.score_id = score_id
|
||||
previous.total_score = total_score
|
||||
previous.attempts += 1
|
||||
await session.commit()
|
||||
await redis.decr(f"multiplayer:{room_id}:gameplay:players")
|
||||
|
||||
|
||||
async def get_position(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
score_id: int,
|
||||
session: AsyncSession,
|
||||
) -> int:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=(
|
||||
col(PlaylistBestScore.playlist_id),
|
||||
col(PlaylistBestScore.room_id),
|
||||
),
|
||||
order_by=col(PlaylistBestScore.total_score).desc(),
|
||||
)
|
||||
.label("row_number")
|
||||
)
|
||||
subq = (
|
||||
select(PlaylistBestScore, rownum)
|
||||
.where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
|
||||
result = await session.exec(stmt)
|
||||
s = result.one_or_none()
|
||||
return s if s else 0
|
||||
143
app/database/playlists.py
Normal file
143
app/database/playlists.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.mods import APIMod
|
||||
from app.models.multiplayer_hub import PlaylistItem
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .room import Room
|
||||
|
||||
|
||||
class PlaylistBase(SQLModel, UTCBaseModel):
|
||||
id: int = Field(index=True)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
ruleset_id: int = Field(ge=0, le=3)
|
||||
expired: bool = Field(default=False)
|
||||
playlist_order: int = Field(default=0)
|
||||
played_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True)),
|
||||
default=None,
|
||||
)
|
||||
allowed_mods: list[APIMod] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
required_mods: list[APIMod] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
beatmap_id: int = Field(
|
||||
foreign_key="beatmaps.id",
|
||||
)
|
||||
freestyle: bool = Field(default=False)
|
||||
|
||||
|
||||
class Playlist(PlaylistBase, table=True):
|
||||
__tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType]
|
||||
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
|
||||
room_id: int = Field(foreign_key="rooms.id", exclude=True)
|
||||
|
||||
beatmap: Beatmap = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"lazy": "joined",
|
||||
}
|
||||
)
|
||||
room: "Room" = Relationship()
|
||||
|
||||
@classmethod
|
||||
async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
|
||||
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(
|
||||
cls.room_id == room_id
|
||||
)
|
||||
result = await session.exec(stmt)
|
||||
return result.one()
|
||||
|
||||
@classmethod
|
||||
async def from_hub(
|
||||
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
|
||||
) -> "Playlist":
|
||||
next_id = await cls.get_next_id_for_room(room_id, session=session)
|
||||
return cls(
|
||||
id=next_id,
|
||||
owner_id=playlist.owner_id,
|
||||
ruleset_id=playlist.ruleset_id,
|
||||
beatmap_id=playlist.beatmap_id,
|
||||
required_mods=playlist.required_mods,
|
||||
allowed_mods=playlist.allowed_mods,
|
||||
expired=playlist.expired,
|
||||
playlist_order=playlist.playlist_order,
|
||||
played_at=playlist.played_at,
|
||||
freestyle=playlist.freestyle,
|
||||
room_id=room_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(
|
||||
select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
|
||||
)
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
raise ValueError("Playlist item not found")
|
||||
db_playlist.owner_id = playlist.owner_id
|
||||
db_playlist.ruleset_id = playlist.ruleset_id
|
||||
db_playlist.beatmap_id = playlist.beatmap_id
|
||||
db_playlist.required_mods = playlist.required_mods
|
||||
db_playlist.allowed_mods = playlist.allowed_mods
|
||||
db_playlist.expired = playlist.expired
|
||||
db_playlist.playlist_order = playlist.playlist_order
|
||||
db_playlist.played_at = playlist.played_at
|
||||
db_playlist.freestyle = playlist.freestyle
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def add_to_db(
|
||||
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
|
||||
):
|
||||
db_playlist = await cls.from_hub(playlist, room_id, session)
|
||||
session.add(db_playlist)
|
||||
await session.commit()
|
||||
await session.refresh(db_playlist)
|
||||
playlist.id = db_playlist.id
|
||||
|
||||
@classmethod
|
||||
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(
|
||||
select(cls).where(cls.id == item_id, cls.room_id == room_id)
|
||||
)
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
raise ValueError("Playlist item not found")
|
||||
await session.delete(db_playlist)
|
||||
await session.commit()
|
||||
|
||||
|
||||
class PlaylistResp(PlaylistBase):
|
||||
beatmap: BeatmapResp | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, playlist: Playlist, include: list[str] = []
|
||||
) -> "PlaylistResp":
|
||||
data = playlist.model_dump()
|
||||
if "beatmap" in include:
|
||||
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap, from_set=True)
|
||||
resp = cls.model_validate(data)
|
||||
return resp
|
||||
41
app/database/pp_best_score.py
Normal file
41
app/database/pp_best_score.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .lazer_user import User
|
||||
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .beatmap import Beatmap
|
||||
from .score import Score
|
||||
|
||||
|
||||
class PPBestScore(SQLModel, table=True):
|
||||
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
pp: float = Field(
|
||||
sa_column=Column(Float, default=0),
|
||||
)
|
||||
acc: float = Field(
|
||||
sa_column=Column(Float, default=0),
|
||||
)
|
||||
|
||||
user: User = Relationship()
|
||||
score: "Score" = Relationship()
|
||||
beatmap: "Beatmap" = Relationship()
|
||||
@@ -1,8 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from app.models.user import User as APIUser
|
||||
|
||||
from .user import User as DBUser
|
||||
from .lazer_user import User, UserResp
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import (
|
||||
@@ -24,12 +22,16 @@ class RelationshipType(str, Enum):
|
||||
|
||||
class Relationship(SQLModel, table=True):
|
||||
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
|
||||
exclude=True,
|
||||
)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
ForeignKey("lazer_users.id"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
@@ -37,20 +39,22 @@ class Relationship(SQLModel, table=True):
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
ForeignKey("lazer_users.id"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
||||
target: DBUser = SQLRelationship(
|
||||
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
|
||||
target: User = SQLRelationship(
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "[Relationship.target_id]",
|
||||
"lazy": "selectin",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RelationshipResp(BaseModel):
|
||||
target_id: int
|
||||
target: APIUser
|
||||
target: UserResp
|
||||
mutual: bool = False
|
||||
type: RelationshipType
|
||||
|
||||
@@ -58,8 +62,6 @@ class RelationshipResp(BaseModel):
|
||||
async def from_db(
|
||||
cls, session: AsyncSession, relationship: Relationship
|
||||
) -> "RelationshipResp":
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
|
||||
target_relationship = (
|
||||
await session.exec(
|
||||
select(Relationship).where(
|
||||
@@ -75,7 +77,16 @@ class RelationshipResp(BaseModel):
|
||||
)
|
||||
return cls(
|
||||
target_id=relationship.target_id,
|
||||
target=await convert_db_user_to_api_user(relationship.target),
|
||||
target=await UserResp.from_db(
|
||||
relationship.target,
|
||||
session,
|
||||
include=[
|
||||
"team",
|
||||
"daily_challenge_user_stats",
|
||||
"statistics",
|
||||
"statistics_rulesets",
|
||||
],
|
||||
),
|
||||
mutual=mutual,
|
||||
type=relationship.type,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,125 @@
|
||||
from sqlmodel import Field, SQLModel
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.models.multiplayer_hub import ServerMultiplayerRoom
|
||||
from app.models.room import (
|
||||
MatchType,
|
||||
QueueMode,
|
||||
RoomCategory,
|
||||
RoomDifficultyRange,
|
||||
RoomPlaylistItemStats,
|
||||
RoomStatus,
|
||||
)
|
||||
|
||||
from .lazer_user import User, UserResp
|
||||
from .playlists import Playlist, PlaylistResp
|
||||
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
|
||||
class RoomIndex(SQLModel, table=True):
|
||||
__tablename__ = "mp_room_index" # pyright: ignore[reportAssignmentType]
|
||||
id: int = Field(default=None, primary_key=True, index=True) # pyright: ignore[reportCallIssue]
|
||||
class RoomBase(SQLModel):
|
||||
name: str = Field(index=True)
|
||||
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
|
||||
duration: int | None = Field(default=None) # minutes
|
||||
starts_at: datetime | None = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
ended_at: datetime | None = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
participant_count: int = Field(default=0)
|
||||
max_attempts: int | None = Field(default=None) # playlists
|
||||
type: MatchType
|
||||
queue_mode: QueueMode
|
||||
auto_skip: bool
|
||||
auto_start_duration: int
|
||||
status: RoomStatus
|
||||
# TODO: channel_id
|
||||
# recent_participants: list[User]
|
||||
|
||||
|
||||
class Room(RoomBase, table=True):
|
||||
__tablename__ = "rooms" # pyright: ignore[reportAssignmentType]
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
host_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
|
||||
host: User = Relationship()
|
||||
playlist: list[Playlist] = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"lazy": "joined",
|
||||
"cascade": "all, delete-orphan",
|
||||
"overlaps": "room",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RoomResp(RoomBase):
|
||||
id: int
|
||||
password: str | None = None
|
||||
host: UserResp | None = None
|
||||
playlist: list[PlaylistResp] = []
|
||||
playlist_item_stats: RoomPlaylistItemStats | None = None
|
||||
difficulty_range: RoomDifficultyRange | None = None
|
||||
current_playlist_item: PlaylistResp | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, room: Room) -> "RoomResp":
|
||||
resp = cls.model_validate(room.model_dump())
|
||||
|
||||
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
|
||||
difficulty_range = RoomDifficultyRange(
|
||||
min=0,
|
||||
max=0,
|
||||
)
|
||||
rulesets = set()
|
||||
for playlist in room.playlist:
|
||||
stats.count_total += 1
|
||||
if not playlist.expired:
|
||||
stats.count_active += 1
|
||||
rulesets.add(playlist.ruleset_id)
|
||||
difficulty_range.min = min(
|
||||
difficulty_range.min, playlist.beatmap.difficulty_rating
|
||||
)
|
||||
difficulty_range.max = max(
|
||||
difficulty_range.max, playlist.beatmap.difficulty_rating
|
||||
)
|
||||
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
|
||||
stats.ruleset_ids = list(rulesets)
|
||||
resp.playlist_item_stats = stats
|
||||
resp.difficulty_range = difficulty_range
|
||||
resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
|
||||
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp":
|
||||
room = server_room.room
|
||||
resp = cls(
|
||||
id=room.room_id,
|
||||
name=room.settings.name,
|
||||
type=room.settings.match_type,
|
||||
queue_mode=room.settings.queue_mode,
|
||||
auto_skip=room.settings.auto_skip,
|
||||
auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
|
||||
status=server_room.status,
|
||||
category=server_room.category,
|
||||
# duration = room.settings.duration,
|
||||
starts_at=server_room.start_at,
|
||||
participant_count=len(room.users),
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime
|
||||
from datetime import UTC, date, datetime
|
||||
import json
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.calculator import (
|
||||
calculate_pp,
|
||||
@@ -12,9 +13,8 @@ from app.calculator import (
|
||||
calculate_weighted_pp,
|
||||
clamp,
|
||||
)
|
||||
from app.database.score_token import ScoreToken
|
||||
from app.database.user import LazerUserStatistics, User
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.database.team import TeamMember
|
||||
from app.models.model import RespWithCursor, UTCBaseModel
|
||||
from app.models.mods import APIMod, mods_can_get_pp
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
@@ -26,15 +26,24 @@ from app.models.score import (
|
||||
ScoreStatistics,
|
||||
SoloScoreSubmissionInfo,
|
||||
)
|
||||
from app.models.user import User as APIUser
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
from .beatmapset import BeatmapsetResp
|
||||
from .best_score import BestScore
|
||||
from .lazer_user import User, UserResp
|
||||
from .monthly_playcounts import MonthlyPlaycounts
|
||||
from .pp_best_score import PPBestScore
|
||||
from .relationship import (
|
||||
Relationship as DBRelationship,
|
||||
RelationshipType,
|
||||
)
|
||||
from .score_token import ScoreToken
|
||||
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
|
||||
from sqlalchemy.orm import aliased, joinedload
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
@@ -43,9 +52,10 @@ from sqlmodel import (
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
false,
|
||||
func,
|
||||
select,
|
||||
text,
|
||||
true,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql._expression_select_cls import SelectOfScalar
|
||||
@@ -54,7 +64,7 @@ if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
|
||||
class ScoreBase(SQLModel):
|
||||
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
# 基本字段
|
||||
accuracy: float
|
||||
map_md5: str = Field(max_length=32, index=True)
|
||||
@@ -78,10 +88,11 @@ class ScoreBase(SQLModel):
|
||||
default=0, sa_column=Column(BigInteger), exclude=True
|
||||
)
|
||||
type: str
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
|
||||
# optional
|
||||
# TODO: current_user_attributes
|
||||
position: int | None = Field(default=None) # multiplayer
|
||||
# position: int | None = Field(default=None) # multiplayer
|
||||
|
||||
|
||||
class Score(ScoreBase, table=True):
|
||||
@@ -89,12 +100,11 @@ class Score(ScoreBase, table=True):
|
||||
id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
|
||||
)
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
ForeignKey("lazer_users.id"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
@@ -112,28 +122,13 @@ class Score(ScoreBase, table=True):
|
||||
gamemode: GameMode = Field(index=True)
|
||||
|
||||
# optional
|
||||
beatmap: "Beatmap" = Relationship()
|
||||
user: "User" = Relationship()
|
||||
beatmap: Beatmap = Relationship()
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@property
|
||||
def is_perfect_combo(self) -> bool:
|
||||
return self.max_combo == self.beatmap.max_combo
|
||||
|
||||
@staticmethod
|
||||
def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]:
|
||||
clause = select(Score).options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
if with_user:
|
||||
return clause.options(
|
||||
joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
return clause
|
||||
|
||||
@staticmethod
|
||||
def select_clause_unique(
|
||||
*where_clauses: ColumnExpressionArgument[bool] | bool,
|
||||
@@ -147,18 +142,7 @@ class Score(ScoreBase, table=True):
|
||||
)
|
||||
subq = select(Score, rownum).where(*where_clauses).subquery()
|
||||
best = aliased(Score, subq, adapt_on_names=True)
|
||||
return (
|
||||
select(best)
|
||||
.where(subq.c.rn == 1)
|
||||
.options(
|
||||
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
return select(best).where(subq.c.rn == 1)
|
||||
|
||||
|
||||
class ScoreResp(ScoreBase):
|
||||
@@ -173,22 +157,23 @@ class ScoreResp(ScoreBase):
|
||||
ruleset_id: int | None = None
|
||||
beatmap: BeatmapResp | None = None
|
||||
beatmapset: BeatmapsetResp | None = None
|
||||
user: APIUser | None = None
|
||||
user: UserResp | None = None
|
||||
statistics: ScoreStatistics | None = None
|
||||
maximum_statistics: ScoreStatistics | None = None
|
||||
rank_global: int | None = None
|
||||
rank_country: int | None = None
|
||||
position: int | None = None
|
||||
scores_around: "ScoreAround | None" = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, session: AsyncSession, score: Score, user: User | None = None
|
||||
) -> "ScoreResp":
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
|
||||
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
||||
s = cls.model_validate(score.model_dump())
|
||||
assert score.id
|
||||
s.beatmap = BeatmapResp.from_db(score.beatmap)
|
||||
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset)
|
||||
await score.awaitable_attrs.beatmap
|
||||
s.beatmap = await BeatmapResp.from_db(score.beatmap)
|
||||
s.beatmapset = await BeatmapsetResp.from_db(
|
||||
score.beatmap.beatmapset, session=session, user=score.user
|
||||
)
|
||||
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
|
||||
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
|
||||
s.ruleset_id = MODE_TO_INT[score.gamemode]
|
||||
@@ -220,163 +205,184 @@ class ScoreResp(ScoreBase):
|
||||
s.maximum_statistics = {
|
||||
HitResult.GREAT: score.beatmap.max_combo,
|
||||
}
|
||||
if user:
|
||||
s.user = await convert_db_user_to_api_user(user)
|
||||
s.user = await UserResp.from_db(
|
||||
score.user,
|
||||
session,
|
||||
include=["statistics", "team", "daily_challenge_user_stats"],
|
||||
ruleset=score.gamemode,
|
||||
)
|
||||
s.rank_global = (
|
||||
await get_score_position_by_id(
|
||||
session,
|
||||
score.map_md5,
|
||||
score.beatmap_id,
|
||||
score.id,
|
||||
mode=score.gamemode,
|
||||
user=user or score.user,
|
||||
user=score.user,
|
||||
)
|
||||
or None
|
||||
)
|
||||
s.rank_country = (
|
||||
await get_score_position_by_id(
|
||||
session,
|
||||
score.map_md5,
|
||||
score.beatmap_id,
|
||||
score.id,
|
||||
score.gamemode,
|
||||
user or score.user,
|
||||
score.user,
|
||||
type=LeaderboardType.COUNTRY,
|
||||
)
|
||||
or None
|
||||
)
|
||||
return s
|
||||
|
||||
|
||||
class MultiplayerScores(RespWithCursor):
|
||||
scores: list[ScoreResp] = Field(default_factory=list)
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ScoreAround(SQLModel):
|
||||
higher: MultiplayerScores | None = None
|
||||
lower: MultiplayerScores | None = None
|
||||
|
||||
|
||||
async def get_best_id(session: AsyncSession, score_id: int) -> None:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(partition_by=col(BestScore.user_id), order_by=col(BestScore.pp).desc())
|
||||
.over(
|
||||
partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
subq = select(BestScore, rownum).subquery()
|
||||
subq = select(PPBestScore, rownum).subquery()
|
||||
stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
|
||||
result = await session.exec(stmt)
|
||||
return result.one_or_none()
|
||||
|
||||
|
||||
async def _score_where(
|
||||
type: LeaderboardType,
|
||||
beatmap: int,
|
||||
mode: GameMode,
|
||||
mods: list[str] | None = None,
|
||||
user: User | None = None,
|
||||
) -> list[ColumnElement[bool]] | None:
|
||||
wheres = [
|
||||
col(BestScore.beatmap_id) == beatmap,
|
||||
col(BestScore.gamemode) == mode,
|
||||
]
|
||||
|
||||
if type == LeaderboardType.FRIENDS:
|
||||
if user and user.is_supporter:
|
||||
subq = (
|
||||
select(DBRelationship.target_id)
|
||||
.where(
|
||||
DBRelationship.type == RelationshipType.FOLLOW,
|
||||
DBRelationship.user_id == user.id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id)))
|
||||
else:
|
||||
return None
|
||||
elif type == LeaderboardType.COUNTRY:
|
||||
if user and user.is_supporter:
|
||||
wheres.append(
|
||||
col(BestScore.user).has(col(User.country_code) == user.country_code)
|
||||
)
|
||||
else:
|
||||
return None
|
||||
elif type == LeaderboardType.TEAM:
|
||||
if user:
|
||||
team_membership = await user.awaitable_attrs.team_membership
|
||||
if team_membership:
|
||||
team_id = team_membership.team_id
|
||||
wheres.append(
|
||||
col(BestScore.user).has(
|
||||
col(User.team_membership).has(TeamMember.team_id == team_id)
|
||||
)
|
||||
)
|
||||
if mods:
|
||||
if user and user.is_supporter:
|
||||
wheres.append(
|
||||
text(
|
||||
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
|
||||
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
|
||||
) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return wheres
|
||||
|
||||
|
||||
async def get_leaderboard(
|
||||
session: AsyncSession,
|
||||
beatmap_md5: str,
|
||||
beatmap: int,
|
||||
mode: GameMode,
|
||||
type: LeaderboardType = LeaderboardType.GLOBAL,
|
||||
mods: list[APIMod] | None = None,
|
||||
mods: list[str] | None = None,
|
||||
user: User | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[Score]:
|
||||
scores = []
|
||||
if type == LeaderboardType.GLOBAL:
|
||||
query = (
|
||||
select(Score)
|
||||
.where(
|
||||
col(Beatmap.beatmap_status).in_(
|
||||
[
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.LOVED,
|
||||
BeatmapRankStatus.QUALIFIED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
]
|
||||
),
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.gamemode == mode,
|
||||
col(Score.passed).is_(True),
|
||||
Score.mods == mods if user and user.is_supporter else false(),
|
||||
)
|
||||
.limit(limit)
|
||||
.order_by(
|
||||
col(Score.total_score).desc(),
|
||||
)
|
||||
)
|
||||
result = await session.exec(query)
|
||||
scores = list[Score](result.all())
|
||||
elif type == LeaderboardType.FRIENDS and user and user.is_supporter:
|
||||
# TODO
|
||||
...
|
||||
elif type == LeaderboardType.TEAM and user and user.team_membership:
|
||||
team_id = user.team_membership.team_id
|
||||
query = (
|
||||
select(Score)
|
||||
.join(Beatmap)
|
||||
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.gamemode == mode,
|
||||
col(Score.passed).is_(True),
|
||||
col(Score.user.team_membership).is_not(None),
|
||||
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
|
||||
Score.mods == mods if user and user.is_supporter else false(),
|
||||
)
|
||||
.limit(limit)
|
||||
.order_by(
|
||||
col(Score.total_score).desc(),
|
||||
)
|
||||
)
|
||||
result = await session.exec(query)
|
||||
scores = list[Score](result.all())
|
||||
) -> tuple[list[Score], Score | None]:
|
||||
wheres = await _score_where(type, beatmap, mode, mods, user)
|
||||
if wheres is None:
|
||||
return [], None
|
||||
query = (
|
||||
select(BestScore)
|
||||
.where(*wheres)
|
||||
.limit(limit)
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
)
|
||||
if mods:
|
||||
query = query.params(w=json.dumps(mods))
|
||||
scores = [s.score for s in await session.exec(query)]
|
||||
user_score = None
|
||||
if user:
|
||||
user_score = (
|
||||
await session.exec(
|
||||
select(Score).where(
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.gamemode == mode,
|
||||
Score.user_id == user.id,
|
||||
col(Score.passed).is_(True),
|
||||
)
|
||||
self_query = (
|
||||
select(BestScore)
|
||||
.where(BestScore.user_id == user.id)
|
||||
.where(
|
||||
col(BestScore.beatmap_id) == beatmap,
|
||||
col(BestScore.gamemode) == mode,
|
||||
)
|
||||
).first()
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
.limit(1)
|
||||
)
|
||||
if mods:
|
||||
self_query = self_query.where(
|
||||
text(
|
||||
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
|
||||
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
|
||||
)
|
||||
).params(w=json.dumps(mods))
|
||||
user_bs = (await session.exec(self_query)).first()
|
||||
if user_bs:
|
||||
user_score = user_bs.score
|
||||
if user_score and user_score not in scores:
|
||||
scores.append(user_score)
|
||||
return scores
|
||||
return scores, user_score
|
||||
|
||||
|
||||
async def get_score_position_by_user(
|
||||
session: AsyncSession,
|
||||
beatmap_md5: str,
|
||||
beatmap: int,
|
||||
user: User,
|
||||
mode: GameMode,
|
||||
type: LeaderboardType = LeaderboardType.GLOBAL,
|
||||
mods: list[APIMod] | None = None,
|
||||
mods: list[str] | None = None,
|
||||
) -> int:
|
||||
where_clause = [
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.gamemode == mode,
|
||||
col(Score.passed).is_(True),
|
||||
col(Beatmap.beatmap_status).in_(
|
||||
[
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.LOVED,
|
||||
BeatmapRankStatus.QUALIFIED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
]
|
||||
),
|
||||
]
|
||||
if mods and user.is_supporter:
|
||||
where_clause.append(Score.mods == mods)
|
||||
else:
|
||||
where_clause.append(false())
|
||||
if type == LeaderboardType.FRIENDS and user.is_supporter:
|
||||
# TODO
|
||||
...
|
||||
elif type == LeaderboardType.TEAM and user.team_membership:
|
||||
team_id = user.team_membership.team_id
|
||||
where_clause.append(
|
||||
col(Score.user.team_membership).is_not(None),
|
||||
)
|
||||
where_clause.append(
|
||||
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
|
||||
)
|
||||
wheres = await _score_where(type, beatmap, mode, mods, user=user)
|
||||
if wheres is None:
|
||||
return 0
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=Score.map_md5,
|
||||
order_by=col(Score.total_score).desc(),
|
||||
partition_by=col(BestScore.beatmap_id),
|
||||
order_by=col(BestScore.total_score).desc(),
|
||||
)
|
||||
.label("row_number")
|
||||
)
|
||||
subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery()
|
||||
stmt = select(subq.c.row_number).where(subq.c.user == user)
|
||||
subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
|
||||
stmt = select(subq.c.row_number).where(subq.c.user_id == user.id)
|
||||
result = await session.exec(stmt)
|
||||
s = result.one_or_none()
|
||||
return s if s else 0
|
||||
@@ -384,57 +390,26 @@ async def get_score_position_by_user(
|
||||
|
||||
async def get_score_position_by_id(
|
||||
session: AsyncSession,
|
||||
beatmap_md5: str,
|
||||
beatmap: int,
|
||||
score_id: int,
|
||||
mode: GameMode,
|
||||
user: User | None = None,
|
||||
type: LeaderboardType = LeaderboardType.GLOBAL,
|
||||
mods: list[APIMod] | None = None,
|
||||
mods: list[str] | None = None,
|
||||
) -> int:
|
||||
where_clause = [
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.id == score_id,
|
||||
Score.gamemode == mode,
|
||||
col(Score.passed).is_(True),
|
||||
col(Beatmap.beatmap_status).in_(
|
||||
[
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.LOVED,
|
||||
BeatmapRankStatus.QUALIFIED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
]
|
||||
),
|
||||
]
|
||||
if mods and user and user.is_supporter:
|
||||
where_clause.append(Score.mods == mods)
|
||||
elif mods:
|
||||
where_clause.append(false())
|
||||
wheres = await _score_where(type, beatmap, mode, mods, user=user)
|
||||
if wheres is None:
|
||||
return 0
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=[col(Score.user_id), col(Score.map_md5)],
|
||||
order_by=col(Score.total_score).desc(),
|
||||
partition_by=col(BestScore.beatmap_id),
|
||||
order_by=col(BestScore.total_score).desc(),
|
||||
)
|
||||
.label("rownum")
|
||||
.label("row_number")
|
||||
)
|
||||
subq = (
|
||||
select(Score.user_id, Score.id, Score.total_score, rownum)
|
||||
.join(Beatmap)
|
||||
.where(*where_clause)
|
||||
.subquery()
|
||||
)
|
||||
best_scores = aliased(subq)
|
||||
overall_rank = (
|
||||
func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank")
|
||||
)
|
||||
final_q = (
|
||||
select(best_scores.c.id, overall_rank)
|
||||
.select_from(best_scores)
|
||||
.where(best_scores.c.rownum == 1)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id)
|
||||
subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
|
||||
stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
|
||||
result = await session.exec(stmt)
|
||||
s = result.one_or_none()
|
||||
return s if s else 0
|
||||
@@ -445,16 +420,38 @@ async def get_user_best_score_in_beatmap(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
mode: GameMode | None = None,
|
||||
) -> Score | None:
|
||||
) -> BestScore | None:
|
||||
return (
|
||||
await session.exec(
|
||||
Score.select_clause(False)
|
||||
select(BestScore)
|
||||
.where(
|
||||
Score.gamemode == mode if mode is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user,
|
||||
BestScore.gamemode == mode if mode is not None else true(),
|
||||
BestScore.beatmap_id == beatmap,
|
||||
BestScore.user_id == user,
|
||||
)
|
||||
.order_by(col(Score.total_score).desc())
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
|
||||
# FIXME
|
||||
async def get_user_best_score_with_mod_in_beatmap(
|
||||
session: AsyncSession,
|
||||
beatmap: int,
|
||||
user: int,
|
||||
mod: list[str],
|
||||
mode: GameMode | None = None,
|
||||
) -> BestScore | None:
|
||||
return (
|
||||
await session.exec(
|
||||
select(BestScore)
|
||||
.where(
|
||||
BestScore.gamemode == mode if mode is not None else True,
|
||||
BestScore.beatmap_id == beatmap,
|
||||
BestScore.user_id == user,
|
||||
# BestScore.mods == mod,
|
||||
)
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
@@ -464,13 +461,13 @@ async def get_user_best_pp_in_beatmap(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
mode: GameMode,
|
||||
) -> BestScore | None:
|
||||
) -> PPBestScore | None:
|
||||
return (
|
||||
await session.exec(
|
||||
select(BestScore).where(
|
||||
BestScore.beatmap_id == beatmap,
|
||||
BestScore.user_id == user,
|
||||
BestScore.gamemode == mode,
|
||||
select(PPBestScore).where(
|
||||
PPBestScore.beatmap_id == beatmap,
|
||||
PPBestScore.user_id == user,
|
||||
PPBestScore.gamemode == mode,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -480,12 +477,12 @@ async def get_user_best_pp(
|
||||
session: AsyncSession,
|
||||
user: int,
|
||||
limit: int = 200,
|
||||
) -> Sequence[BestScore]:
|
||||
) -> Sequence[PPBestScore]:
|
||||
return (
|
||||
await session.exec(
|
||||
select(BestScore)
|
||||
.where(BestScore.user_id == user)
|
||||
.order_by(col(BestScore.pp).desc())
|
||||
select(PPBestScore)
|
||||
.where(PPBestScore.user_id == user)
|
||||
.order_by(col(PPBestScore.pp).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
).all()
|
||||
@@ -494,27 +491,45 @@ async def get_user_best_pp(
|
||||
async def process_user(
|
||||
session: AsyncSession, user: User, score: Score, ranked: bool = False
|
||||
):
|
||||
assert user.id
|
||||
assert score.id
|
||||
mod_for_save = list({mod["acronym"] for mod in score.mods})
|
||||
previous_score_best = await get_user_best_score_in_beatmap(
|
||||
session, score.beatmap_id, user.id, score.gamemode
|
||||
)
|
||||
statistics = None
|
||||
previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap(
|
||||
session, score.beatmap_id, user.id, mod_for_save, score.gamemode
|
||||
)
|
||||
add_to_db = False
|
||||
for i in user.lazer_statistics:
|
||||
mouthly_playcount = (
|
||||
await session.exec(
|
||||
select(MonthlyPlaycounts).where(
|
||||
MonthlyPlaycounts.user_id == user.id,
|
||||
MonthlyPlaycounts.year == date.today().year,
|
||||
MonthlyPlaycounts.month == date.today().month,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if mouthly_playcount is None:
|
||||
mouthly_playcount = MonthlyPlaycounts(
|
||||
user_id=user.id, year=date.today().year, month=date.today().month
|
||||
)
|
||||
add_to_db = True
|
||||
statistics = None
|
||||
for i in await user.awaitable_attrs.statistics:
|
||||
if i.mode == score.gamemode.value:
|
||||
statistics = i
|
||||
break
|
||||
if statistics is None:
|
||||
statistics = LazerUserStatistics(
|
||||
mode=score.gamemode.value,
|
||||
user_id=user.id,
|
||||
raise ValueError(
|
||||
f"User {user.id} does not have statistics for mode {score.gamemode.value}"
|
||||
)
|
||||
add_to_db = True
|
||||
|
||||
# pc, pt, tth, tts
|
||||
statistics.total_score += score.total_score
|
||||
difference = (
|
||||
score.total_score - previous_score_best.total_score
|
||||
if previous_score_best and previous_score_best.id != score.id
|
||||
if previous_score_best
|
||||
else score.total_score
|
||||
)
|
||||
if difference > 0 and score.passed and ranked:
|
||||
@@ -541,11 +556,48 @@ async def process_user(
|
||||
statistics.grade_sh -= 1
|
||||
case Rank.A:
|
||||
statistics.grade_a -= 1
|
||||
else:
|
||||
previous_score_best = BestScore(
|
||||
user_id=user.id,
|
||||
beatmap_id=score.beatmap_id,
|
||||
gamemode=score.gamemode,
|
||||
score_id=score.id,
|
||||
total_score=score.total_score,
|
||||
rank=score.rank,
|
||||
mods=mod_for_save,
|
||||
)
|
||||
session.add(previous_score_best)
|
||||
|
||||
statistics.ranked_score += difference
|
||||
statistics.level_current = calculate_score_to_level(statistics.ranked_score)
|
||||
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
|
||||
if score.passed and ranked:
|
||||
if previous_score_best_mod is not None:
|
||||
previous_score_best_mod.mods = mod_for_save
|
||||
previous_score_best_mod.score_id = score.id
|
||||
previous_score_best_mod.rank = score.rank
|
||||
previous_score_best_mod.total_score = score.total_score
|
||||
elif (
|
||||
previous_score_best is not None and previous_score_best.score_id != score.id
|
||||
):
|
||||
session.add(
|
||||
BestScore(
|
||||
user_id=user.id,
|
||||
beatmap_id=score.beatmap_id,
|
||||
gamemode=score.gamemode,
|
||||
score_id=score.id,
|
||||
total_score=score.total_score,
|
||||
rank=score.rank,
|
||||
mods=mod_for_save,
|
||||
)
|
||||
)
|
||||
statistics.play_count += 1
|
||||
mouthly_playcount.playcount += 1
|
||||
statistics.play_time += int((score.ended_at - score.started_at).total_seconds())
|
||||
statistics.count_100 += score.n100 + score.nkatu
|
||||
statistics.count_300 += score.n300 + score.ngeki
|
||||
statistics.count_50 += score.n50
|
||||
statistics.count_miss += score.nmiss
|
||||
statistics.total_hits += (
|
||||
score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
|
||||
)
|
||||
@@ -563,11 +615,8 @@ async def process_user(
|
||||
acc_sum = clamp(acc_sum, 0.0, 100.0)
|
||||
statistics.pp = pp_sum
|
||||
statistics.hit_accuracy = acc_sum
|
||||
|
||||
statistics.updated_at = datetime.now(UTC)
|
||||
|
||||
if add_to_db:
|
||||
session.add(statistics)
|
||||
session.add(mouthly_playcount)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
@@ -581,7 +630,10 @@ async def process_score(
|
||||
fetcher: "Fetcher",
|
||||
session: AsyncSession,
|
||||
redis: Redis,
|
||||
item_id: int | None = None,
|
||||
room_id: int | None = None,
|
||||
) -> Score:
|
||||
assert user.id
|
||||
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
|
||||
score = Score(
|
||||
accuracy=info.accuracy,
|
||||
@@ -611,6 +663,8 @@ async def process_score(
|
||||
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
|
||||
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
|
||||
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
|
||||
playlist_item_id=item_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
if can_get_pp:
|
||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
@@ -628,7 +682,7 @@ async def process_score(
|
||||
)
|
||||
if previous_pp_best is None or score.pp > previous_pp_best.pp:
|
||||
assert score.id
|
||||
best_score = BestScore(
|
||||
best_score = PPBestScore(
|
||||
user_id=user_id,
|
||||
score_id=score.id,
|
||||
beatmap_id=beatmap_id,
|
||||
@@ -637,7 +691,7 @@ async def process_score(
|
||||
acc=score.accuracy,
|
||||
)
|
||||
session.add(best_score)
|
||||
session.delete(previous_pp_best) if previous_pp_best else None
|
||||
await session.delete(previous_pp_best) if previous_pp_best else None
|
||||
await session.commit()
|
||||
await session.refresh(score)
|
||||
await session.refresh(score_token)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .beatmap import Beatmap
|
||||
from .user import User
|
||||
from .lazer_user import User
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
|
||||
class ScoreTokenBase(SQLModel):
|
||||
class ScoreTokenBase(SQLModel, UTCBaseModel):
|
||||
score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
|
||||
ruleset_id: GameMode
|
||||
playlist_item_id: int | None = Field(default=None) # playlist
|
||||
@@ -34,10 +35,10 @@ class ScoreToken(ScoreTokenBase, table=True):
|
||||
autoincrement=True,
|
||||
),
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id")
|
||||
user: "User" = Relationship()
|
||||
beatmap: "Beatmap" = Relationship()
|
||||
user: User = Relationship()
|
||||
beatmap: Beatmap = Relationship()
|
||||
|
||||
|
||||
class ScoreTokenResp(ScoreTokenBase):
|
||||
|
||||
95
app/database/statistics.py
Normal file
95
app/database/statistics.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.score import GameMode
|
||||
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
|
||||
|
||||
class UserStatisticsBase(SQLModel):
|
||||
mode: GameMode
|
||||
count_100: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
count_300: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
count_50: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
count_miss: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
|
||||
global_rank: int | None = Field(default=None)
|
||||
country_rank: int | None = Field(default=None)
|
||||
|
||||
pp: float = Field(default=0.0)
|
||||
ranked_score: int = Field(default=0)
|
||||
hit_accuracy: float = Field(default=0.00)
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
total_hits: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
maximum_combo: int = Field(default=0)
|
||||
|
||||
play_count: int = Field(default=0)
|
||||
play_time: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
replays_watched_by_others: int = Field(default=0)
|
||||
is_ranked: bool = Field(default=True)
|
||||
|
||||
|
||||
class UserStatistics(UserStatisticsBase, table=True):
|
||||
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("lazer_users.id"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
grade_ss: int = Field(default=0)
|
||||
grade_ssh: int = Field(default=0)
|
||||
grade_s: int = Field(default=0)
|
||||
grade_sh: int = Field(default=0)
|
||||
grade_a: int = Field(default=0)
|
||||
|
||||
level_current: int = Field(default=1)
|
||||
level_progress: int = Field(default=0)
|
||||
|
||||
user: "User" = Relationship(back_populates="statistics") # type: ignore[valid-type]
|
||||
|
||||
|
||||
class UserStatisticsResp(UserStatisticsBase):
|
||||
grade_counts: dict[str, int] = Field(
|
||||
default_factory=lambda: {
|
||||
"ss": 0,
|
||||
"ssh": 0,
|
||||
"s": 0,
|
||||
"sh": 0,
|
||||
"a": 0,
|
||||
}
|
||||
)
|
||||
level: dict[str, int] = Field(
|
||||
default_factory=lambda: {
|
||||
"current": 1,
|
||||
"progress": 0,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, obj: UserStatistics) -> "UserStatisticsResp":
|
||||
s = cls.model_validate(obj)
|
||||
s.grade_counts = {
|
||||
"ss": obj.grade_ss,
|
||||
"ssh": obj.grade_ssh,
|
||||
"s": obj.grade_s,
|
||||
"sh": obj.grade_sh,
|
||||
"a": obj.grade_a,
|
||||
}
|
||||
s.level = {
|
||||
"current": obj.level_current,
|
||||
"progress": obj.level_progress,
|
||||
}
|
||||
return s
|
||||
@@ -1,14 +1,16 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
from .lazer_user import User
|
||||
|
||||
|
||||
class Team(SQLModel, table=True):
|
||||
class Team(SQLModel, UTCBaseModel, table=True):
|
||||
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
@@ -22,15 +24,19 @@ class Team(SQLModel, table=True):
|
||||
members: list["TeamMember"] = Relationship(back_populates="team")
|
||||
|
||||
|
||||
class TeamMember(SQLModel, table=True):
|
||||
class TeamMember(SQLModel, UTCBaseModel, table=True):
|
||||
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
team_id: int = Field(foreign_key="teams.id")
|
||||
joined_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
user: "User" = Relationship(back_populates="team_membership")
|
||||
team: "Team" = Relationship(back_populates="members")
|
||||
user: "User" = Relationship(
|
||||
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
team: "Team" = Relationship(
|
||||
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
|
||||
@@ -1,527 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .legacy import LegacyUserStatistics
|
||||
from .team import TeamMember
|
||||
|
||||
from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text
|
||||
from sqlalchemy.dialects.mysql import VARCHAR
|
||||
from sqlalchemy.orm import joinedload, selectinload
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
# 主键
|
||||
id: int = Field(
|
||||
default=None, sa_column=Column(BigInteger, primary_key=True, index=True)
|
||||
)
|
||||
|
||||
# 基本信息(匹配 migrations_old 中的结构)
|
||||
name: str = Field(max_length=32, unique=True, index=True) # 用户名
|
||||
safe_name: str = Field(max_length=32, unique=True, index=True) # 安全用户名
|
||||
email: str = Field(max_length=254, unique=True, index=True)
|
||||
priv: int = Field(default=1) # 权限
|
||||
pw_bcrypt: str = Field(max_length=60) # bcrypt 哈希密码
|
||||
country: str = Field(default="CN", max_length=2) # 国家代码
|
||||
|
||||
# 状态和时间
|
||||
silence_end: int = Field(default=0)
|
||||
donor_end: int = Field(default=0)
|
||||
creation_time: int = Field(default=0) # Unix 时间戳
|
||||
latest_activity: int = Field(default=0) # Unix 时间戳
|
||||
|
||||
# 游戏相关
|
||||
preferred_mode: int = Field(default=0) # 偏好游戏模式
|
||||
play_style: int = Field(default=0) # 游戏风格
|
||||
|
||||
# 扩展信息
|
||||
clan_id: int = Field(default=0)
|
||||
clan_priv: int = Field(default=0)
|
||||
custom_badge_name: str | None = Field(default=None, max_length=16)
|
||||
custom_badge_icon: str | None = Field(default=None, max_length=64)
|
||||
userpage_content: str | None = Field(default=None, max_length=2048)
|
||||
api_key: str | None = Field(default=None, max_length=36, unique=True)
|
||||
|
||||
# 虚拟字段用于兼容性
|
||||
@property
|
||||
def username(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def country_code(self):
|
||||
return self.country
|
||||
|
||||
@property
|
||||
def join_date(self):
|
||||
creation_time = getattr(self, "creation_time", 0)
|
||||
return (
|
||||
datetime.fromtimestamp(creation_time)
|
||||
if creation_time > 0
|
||||
else datetime.utcnow()
|
||||
)
|
||||
|
||||
@property
|
||||
def last_visit(self):
|
||||
latest_activity = getattr(self, "latest_activity", 0)
|
||||
return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None
|
||||
|
||||
@property
|
||||
def is_supporter(self):
|
||||
return self.lazer_profile.is_supporter if self.lazer_profile else False
|
||||
|
||||
# 关联关系
|
||||
lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user")
|
||||
lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user")
|
||||
lazer_counts: Optional["LazerUserCounts"] = Relationship(back_populates="user")
|
||||
lazer_achievements: list["LazerUserAchievement"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
lazer_profile_sections: list["LazerUserProfileSections"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user")
|
||||
team_membership: Optional["TeamMember"] = Relationship(back_populates="user")
|
||||
daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
rank_history: list["RankHistory"] = Relationship(back_populates="user")
|
||||
avatar: Optional["UserAvatar"] = Relationship(back_populates="user")
|
||||
active_banners: list["LazerUserBanners"] = Relationship(back_populates="user")
|
||||
lazer_badges: list["LazerUserBadge"] = Relationship(back_populates="user")
|
||||
lazer_monthly_playcounts: list["LazerUserMonthlyPlaycounts"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
lazer_previous_usernames: list["LazerUserPreviousUsername"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
lazer_replays_watched: list["LazerUserReplaysWatched"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def all_select_option(cls):
|
||||
return (
|
||||
joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.avatar), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.team_membership), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.rank_history), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.active_banners), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def all_select_clause(cls):
|
||||
return select(cls).options(*cls.all_select_option())
|
||||
|
||||
|
||||
# ============================================
|
||||
# Lazer API 专用表模型
|
||||
# ============================================
|
||||
|
||||
|
||||
class LazerUserProfile(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 基本状态字段
|
||||
is_active: bool = Field(default=True)
|
||||
is_bot: bool = Field(default=False)
|
||||
is_deleted: bool = Field(default=False)
|
||||
is_online: bool = Field(default=True)
|
||||
is_supporter: bool = Field(default=False)
|
||||
is_restricted: bool = Field(default=False)
|
||||
session_verified: bool = Field(default=False)
|
||||
has_supported: bool = Field(default=False)
|
||||
pm_friends_only: bool = Field(default=False)
|
||||
|
||||
# 基本资料字段
|
||||
default_group: str = Field(default="default", max_length=50)
|
||||
last_visit: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
join_date: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
profile_colour: str | None = Field(default=None, max_length=7)
|
||||
profile_hue: int | None = Field(default=None)
|
||||
|
||||
# 社交媒体和个人资料字段
|
||||
avatar_url: str | None = Field(default=None, max_length=500)
|
||||
cover_url: str | None = Field(default=None, max_length=500)
|
||||
discord: str | None = Field(default=None, max_length=100)
|
||||
twitter: str | None = Field(default=None, max_length=100)
|
||||
website: str | None = Field(default=None, max_length=500)
|
||||
title: str | None = Field(default=None, max_length=100)
|
||||
title_url: str | None = Field(default=None, max_length=500)
|
||||
interests: str | None = Field(default=None, sa_column=Column(Text))
|
||||
location: str | None = Field(default=None, max_length=100)
|
||||
|
||||
occupation: str | None = Field(default=None) # 职业字段,默认为 None
|
||||
|
||||
# 游戏相关字段
|
||||
playmode: str = Field(default="osu", max_length=10)
|
||||
support_level: int = Field(default=0)
|
||||
max_blocks: int = Field(default=100)
|
||||
max_friends: int = Field(default=500)
|
||||
post_count: int = Field(default=0)
|
||||
|
||||
# 页面内容
|
||||
page_html: str | None = Field(default=None, sa_column=Column(Text))
|
||||
page_raw: str | None = Field(default=None, sa_column=Column(Text))
|
||||
|
||||
profile_order: str = Field(
|
||||
default="me,recent_activity,top_ranks,medals,historical,beatmaps,kudosu"
|
||||
)
|
||||
|
||||
# 关联关系
|
||||
user: "User" = Relationship(back_populates="lazer_profile")
|
||||
|
||||
|
||||
class LazerUserProfileSections(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
section_name: str = Field(sa_column=Column(VARCHAR(50)))
|
||||
display_order: int | None = Field(default=None)
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
user: "User" = Relationship(back_populates="lazer_profile_sections")
|
||||
|
||||
|
||||
class LazerUserCountry(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
code: str = Field(max_length=2)
|
||||
name: str = Field(max_length=100)
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
|
||||
class LazerUserKudosu(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
available: int = Field(default=0)
|
||||
total: int = Field(default=0)
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
|
||||
class LazerUserCounts(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 统计计数字段
|
||||
beatmap_playcounts_count: int = Field(default=0)
|
||||
comments_count: int = Field(default=0)
|
||||
favourite_beatmapset_count: int = Field(default=0)
|
||||
follower_count: int = Field(default=0)
|
||||
graveyard_beatmapset_count: int = Field(default=0)
|
||||
guest_beatmapset_count: int = Field(default=0)
|
||||
loved_beatmapset_count: int = Field(default=0)
|
||||
mapping_follower_count: int = Field(default=0)
|
||||
nominated_beatmapset_count: int = Field(default=0)
|
||||
pending_beatmapset_count: int = Field(default=0)
|
||||
ranked_beatmapset_count: int = Field(default=0)
|
||||
ranked_and_approved_beatmapset_count: int = Field(default=0)
|
||||
unranked_beatmapset_count: int = Field(default=0)
|
||||
scores_best_count: int = Field(default=0)
|
||||
scores_first_count: int = Field(default=0)
|
||||
scores_pinned_count: int = Field(default=0)
|
||||
scores_recent_count: int = Field(default=0)
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
# 关联关系
|
||||
user: "User" = Relationship(back_populates="lazer_counts")
|
||||
|
||||
|
||||
class LazerUserStatistics(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
mode: str = Field(default="osu", max_length=10, primary_key=True)
|
||||
|
||||
# 基本命中统计
|
||||
count_100: int = Field(default=0)
|
||||
count_300: int = Field(default=0)
|
||||
count_50: int = Field(default=0)
|
||||
count_miss: int = Field(default=0)
|
||||
|
||||
# 等级信息
|
||||
level_current: int = Field(default=1)
|
||||
level_progress: int = Field(default=0)
|
||||
|
||||
# 排名信息
|
||||
global_rank: int | None = Field(default=None)
|
||||
global_rank_exp: int | None = Field(default=None)
|
||||
country_rank: int | None = Field(default=None)
|
||||
|
||||
# PP 和分数
|
||||
pp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2)))
|
||||
pp_exp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2)))
|
||||
ranked_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
hit_accuracy: float = Field(default=0.00, sa_column=Column(DECIMAL(5, 2)))
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
total_hits: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
maximum_combo: int = Field(default=0)
|
||||
|
||||
# 游戏统计
|
||||
play_count: int = Field(default=0)
|
||||
play_time: int = Field(default=0) # 秒
|
||||
replays_watched_by_others: int = Field(default=0)
|
||||
is_ranked: bool = Field(default=False)
|
||||
|
||||
# 成绩等级计数
|
||||
grade_ss: int = Field(default=0)
|
||||
grade_ssh: int = Field(default=0)
|
||||
grade_s: int = Field(default=0)
|
||||
grade_sh: int = Field(default=0)
|
||||
grade_a: int = Field(default=0)
|
||||
|
||||
# 最高排名记录
|
||||
rank_highest: int | None = Field(default=None)
|
||||
rank_highest_updated_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
# 关联关系
|
||||
user: "User" = Relationship(back_populates="lazer_statistics")
|
||||
|
||||
|
||||
class LazerUserBanners(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
tournament_id: int
|
||||
image_url: str = Field(sa_column=Column(VARCHAR(500)))
|
||||
is_active: bool | None = Field(default=None)
|
||||
|
||||
# 修正user关系的back_populates值
|
||||
user: "User" = Relationship(back_populates="active_banners")
|
||||
|
||||
|
||||
class LazerUserAchievement(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
achievement_id: int
|
||||
achieved_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
user: "User" = Relationship(back_populates="lazer_achievements")
|
||||
|
||||
|
||||
class LazerUserBadge(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
badge_id: int
|
||||
awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
description: str | None = Field(default=None, sa_column=Column(Text))
|
||||
image_url: str | None = Field(default=None, max_length=500)
|
||||
url: str | None = Field(default=None, max_length=500)
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
user: "User" = Relationship(back_populates="lazer_badges")
|
||||
|
||||
|
||||
class LazerUserMonthlyPlaycounts(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
start_date: datetime = Field(sa_column=Column(Date))
|
||||
play_count: int = Field(default=0)
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
user: "User" = Relationship(back_populates="lazer_monthly_playcounts")
|
||||
|
||||
|
||||
class LazerUserPreviousUsername(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
username: str = Field(max_length=32)
|
||||
changed_at: datetime = Field(sa_column=Column(DateTime))
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
user: "User" = Relationship(back_populates="lazer_previous_usernames")
|
||||
|
||||
|
||||
class LazerUserReplaysWatched(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
start_date: datetime = Field(sa_column=Column(Date))
|
||||
count: int = Field(default=0)
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
user: "User" = Relationship(back_populates="lazer_replays_watched")
|
||||
|
||||
|
||||
# 类型转换用的 UserAchievement(不是 SQLAlchemy 模型)
|
||||
@dataclass
|
||||
class UserAchievement:
|
||||
achieved_at: datetime
|
||||
achievement_id: int
|
||||
|
||||
|
||||
class DailyChallengeStats(SQLModel, table=True):
|
||||
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True)
|
||||
)
|
||||
|
||||
daily_streak_best: int = Field(default=0)
|
||||
daily_streak_current: int = Field(default=0)
|
||||
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
last_weekly_streak: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime)
|
||||
)
|
||||
playcount: int = Field(default=0)
|
||||
top_10p_placements: int = Field(default=0)
|
||||
top_50p_placements: int = Field(default=0)
|
||||
weekly_streak_best: int = Field(default=0)
|
||||
weekly_streak_current: int = Field(default=0)
|
||||
|
||||
user: "User" = Relationship(back_populates="daily_challenge_stats")
|
||||
|
||||
|
||||
class RankHistory(SQLModel, table=True):
|
||||
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
mode: str = Field(max_length=10)
|
||||
rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks
|
||||
date_recorded: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
user: "User" = Relationship(back_populates="rank_history")
|
||||
|
||||
|
||||
class UserAvatar(SQLModel, table=True):
|
||||
__tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
filename: str = Field(max_length=255)
|
||||
original_filename: str = Field(max_length=255)
|
||||
file_size: int
|
||||
mime_type: str = Field(max_length=100)
|
||||
is_active: bool = Field(default=True)
|
||||
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
r2_original_url: str | None = Field(default=None, max_length=500)
|
||||
r2_game_url: str | None = Field(default=None, max_length=500)
|
||||
|
||||
user: "User" = Relationship(back_populates="avatar")
|
||||
45
app/database/user_account_history.py
Normal file
45
app/database/user_account_history.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from sqlmodel import BigInteger, Column, Field, ForeignKey, Integer, SQLModel
|
||||
|
||||
|
||||
class UserAccountHistoryType(str, Enum):
|
||||
NOTE = "note"
|
||||
RESTRICTION = "restriction"
|
||||
SLIENCE = "silence"
|
||||
TOURNAMENT_BAN = "tournament_ban"
|
||||
|
||||
|
||||
class UserAccountHistoryBase(SQLModel, UTCBaseModel):
|
||||
description: str | None = None
|
||||
length: int
|
||||
permanent: bool = False
|
||||
timestamp: datetime = Field(default=datetime.now(UTC))
|
||||
type: UserAccountHistoryType
|
||||
|
||||
|
||||
class UserAccountHistory(UserAccountHistoryBase, table=True):
|
||||
__tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(
|
||||
sa_column=Column(
|
||||
Integer,
|
||||
autoincrement=True,
|
||||
index=True,
|
||||
primary_key=True,
|
||||
)
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
|
||||
|
||||
class UserAccountHistoryResp(UserAccountHistoryBase):
|
||||
id: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_model: UserAccountHistory) -> "UserAccountHistoryResp":
|
||||
return cls.model_validate(db_model)
|
||||
@@ -5,15 +5,11 @@ import json
|
||||
from app.config import settings
|
||||
|
||||
from pydantic import BaseModel
|
||||
import redis.asyncio as redis
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
redis = None
|
||||
|
||||
|
||||
def json_serializer(value):
|
||||
if isinstance(value, BaseModel | SQLModel):
|
||||
@@ -25,10 +21,7 @@ def json_serializer(value):
|
||||
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
|
||||
|
||||
# Redis 连接
|
||||
if redis:
|
||||
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
else:
|
||||
redis_client = None
|
||||
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
|
||||
@@ -8,7 +8,7 @@ from app.log import logger
|
||||
fetcher: Fetcher | None = None
|
||||
|
||||
|
||||
def get_fetcher() -> Fetcher:
|
||||
async def get_fetcher() -> Fetcher:
|
||||
global fetcher
|
||||
if fetcher is None:
|
||||
fetcher = Fetcher(
|
||||
@@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher:
|
||||
settings.FETCHER_CALLBACK_URL,
|
||||
)
|
||||
redis = get_redis()
|
||||
if redis:
|
||||
access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}")
|
||||
if access_token:
|
||||
fetcher.access_token = str(access_token)
|
||||
refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
|
||||
if refresh_token:
|
||||
fetcher.refresh_token = str(refresh_token)
|
||||
if not fetcher.access_token or not fetcher.refresh_token:
|
||||
logger.opt(colors=True).info(
|
||||
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
|
||||
)
|
||||
access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
|
||||
if access_token:
|
||||
fetcher.access_token = str(access_token)
|
||||
refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
|
||||
if refresh_token:
|
||||
fetcher.refresh_token = str(refresh_token)
|
||||
if not fetcher.access_token or not fetcher.refresh_token:
|
||||
logger.opt(colors=True).info(
|
||||
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
|
||||
)
|
||||
return fetcher
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database import User
|
||||
|
||||
from .database import get_db
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
security = HTTPBearer()
|
||||
@@ -17,7 +16,7 @@ security = HTTPBearer()
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DBUser:
|
||||
) -> User:
|
||||
"""获取当前认证用户"""
|
||||
token = credentials.credentials
|
||||
|
||||
@@ -27,13 +26,9 @@ async def get_current_user(
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
|
||||
async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None:
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
return None
|
||||
user = (
|
||||
await db.exec(
|
||||
DBUser.all_select_clause().where(DBUser.id == token_record.user_id)
|
||||
)
|
||||
).first()
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
return user
|
||||
|
||||
@@ -59,16 +59,15 @@ class BaseFetcher:
|
||||
self.refresh_token = token_data.get("refresh_token", "")
|
||||
self.token_expiry = int(time.time()) + token_data["expires_in"]
|
||||
redis = get_redis()
|
||||
if redis:
|
||||
redis.set(
|
||||
f"fetcher:access_token:{self.client_id}",
|
||||
self.access_token,
|
||||
ex=token_data["expires_in"],
|
||||
)
|
||||
redis.set(
|
||||
f"fetcher:refresh_token:{self.client_id}",
|
||||
self.refresh_token,
|
||||
)
|
||||
await redis.set(
|
||||
f"fetcher:access_token:{self.client_id}",
|
||||
self.access_token,
|
||||
ex=token_data["expires_in"],
|
||||
)
|
||||
await redis.set(
|
||||
f"fetcher:refresh_token:{self.client_id}",
|
||||
self.refresh_token,
|
||||
)
|
||||
|
||||
async def refresh_access_token(self) -> None:
|
||||
async with AsyncClient() as client:
|
||||
@@ -87,13 +86,12 @@ class BaseFetcher:
|
||||
self.refresh_token = token_data.get("refresh_token", "")
|
||||
self.token_expiry = int(time.time()) + token_data["expires_in"]
|
||||
redis = get_redis()
|
||||
if redis:
|
||||
redis.set(
|
||||
f"fetcher:access_token:{self.client_id}",
|
||||
self.access_token,
|
||||
ex=token_data["expires_in"],
|
||||
)
|
||||
redis.set(
|
||||
f"fetcher:refresh_token:{self.client_id}",
|
||||
self.refresh_token,
|
||||
)
|
||||
await redis.set(
|
||||
f"fetcher:access_token:{self.client_id}",
|
||||
self.access_token,
|
||||
ex=token_data["expires_in"],
|
||||
)
|
||||
await redis.set(
|
||||
f"fetcher:refresh_token:{self.client_id}",
|
||||
self.refresh_token,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
from loguru import logger
|
||||
import redis
|
||||
import redis.asyncio as redis
|
||||
|
||||
|
||||
class OsuDotDirectFetcher(BaseFetcher):
|
||||
@@ -22,8 +22,8 @@ class OsuDotDirectFetcher(BaseFetcher):
|
||||
async def get_or_fetch_beatmap_raw(
|
||||
self, redis: redis.Redis, beatmap_id: int
|
||||
) -> str:
|
||||
if redis.exists(f"beatmap:{beatmap_id}:raw"):
|
||||
return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
|
||||
if await redis.exists(f"beatmap:{beatmap_id}:raw"):
|
||||
return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
|
||||
raw = await self.get_beatmap_raw(beatmap_id)
|
||||
redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
|
||||
await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
|
||||
return raw
|
||||
|
||||
@@ -42,11 +42,12 @@ class Language(IntEnum):
|
||||
KOREAN = 6
|
||||
FRENCH = 7
|
||||
GERMAN = 8
|
||||
ITALIAN = 9
|
||||
SPANISH = 10
|
||||
RUSSIAN = 11
|
||||
POLISH = 12
|
||||
OTHER = 13
|
||||
SWEDISH = 9
|
||||
ITALIAN = 10
|
||||
SPANISH = 11
|
||||
RUSSIAN = 12
|
||||
POLISH = 13
|
||||
OTHER = 14
|
||||
|
||||
|
||||
class BeatmapAttributes(BaseModel):
|
||||
|
||||
@@ -1,114 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Any, Literal
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
from app.models.signalr import UserState
|
||||
from app.models.signalr import SignalRUnionMessage, UserState
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class _UserActivity(BaseModel):
|
||||
model_config = ConfigDict(serialize_by_alias=True)
|
||||
type: Literal[
|
||||
"ChoosingBeatmap",
|
||||
"InSoloGame",
|
||||
"WatchingReplay",
|
||||
"SpectatingUser",
|
||||
"SearchingForLobby",
|
||||
"InLobby",
|
||||
"InMultiplayerGame",
|
||||
"SpectatingMultiplayerGame",
|
||||
"InPlaylistGame",
|
||||
"EditingBeatmap",
|
||||
"ModdingBeatmap",
|
||||
"TestingBeatmap",
|
||||
"InDailyChallengeLobby",
|
||||
"PlayingDailyChallenge",
|
||||
] = Field(alias="$dtype")
|
||||
value: Any | None = Field(alias="$value")
|
||||
class _UserActivity(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class ChoosingBeatmap(_UserActivity):
|
||||
type: Literal["ChoosingBeatmap"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class InGameValue(BaseModel):
|
||||
beatmap_id: int = Field(alias="BeatmapID")
|
||||
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
|
||||
ruleset_id: int = Field(alias="RulesetID")
|
||||
ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb")
|
||||
union_type: ClassVar[Literal[11]] = 11
|
||||
|
||||
|
||||
class _InGame(_UserActivity):
|
||||
value: InGameValue = Field(alias="$value")
|
||||
beatmap_id: int
|
||||
beatmap_display_title: str
|
||||
ruleset_id: int
|
||||
ruleset_playing_verb: str
|
||||
|
||||
|
||||
class InSoloGame(_InGame):
|
||||
type: Literal["InSoloGame"] = Field(alias="$dtype")
|
||||
union_type: ClassVar[Literal[12]] = 12
|
||||
|
||||
|
||||
class InMultiplayerGame(_InGame):
|
||||
type: Literal["InMultiplayerGame"] = Field(alias="$dtype")
|
||||
union_type: ClassVar[Literal[23]] = 23
|
||||
|
||||
|
||||
class SpectatingMultiplayerGame(_InGame):
|
||||
type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype")
|
||||
union_type: ClassVar[Literal[24]] = 24
|
||||
|
||||
|
||||
class InPlaylistGame(_InGame):
|
||||
type: Literal["InPlaylistGame"] = Field(alias="$dtype")
|
||||
union_type: ClassVar[Literal[31]] = 31
|
||||
|
||||
|
||||
class EditingBeatmapValue(BaseModel):
|
||||
beatmap_id: int = Field(alias="BeatmapID")
|
||||
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
|
||||
class PlayingDailyChallenge(_InGame):
|
||||
union_type: ClassVar[Literal[52]] = 52
|
||||
|
||||
|
||||
class EditingBeatmap(_UserActivity):
|
||||
type: Literal["EditingBeatmap"] = Field(alias="$dtype")
|
||||
value: EditingBeatmapValue = Field(alias="$value")
|
||||
union_type: ClassVar[Literal[41]] = 41
|
||||
beatmap_id: int
|
||||
beatmap_display_title: str
|
||||
|
||||
|
||||
class TestingBeatmap(_UserActivity):
|
||||
type: Literal["TestingBeatmap"] = Field(alias="$dtype")
|
||||
class TestingBeatmap(EditingBeatmap):
|
||||
union_type: ClassVar[Literal[43]] = 43
|
||||
|
||||
|
||||
class ModdingBeatmap(_UserActivity):
|
||||
type: Literal["ModdingBeatmap"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class WatchingReplayValue(BaseModel):
|
||||
score_id: int = Field(alias="ScoreID")
|
||||
player_name: str = Field(alias="PlayerName")
|
||||
beatmap_id: int = Field(alias="BeatmapID")
|
||||
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
|
||||
class ModdingBeatmap(EditingBeatmap):
|
||||
union_type: ClassVar[Literal[42]] = 42
|
||||
|
||||
|
||||
class WatchingReplay(_UserActivity):
|
||||
type: Literal["WatchingReplay"] = Field(alias="$dtype")
|
||||
value: int | None = Field(alias="$value") # Replay ID
|
||||
union_type: ClassVar[Literal[13]] = 13
|
||||
score_id: int
|
||||
player_name: str
|
||||
beatmap_id: int
|
||||
beatmap_display_title: str
|
||||
|
||||
|
||||
class SpectatingUser(WatchingReplay):
|
||||
type: Literal["SpectatingUser"] = Field(alias="$dtype")
|
||||
union_type: ClassVar[Literal[14]] = 14
|
||||
|
||||
|
||||
class SearchingForLobby(_UserActivity):
|
||||
type: Literal["SearchingForLobby"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class InLobbyValue(BaseModel):
|
||||
room_id: int = Field(alias="RoomID")
|
||||
room_name: str = Field(alias="RoomName")
|
||||
union_type: ClassVar[Literal[21]] = 21
|
||||
|
||||
|
||||
class InLobby(_UserActivity):
|
||||
type: Literal["InLobby"] = "InLobby"
|
||||
union_type: ClassVar[Literal[22]] = 22
|
||||
room_id: int
|
||||
room_name: str
|
||||
|
||||
|
||||
class InDailyChallengeLobby(_UserActivity):
|
||||
type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype")
|
||||
union_type: ClassVar[Literal[51]] = 51
|
||||
|
||||
|
||||
UserActivity = (
|
||||
@@ -128,23 +99,25 @@ UserActivity = (
|
||||
)
|
||||
|
||||
|
||||
class MetadataClientState(UserState):
|
||||
user_activity: UserActivity | None = None
|
||||
status: OnlineStatus | None = None
|
||||
class UserPresence(BaseModel):
|
||||
activity: UserActivity | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any] | None:
|
||||
if self.status is None or self.status == OnlineStatus.OFFLINE:
|
||||
return None
|
||||
dumped = self.model_dump(by_alias=True, exclude_none=True)
|
||||
return {
|
||||
"Activity": dumped.get("user_activity"),
|
||||
"Status": dumped.get("status"),
|
||||
}
|
||||
status: OnlineStatus | None = None
|
||||
|
||||
@property
|
||||
def pushable(self) -> bool:
|
||||
return self.status is not None and self.status != OnlineStatus.OFFLINE
|
||||
|
||||
@property
|
||||
def for_push(self) -> "UserPresence | None":
|
||||
return UserPresence(
|
||||
activity=self.activity,
|
||||
status=self.status,
|
||||
)
|
||||
|
||||
|
||||
class MetadataClientState(UserPresence, UserState): ...
|
||||
|
||||
|
||||
class OnlineStatus(IntEnum):
|
||||
OFFLINE = 0 # 隐身
|
||||
|
||||
22
app/models/model.py
Normal file
22
app/models/model.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
|
||||
class UTCBaseModel(BaseModel):
|
||||
@field_serializer("*", when_used="json")
|
||||
def serialize_datetime(self, v, _info):
|
||||
if isinstance(v, datetime):
|
||||
if v.tzinfo is None:
|
||||
v = v.replace(tzinfo=UTC)
|
||||
return v.astimezone(UTC).isoformat()
|
||||
return v
|
||||
|
||||
|
||||
Cursor = dict[str, int]
|
||||
|
||||
|
||||
class RespWithCursor(BaseModel):
|
||||
cursor: Cursor | None = None
|
||||
@@ -8,7 +8,7 @@ from app.path import STATIC_DIR
|
||||
|
||||
class APIMod(TypedDict):
|
||||
acronym: str
|
||||
settings: NotRequired[dict[str, bool | float | str]]
|
||||
settings: NotRequired[dict[str, bool | float | str | int]]
|
||||
|
||||
|
||||
# https://github.com/ppy/osu-api/wiki#mods
|
||||
|
||||
925
app/models/multiplayer_hub.py
Normal file
925
app/models/multiplayer_hub.py
Normal file
@@ -0,0 +1,925 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import IntEnum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
TypedDict,
|
||||
cast,
|
||||
override,
|
||||
)
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.dependencies.database import engine
|
||||
from app.exception import InvokeException
|
||||
|
||||
from .mods import APIMod
|
||||
from .room import (
|
||||
DownloadState,
|
||||
MatchType,
|
||||
MultiplayerRoomState,
|
||||
MultiplayerUserState,
|
||||
QueueMode,
|
||||
RoomCategory,
|
||||
RoomStatus,
|
||||
)
|
||||
from .signalr import (
|
||||
SignalRMeta,
|
||||
SignalRUnionMessage,
|
||||
UserState,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import col
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.signalr.hub import MultiplayerHub
|
||||
|
||||
HOST_LIMIT = 50
|
||||
PER_USER_LIMIT = 3
|
||||
|
||||
|
||||
class MultiplayerClientState(UserState):
|
||||
room_id: int = 0
|
||||
|
||||
|
||||
class MultiplayerRoomSettings(BaseModel):
|
||||
name: str = "Unnamed Room"
|
||||
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||
password: str = ""
|
||||
match_type: MatchType = MatchType.HEAD_TO_HEAD
|
||||
queue_mode: QueueMode = QueueMode.HOST_ONLY
|
||||
auto_start_duration: timedelta = timedelta(seconds=0)
|
||||
auto_skip: bool = False
|
||||
|
||||
@property
|
||||
def auto_start_enabled(self) -> bool:
|
||||
return self.auto_start_duration != timedelta(seconds=0)
|
||||
|
||||
|
||||
class BeatmapAvailability(BaseModel):
|
||||
state: DownloadState = DownloadState.UNKNOWN
|
||||
download_progress: float | None = None
|
||||
|
||||
|
||||
class _MatchUserState(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class TeamVersusUserState(_MatchUserState):
|
||||
team_id: int
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
MatchUserState = TeamVersusUserState
|
||||
|
||||
|
||||
class _MatchRoomState(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class MultiplayerTeam(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class TeamVersusRoomState(_MatchRoomState):
|
||||
teams: list[MultiplayerTeam] = Field(
|
||||
default_factory=lambda: [
|
||||
MultiplayerTeam(id=0, name="Team Red"),
|
||||
MultiplayerTeam(id=1, name="Team Blue"),
|
||||
]
|
||||
)
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
MatchRoomState = TeamVersusRoomState
|
||||
|
||||
|
||||
class PlaylistItem(BaseModel):
|
||||
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||
owner_id: int
|
||||
beatmap_id: int
|
||||
beatmap_checksum: str
|
||||
ruleset_id: int
|
||||
required_mods: list[APIMod] = Field(default_factory=list)
|
||||
allowed_mods: list[APIMod] = Field(default_factory=list)
|
||||
expired: bool
|
||||
playlist_order: int
|
||||
played_at: datetime | None = None
|
||||
star_rating: float
|
||||
freestyle: bool
|
||||
|
||||
def _get_api_mods(self):
|
||||
from app.models.mods import API_MODS, init_mods
|
||||
|
||||
if not API_MODS:
|
||||
init_mods()
|
||||
return API_MODS
|
||||
|
||||
def _validate_mod_for_ruleset(
|
||||
self, mod: APIMod, ruleset_key: int, context: str = "mod"
|
||||
) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
# Check if mod is valid for ruleset
|
||||
if (
|
||||
typed_ruleset_key not in API_MODS
|
||||
or mod["acronym"] not in API_MODS[typed_ruleset_key]
|
||||
):
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is invalid for this ruleset"
|
||||
)
|
||||
|
||||
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
|
||||
|
||||
# Check if mod is unplayable in multiplayer
|
||||
if mod_settings.get("UserPlayable", True) is False:
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is not playable by users"
|
||||
)
|
||||
|
||||
if mod_settings.get("ValidForMultiplayer", True) is False:
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is not valid for multiplayer"
|
||||
)
|
||||
|
||||
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
for i, mod1 in enumerate(mods):
|
||||
mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
|
||||
if mod1_settings:
|
||||
incompatible = set(mod1_settings.get("IncompatibleMods", []))
|
||||
for mod2 in mods[i + 1 :]:
|
||||
if mod2["acronym"] in incompatible:
|
||||
raise InvokeException(
|
||||
f"Mods {mod1['acronym']} and "
|
||||
f"{mod2['acronym']} are incompatible"
|
||||
)
|
||||
|
||||
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
|
||||
for req_mod in self.required_mods:
|
||||
req_acronym = req_mod["acronym"]
|
||||
req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
|
||||
if req_settings:
|
||||
incompatible = set(req_settings.get("IncompatibleMods", []))
|
||||
conflicting_allowed = allowed_acronyms & incompatible
|
||||
if conflicting_allowed:
|
||||
conflict_list = ", ".join(conflicting_allowed)
|
||||
raise InvokeException(
|
||||
f"Required mod {req_acronym} conflicts with "
|
||||
f"allowed mods: {conflict_list}"
|
||||
)
|
||||
|
||||
def validate_playlist_item_mods(self) -> None:
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
|
||||
|
||||
# Validate required mods
|
||||
for mod in self.required_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
|
||||
|
||||
# Validate allowed mods
|
||||
for mod in self.allowed_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
|
||||
|
||||
# Check internal compatibility of required mods
|
||||
self._check_mod_compatibility(self.required_mods, ruleset_key)
|
||||
|
||||
# Check compatibility between required and allowed mods
|
||||
self._check_required_allowed_compatibility(ruleset_key)
|
||||
|
||||
def validate_user_mods(
|
||||
self,
|
||||
user: "MultiplayerRoomUser",
|
||||
proposed_mods: list[APIMod],
|
||||
) -> tuple[bool, list[APIMod]]:
|
||||
"""
|
||||
Validates user mods against playlist item rules and returns valid mods.
|
||||
Returns (is_valid, valid_mods).
|
||||
"""
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
|
||||
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
|
||||
|
||||
valid_mods = []
|
||||
all_proposed_valid = True
|
||||
|
||||
# Check if mods are valid for the ruleset
|
||||
for mod in proposed_mods:
|
||||
if (
|
||||
ruleset_key not in API_MODS
|
||||
or mod["acronym"] not in API_MODS[ruleset_key]
|
||||
):
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
valid_mods.append(mod)
|
||||
|
||||
# Check mod compatibility within user mods
|
||||
incompatible_mods = set()
|
||||
final_valid_mods = []
|
||||
for mod in valid_mods:
|
||||
if mod["acronym"] in incompatible_mods:
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
|
||||
if setting_mods:
|
||||
incompatible_mods.update(setting_mods["IncompatibleMods"])
|
||||
final_valid_mods.append(mod)
|
||||
|
||||
# If not freestyle, check against allowed mods
|
||||
if not self.freestyle:
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
filtered_valid_mods = []
|
||||
for mod in final_valid_mods:
|
||||
if mod["acronym"] not in allowed_acronyms:
|
||||
all_proposed_valid = False
|
||||
else:
|
||||
filtered_valid_mods.append(mod)
|
||||
final_valid_mods = filtered_valid_mods
|
||||
|
||||
# Check compatibility with required mods
|
||||
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
|
||||
all_mod_acronyms = {
|
||||
mod["acronym"] for mod in final_valid_mods
|
||||
} | required_mod_acronyms
|
||||
|
||||
# Check for incompatibility between required and user mods
|
||||
filtered_valid_mods = []
|
||||
for mod in final_valid_mods:
|
||||
mod_acronym = mod["acronym"]
|
||||
is_compatible = True
|
||||
|
||||
for other_acronym in all_mod_acronyms:
|
||||
if other_acronym == mod_acronym:
|
||||
continue
|
||||
setting_mods = API_MODS[ruleset_key].get(mod_acronym)
|
||||
if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
|
||||
is_compatible = False
|
||||
all_proposed_valid = False
|
||||
break
|
||||
|
||||
if is_compatible:
|
||||
filtered_valid_mods.append(mod)
|
||||
|
||||
return all_proposed_valid, filtered_valid_mods
|
||||
|
||||
def clone(self) -> "PlaylistItem":
|
||||
copy = self.model_copy()
|
||||
copy.required_mods = list(self.required_mods)
|
||||
copy.allowed_mods = list(self.allowed_mods)
|
||||
copy.expired = False
|
||||
copy.played_at = None
|
||||
return copy
|
||||
|
||||
|
||||
class _MultiplayerCountdown(SignalRUnionMessage):
|
||||
id: int = 0
|
||||
time_remaining: timedelta
|
||||
is_exclusive: Annotated[
|
||||
bool, Field(default=True), SignalRMeta(member_ignore=True)
|
||||
] = True
|
||||
|
||||
|
||||
class MatchStartCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
class ForceGameplayStartCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
|
||||
|
||||
class ServerShuttingDownCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
|
||||
|
||||
MultiplayerCountdown = (
|
||||
MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerRoomUser(BaseModel):
|
||||
user_id: int
|
||||
state: MultiplayerUserState = MultiplayerUserState.IDLE
|
||||
availability: BeatmapAvailability = BeatmapAvailability(
|
||||
state=DownloadState.UNKNOWN, download_progress=None
|
||||
)
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
match_state: MatchUserState | None = None
|
||||
ruleset_id: int | None = None # freestyle
|
||||
beatmap_id: int | None = None # freestyle
|
||||
|
||||
|
||||
class MultiplayerRoom(BaseModel):
|
||||
room_id: int
|
||||
state: MultiplayerRoomState
|
||||
settings: MultiplayerRoomSettings
|
||||
users: list[MultiplayerRoomUser] = Field(default_factory=list)
|
||||
host: MultiplayerRoomUser | None = None
|
||||
match_state: MatchRoomState | None = None
|
||||
playlist: list[PlaylistItem] = Field(default_factory=list)
|
||||
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
|
||||
channel_id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, room) -> "MultiplayerRoom":
|
||||
"""
|
||||
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
|
||||
"""
|
||||
|
||||
# 用户列表
|
||||
users = [MultiplayerRoomUser(user_id=room.host_id)]
|
||||
host_user = MultiplayerRoomUser(user_id=room.host_id)
|
||||
# playlist 转换
|
||||
playlist = []
|
||||
if hasattr(room, "playlist"):
|
||||
for item in room.playlist:
|
||||
playlist.append(
|
||||
PlaylistItem(
|
||||
id=item.id,
|
||||
owner_id=item.owner_id,
|
||||
beatmap_id=item.beatmap_id,
|
||||
beatmap_checksum=item.beatmap.checksum if item.beatmap else "",
|
||||
ruleset_id=item.ruleset_id,
|
||||
required_mods=item.required_mods,
|
||||
allowed_mods=item.allowed_mods,
|
||||
expired=item.expired,
|
||||
playlist_order=item.playlist_order,
|
||||
played_at=item.played_at,
|
||||
star_rating=item.beatmap.difficulty_rating
|
||||
if item.beatmap is not None
|
||||
else 0.0,
|
||||
freestyle=item.freestyle,
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
room_id=room.id,
|
||||
state=getattr(room, "state", MultiplayerRoomState.OPEN),
|
||||
settings=MultiplayerRoomSettings(
|
||||
name=room.name,
|
||||
playlist_item_id=playlist[0].id if playlist else 0,
|
||||
password=getattr(room, "password", ""),
|
||||
match_type=room.type,
|
||||
queue_mode=room.queue_mode,
|
||||
auto_start_duration=timedelta(seconds=room.auto_start_duration),
|
||||
auto_skip=room.auto_skip,
|
||||
),
|
||||
users=users,
|
||||
host=host_user,
|
||||
match_state=None,
|
||||
playlist=playlist,
|
||||
active_countdowns=[],
|
||||
channel_id=getattr(room, "channel_id", 0),
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerQueue:
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
self.server_room = room
|
||||
self.current_index = 0
|
||||
|
||||
@property
|
||||
def hub(self) -> "MultiplayerHub":
|
||||
return self.server_room.hub
|
||||
|
||||
@property
|
||||
def upcoming_items(self):
|
||||
return sorted(
|
||||
(item for item in self.room.playlist if not item.expired),
|
||||
key=lambda i: i.playlist_order,
|
||||
)
|
||||
|
||||
@property
|
||||
def room(self):
|
||||
return self.server_room.room
|
||||
|
||||
async def update_order(self):
|
||||
from app.database import Playlist
|
||||
|
||||
match self.room.settings.queue_mode:
|
||||
case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
|
||||
ordered_active_items = []
|
||||
|
||||
is_first_set = True
|
||||
first_set_order_by_user_id = {}
|
||||
|
||||
active_items = [item for item in self.room.playlist if not item.expired]
|
||||
active_items.sort(key=lambda x: x.id)
|
||||
|
||||
user_item_groups = {}
|
||||
for item in active_items:
|
||||
if item.owner_id not in user_item_groups:
|
||||
user_item_groups[item.owner_id] = []
|
||||
user_item_groups[item.owner_id].append(item)
|
||||
|
||||
max_items = max(
|
||||
(len(items) for items in user_item_groups.values()), default=0
|
||||
)
|
||||
|
||||
for i in range(max_items):
|
||||
current_set = []
|
||||
for user_id, items in user_item_groups.items():
|
||||
if i < len(items):
|
||||
current_set.append(items[i])
|
||||
|
||||
if is_first_set:
|
||||
current_set.sort(
|
||||
key=lambda item: (item.playlist_order, item.id)
|
||||
)
|
||||
ordered_active_items.extend(current_set)
|
||||
first_set_order_by_user_id = {
|
||||
item.owner_id: idx
|
||||
for idx, item in enumerate(ordered_active_items)
|
||||
}
|
||||
else:
|
||||
current_set.sort(
|
||||
key=lambda item: first_set_order_by_user_id.get(
|
||||
item.owner_id, 0
|
||||
)
|
||||
)
|
||||
ordered_active_items.extend(current_set)
|
||||
|
||||
is_first_set = False
|
||||
|
||||
for idx, item in enumerate(ordered_active_items):
|
||||
item.playlist_order = idx
|
||||
case _:
|
||||
ordered_active_items = sorted(
|
||||
(item for item in self.room.playlist if not item.expired),
|
||||
key=lambda x: x.id,
|
||||
)
|
||||
async with AsyncSession(engine) as session:
|
||||
for idx, item in enumerate(ordered_active_items):
|
||||
if item.playlist_order == idx:
|
||||
continue
|
||||
item.playlist_order = idx
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
await self.hub.playlist_changed(
|
||||
self.server_room, item, beatmap_changed=False
|
||||
)
|
||||
|
||||
async def update_current_item(self):
|
||||
upcoming_items = self.upcoming_items
|
||||
next_item = (
|
||||
upcoming_items[0]
|
||||
if upcoming_items
|
||||
else max(
|
||||
self.room.playlist,
|
||||
key=lambda i: i.played_at or datetime.min,
|
||||
)
|
||||
)
|
||||
self.current_index = self.room.playlist.index(next_item)
|
||||
last_id = self.room.settings.playlist_item_id
|
||||
self.room.settings.playlist_item_id = next_item.id
|
||||
if last_id != next_item.id:
|
||||
await self.hub.setting_changed(self.server_room, True)
|
||||
|
||||
async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
is_host = self.room.host and self.room.host.user_id == user.user_id
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
|
||||
raise InvokeException("You are not the host")
|
||||
|
||||
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
|
||||
if (
|
||||
len([True for u in self.room.playlist if u.owner_id == user.user_id])
|
||||
>= limit
|
||||
):
|
||||
raise InvokeException(f"You can only have {limit} items in the queue")
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session:
|
||||
beatmap = await session.get(Beatmap, item.beatmap_id)
|
||||
if beatmap is None:
|
||||
raise InvokeException("Beatmap not found")
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = float(
|
||||
beatmap.difficulty_rating
|
||||
) # FIXME: beatmap use decimal
|
||||
await Playlist.add_to_db(item, self.room.room_id, session)
|
||||
self.room.playlist.append(item)
|
||||
await self.hub.playlist_added(self.server_room, item)
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
|
||||
async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session:
|
||||
beatmap = await session.get(Beatmap, item.beatmap_id)
|
||||
if beatmap is None:
|
||||
raise InvokeException("Beatmap not found")
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
existing_item = next(
|
||||
(i for i in self.room.playlist if i.id == item.id), None
|
||||
)
|
||||
if existing_item is None:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item that doesn't exist"
|
||||
)
|
||||
|
||||
if existing_item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item which is not owned by the user"
|
||||
)
|
||||
|
||||
if existing_item.expired:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item which has already been played"
|
||||
)
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = float(beatmap.difficulty_rating)
|
||||
item.playlist_order = existing_item.playlist_order
|
||||
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
|
||||
# Update item in playlist
|
||||
for idx, playlist_item in enumerate(self.room.playlist):
|
||||
if playlist_item.id == item.id:
|
||||
self.room.playlist[idx] = item
|
||||
break
|
||||
|
||||
await self.hub.playlist_changed(
|
||||
self.server_room,
|
||||
item,
|
||||
beatmap_changed=item.beatmap_checksum
|
||||
!= existing_item.beatmap_checksum,
|
||||
)
|
||||
|
||||
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
item = next(
|
||||
(i for i in self.room.playlist if i.id == playlist_item_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if item is None:
|
||||
raise InvokeException("Item does not exist in the room")
|
||||
|
||||
# Check if it's the only item and current item
|
||||
if item == self.current_item:
|
||||
upcoming_items = [i for i in self.room.playlist if not i.expired]
|
||||
if len(upcoming_items) == 1:
|
||||
raise InvokeException("The only item in the room cannot be removed")
|
||||
|
||||
if item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException(
|
||||
"Attempted to remove an item which is not owned by the user"
|
||||
)
|
||||
|
||||
if item.expired:
|
||||
raise InvokeException(
|
||||
"Attempted to remove an item which has already been played"
|
||||
)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
await Playlist.delete_item(item.id, self.room.room_id, session)
|
||||
|
||||
self.room.playlist.remove(item)
|
||||
self.current_index = self.room.playlist.index(self.upcoming_items[0])
|
||||
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
await self.hub.playlist_removed(self.server_room, item.id)
|
||||
|
||||
async def finish_current_item(self):
|
||||
from app.database import Playlist
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
played_at = datetime.now(UTC)
|
||||
await session.execute(
|
||||
update(Playlist)
|
||||
.where(
|
||||
col(Playlist.id) == self.current_item.id,
|
||||
col(Playlist.room_id) == self.room.room_id,
|
||||
)
|
||||
.values(expired=True, played_at=played_at)
|
||||
)
|
||||
self.room.playlist[self.current_index].expired = True
|
||||
self.room.playlist[self.current_index].played_at = played_at
|
||||
await self.hub.playlist_changed(self.server_room, self.current_item, True)
|
||||
await self.update_order()
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||
playitem.expired for playitem in self.room.playlist
|
||||
):
|
||||
assert self.room.host
|
||||
await self.add_item(self.current_item.clone(), self.room.host)
|
||||
|
||||
async def update_queue_mode(self):
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||
playitem.expired for playitem in self.room.playlist
|
||||
):
|
||||
assert self.room.host
|
||||
await self.add_item(self.current_item.clone(), self.room.host)
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
|
||||
@property
|
||||
def current_item(self):
|
||||
return self.room.playlist[self.current_index]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountdownInfo:
|
||||
countdown: MultiplayerCountdown
|
||||
duration: timedelta
|
||||
task: asyncio.Task | None = None
|
||||
|
||||
def __init__(self, countdown: MultiplayerCountdown):
|
||||
self.countdown = countdown
|
||||
self.duration = (
|
||||
countdown.time_remaining
|
||||
if countdown.time_remaining > timedelta(seconds=0)
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
|
||||
|
||||
class _MatchRequest(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class ChangeTeamRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
team_id: int
|
||||
|
||||
|
||||
class StartMatchCountdownRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
duration: timedelta
|
||||
|
||||
|
||||
class StopCountdownRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
id: int
|
||||
|
||||
|
||||
MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
|
||||
|
||||
|
||||
class MatchTypeHandler(ABC):
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
self.room = room
|
||||
self.hub = room.hub
|
||||
|
||||
@abstractmethod
|
||||
async def handle_join(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_request(
|
||||
self, user: MultiplayerRoomUser, request: MatchRequest
|
||||
): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_details(self) -> MatchStartedEventDetail: ...
|
||||
|
||||
|
||||
class HeadToHeadHandler(MatchTypeHandler):
|
||||
@override
|
||||
async def handle_join(self, user: MultiplayerRoomUser):
|
||||
if user.match_state is not None:
|
||||
user.match_state = None
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(
|
||||
self, user: MultiplayerRoomUser, request: MatchRequest
|
||||
): ...
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@override
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
detail = MatchStartedEventDetail(room_type="head_to_head", team=None)
|
||||
return detail
|
||||
|
||||
|
||||
class TeamVersusHandler(MatchTypeHandler):
|
||||
@override
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
super().__init__(room)
|
||||
self.state = TeamVersusRoomState()
|
||||
room.room.match_state = self.state
|
||||
task = asyncio.create_task(self.hub.change_room_match_state(self.room))
|
||||
self.hub.tasks.add(task)
|
||||
task.add_done_callback(self.hub.tasks.discard)
|
||||
|
||||
def _get_best_available_team(self) -> int:
|
||||
for team in self.state.teams:
|
||||
if all(
|
||||
(
|
||||
user.match_state is None
|
||||
or not isinstance(user.match_state, TeamVersusUserState)
|
||||
or user.match_state.team_id != team.id
|
||||
)
|
||||
for user in self.room.room.users
|
||||
):
|
||||
return team.id
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
team_counts = defaultdict(int)
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(
|
||||
user.match_state, TeamVersusUserState
|
||||
):
|
||||
team_counts[user.match_state.team_id] += 1
|
||||
|
||||
if team_counts:
|
||||
min_count = min(team_counts.values())
|
||||
for team_id, count in team_counts.items():
|
||||
if count == min_count:
|
||||
return team_id
|
||||
return self.state.teams[0].id if self.state.teams else 0
|
||||
|
||||
@override
|
||||
async def handle_join(self, user: MultiplayerRoomUser):
|
||||
best_team_id = self._get_best_available_team()
|
||||
user.match_state = TeamVersusUserState(team_id=best_team_id)
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
|
||||
if not isinstance(request, ChangeTeamRequest):
|
||||
return
|
||||
|
||||
if request.team_id not in [team.id for team in self.state.teams]:
|
||||
raise InvokeException("Invalid team ID")
|
||||
|
||||
user.match_state = TeamVersusUserState(team_id=request.team_id)
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@override
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
teams: dict[int, Literal["blue", "red"]] = {}
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(
|
||||
user.match_state, TeamVersusUserState
|
||||
):
|
||||
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
|
||||
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
|
||||
return detail
|
||||
|
||||
|
||||
MATCH_TYPE_HANDLERS = {
|
||||
MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
|
||||
MatchType.TEAM_VERSUS: TeamVersusHandler,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerMultiplayerRoom:
|
||||
room: MultiplayerRoom
|
||||
category: RoomCategory
|
||||
status: RoomStatus
|
||||
start_at: datetime
|
||||
hub: "MultiplayerHub"
|
||||
match_type_handler: MatchTypeHandler
|
||||
queue: MultiplayerQueue
|
||||
_next_countdown_id: int
|
||||
_countdown_id_lock: asyncio.Lock
|
||||
_tracked_countdown: dict[int, CountdownInfo]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
room: MultiplayerRoom,
|
||||
category: RoomCategory,
|
||||
start_at: datetime,
|
||||
hub: "MultiplayerHub",
|
||||
):
|
||||
self.room = room
|
||||
self.category = category
|
||||
self.status = RoomStatus.IDLE
|
||||
self.start_at = start_at
|
||||
self.hub = hub
|
||||
self.queue = MultiplayerQueue(self)
|
||||
self._next_countdown_id = 0
|
||||
self._countdown_id_lock = asyncio.Lock()
|
||||
self._tracked_countdown = {}
|
||||
|
||||
async def set_handler(self):
|
||||
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](
|
||||
self
|
||||
)
|
||||
for i in self.room.users:
|
||||
await self.match_type_handler.handle_join(i)
|
||||
|
||||
async def get_next_countdown_id(self) -> int:
|
||||
async with self._countdown_id_lock:
|
||||
self._next_countdown_id += 1
|
||||
return self._next_countdown_id
|
||||
|
||||
async def start_countdown(
|
||||
self,
|
||||
countdown: MultiplayerCountdown,
|
||||
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
|
||||
):
|
||||
async def _countdown_task(self: "ServerMultiplayerRoom"):
|
||||
await asyncio.sleep(info.duration.total_seconds())
|
||||
if on_complete is not None:
|
||||
await on_complete(self)
|
||||
await self.stop_countdown(countdown)
|
||||
|
||||
if countdown.is_exclusive:
|
||||
await self.stop_all_countdowns()
|
||||
countdown.id = await self.get_next_countdown_id()
|
||||
info = CountdownInfo(countdown)
|
||||
self.room.active_countdowns.append(info.countdown)
|
||||
self._tracked_countdown[countdown.id] = info
|
||||
await self.hub.send_match_event(
|
||||
self, CountdownStartedEvent(countdown=info.countdown)
|
||||
)
|
||||
info.task = asyncio.create_task(_countdown_task(self))
|
||||
|
||||
async def stop_countdown(self, countdown: MultiplayerCountdown):
|
||||
info = self._tracked_countdown.get(countdown.id)
|
||||
if info is None:
|
||||
return
|
||||
del self._tracked_countdown[countdown.id]
|
||||
self.room.active_countdowns.remove(countdown)
|
||||
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
|
||||
if info.task is not None and not info.task.done():
|
||||
info.task.cancel()
|
||||
|
||||
async def stop_all_countdowns(self):
|
||||
for countdown in list(self._tracked_countdown.values()):
|
||||
await self.stop_countdown(countdown.countdown)
|
||||
|
||||
self._tracked_countdown.clear()
|
||||
self.room.active_countdowns.clear()
|
||||
|
||||
|
||||
class _MatchServerEvent(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class CountdownStartedEvent(_MatchServerEvent):
|
||||
countdown: MultiplayerCountdown
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
class CountdownStoppedEvent(_MatchServerEvent):
|
||||
id: int
|
||||
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
|
||||
|
||||
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
|
||||
|
||||
|
||||
class GameplayAbortReason(IntEnum):
|
||||
LOAD_TOOK_TOO_LONG = 0
|
||||
HOST_ABORTED = 1
|
||||
|
||||
|
||||
class MatchStartedEventDetail(TypedDict):
|
||||
room_type: Literal["playlists", "head_to_head", "team_versus"]
|
||||
team: dict[int, Literal["blue", "red"]] | None
|
||||
@@ -1,7 +1,6 @@
|
||||
# OAuth 相关模型
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel):
|
||||
|
||||
class RegistrationErrorResponse(BaseModel):
|
||||
"""注册错误响应模型"""
|
||||
|
||||
form_error: dict
|
||||
|
||||
|
||||
class UserRegistrationErrors(BaseModel):
|
||||
"""用户注册错误模型"""
|
||||
username: List[str] = []
|
||||
user_email: List[str] = []
|
||||
password: List[str] = []
|
||||
|
||||
username: list[str] = []
|
||||
user_email: list[str] = []
|
||||
password: list[str] = []
|
||||
|
||||
|
||||
class RegistrationRequestErrors(BaseModel):
|
||||
"""注册请求错误模型"""
|
||||
|
||||
message: str | None = None
|
||||
redirect: str | None = None
|
||||
user: UserRegistrationErrors | None = None
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
from app.database.beatmap import Beatmap, BeatmapResp
|
||||
from app.database.user import User as DBUser
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.mods import APIMod
|
||||
from app.models.user import User
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RoomCategory(str, Enum):
|
||||
@@ -62,53 +53,24 @@ class MultiplayerUserState(str, Enum):
|
||||
RESULTS = "results"
|
||||
SPECTATING = "spectating"
|
||||
|
||||
@property
|
||||
def is_playing(self) -> bool:
|
||||
return self in {
|
||||
self.WAITING_FOR_LOAD,
|
||||
self.PLAYING,
|
||||
self.READY_FOR_GAMEPLAY,
|
||||
self.LOADED,
|
||||
}
|
||||
|
||||
|
||||
class DownloadState(str, Enum):
|
||||
UNKOWN = "unkown"
|
||||
UNKNOWN = "unknown"
|
||||
NOT_DOWNLOADED = "not_downloaded"
|
||||
DOWNLOADING = "downloading"
|
||||
IMPORTING = "importing"
|
||||
LOCALLY_AVAILABLE = "locally_available"
|
||||
|
||||
|
||||
class PlaylistItem(BaseModel):
|
||||
id: int
|
||||
owner_id: int
|
||||
ruleset_id: int
|
||||
expired: bool
|
||||
playlist_order: int | None
|
||||
played_at: datetime | None
|
||||
allowed_mods: list[APIMod] = []
|
||||
required_mods: list[APIMod] = []
|
||||
beatmap_id: int
|
||||
beatmap: BeatmapResp | None
|
||||
freestyle: bool
|
||||
|
||||
class Config:
|
||||
exclude_none = True
|
||||
|
||||
@classmethod
|
||||
async def from_mpListItem(
|
||||
cls, item: MultiPlayerListItem, db: AsyncSession, fetcher: Fetcher
|
||||
):
|
||||
s = cls.model_validate(item.model_dump())
|
||||
s.id = item.id
|
||||
s.owner_id = item.OwnerID
|
||||
s.ruleset_id = item.RulesetID
|
||||
s.expired = item.Expired
|
||||
s.playlist_order = item.PlaylistOrder
|
||||
s.played_at = item.PlayedAt
|
||||
s.required_mods = item.RequierdMods
|
||||
s.allowed_mods = item.AllowedMods
|
||||
s.freestyle = item.Freestyle
|
||||
cur_beatmap = await Beatmap.get_or_fetch(
|
||||
db, fetcher=fetcher, bid=item.BeatmapID
|
||||
)
|
||||
s.beatmap = BeatmapResp.from_db(cur_beatmap)
|
||||
s.beatmap_id = item.BeatmapID
|
||||
return s
|
||||
|
||||
|
||||
class RoomPlaylistItemStats(BaseModel):
|
||||
count_active: int
|
||||
count_total: int
|
||||
@@ -120,269 +82,7 @@ class RoomDifficultyRange(BaseModel):
|
||||
max: float
|
||||
|
||||
|
||||
class ItemAttemptsCount(BaseModel):
|
||||
id: int
|
||||
attempts: int
|
||||
passed: bool
|
||||
|
||||
|
||||
class PlaylistAggregateScore(BaseModel):
|
||||
playlist_item_attempts: list[ItemAttemptsCount]
|
||||
|
||||
|
||||
class MultiplayerRoomSettings(BaseModel):
|
||||
Name: str = "Unnamed Room"
|
||||
PlaylistItemId: int
|
||||
Password: str = ""
|
||||
MatchType: MatchType
|
||||
QueueMode: QueueMode
|
||||
AutoStartDuration: timedelta
|
||||
AutoSkip: bool
|
||||
|
||||
@classmethod
|
||||
def from_apiRoom(cls, room: Room):
|
||||
s = cls.model_validate(room.model_dump())
|
||||
s.Name = room.name
|
||||
s.Password = room.password if room.password is not None else ""
|
||||
s.MatchType = room.type
|
||||
s.QueueMode = room.queue_mode
|
||||
s.AutoStartDuration = timedelta(seconds=room.auto_start_duration)
|
||||
s.AutoSkip = room.auto_skip
|
||||
return s
|
||||
|
||||
|
||||
class BeatmapAvailability(BaseModel):
|
||||
State: DownloadState
|
||||
DownloadProgress: float | None
|
||||
|
||||
|
||||
class MatchUserState(BaseModel):
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class TeamVersusState(MatchUserState):
|
||||
TeamId: int
|
||||
|
||||
|
||||
MatchUserStateType = TeamVersusState | MatchUserState
|
||||
|
||||
|
||||
class MultiplayerRoomUser(BaseModel):
|
||||
UserID: int
|
||||
State: MultiplayerUserState = MultiplayerUserState.IDLE
|
||||
BeatmapAvailability: BeatmapAvailability
|
||||
Mods: list[APIMod] = []
|
||||
MatchUserState: MatchUserStateType | None
|
||||
RulesetId: int | None
|
||||
BeatmapId: int | None
|
||||
User: User | None
|
||||
|
||||
@classmethod
|
||||
async def from_id(cls, id: int, db: AsyncSession):
|
||||
actualUser = (
|
||||
await db.exec(
|
||||
DBUser.all_select_clause().where(
|
||||
DBUser.id == id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
user = (
|
||||
await convert_db_user_to_api_user(actualUser)
|
||||
if actualUser is not None
|
||||
else None
|
||||
)
|
||||
return MultiplayerRoomUser(
|
||||
UserID=id,
|
||||
MatchUserState=None,
|
||||
BeatmapAvailability=BeatmapAvailability(
|
||||
State=DownloadState.UNKOWN, DownloadProgress=None
|
||||
),
|
||||
RulesetId=None,
|
||||
BeatmapId=None,
|
||||
User=user,
|
||||
)
|
||||
|
||||
|
||||
class MatchRoomState(BaseModel):
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class MultiPlayerTeam(BaseModel):
|
||||
id: int = 0
|
||||
name: str = ""
|
||||
|
||||
|
||||
class TeamVersusRoomState(BaseModel):
|
||||
teams: list[MultiPlayerTeam] = []
|
||||
|
||||
class Config:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def create_default(cls):
|
||||
return cls(
|
||||
teams=[
|
||||
MultiPlayerTeam(id=0, name="Team Red"),
|
||||
MultiPlayerTeam(id=1, name="Team Blue"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MatchRoomStateType = TeamVersusRoomState | MatchRoomState
|
||||
|
||||
|
||||
class MultiPlayerListItem(BaseModel):
|
||||
id: int
|
||||
OwnerID: int
|
||||
BeatmapID: int
|
||||
BeatmapChecksum: str = ""
|
||||
RulesetID: int
|
||||
RequierdMods: list[APIMod]
|
||||
AllowedMods: list[APIMod]
|
||||
Expired: bool
|
||||
PlaylistOrder: int | None
|
||||
PlayedAt: datetime | None
|
||||
StarRating: float
|
||||
Freestyle: bool
|
||||
|
||||
@classmethod
|
||||
async def from_apiItem(cls, item: PlaylistItem, db: AsyncSession, fetcher: Fetcher):
|
||||
s = cls.model_validate(item.model_dump())
|
||||
s.id = item.id
|
||||
s.OwnerID = item.owner_id
|
||||
if item.beatmap is None: # 从客户端接受的一定没有这字段
|
||||
cur_beatmap = await Beatmap.get_or_fetch(
|
||||
db, fetcher=fetcher, bid=item.beatmap_id
|
||||
)
|
||||
s.BeatmapID = cur_beatmap.id if cur_beatmap.id is not None else 0
|
||||
s.BeatmapChecksum = cur_beatmap.checksum
|
||||
s.StarRating = cur_beatmap.difficulty_rating
|
||||
s.RulesetID = item.ruleset_id
|
||||
s.RequierdMods = item.required_mods
|
||||
s.AllowedMods = item.allowed_mods
|
||||
s.Expired = item.expired
|
||||
s.PlaylistOrder = item.playlist_order if item.playlist_order is not None else 0
|
||||
s.PlayedAt = item.played_at
|
||||
s.Freestyle = item.freestyle
|
||||
return s
|
||||
|
||||
|
||||
class MultiplayerCountdown(BaseModel):
|
||||
id: int = 0
|
||||
time_remaining: timedelta = timedelta(seconds=0)
|
||||
is_exclusive: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class MatchStartCountdown(MultiplayerCountdown):
|
||||
pass
|
||||
|
||||
|
||||
class ForceGameplayStartCountdown(MultiplayerCountdown):
|
||||
pass
|
||||
|
||||
|
||||
class ServerShuttingCountdown(MultiplayerCountdown):
|
||||
pass
|
||||
|
||||
|
||||
MultiplayerCountdownType = (
|
||||
MatchStartCountdown
|
||||
| ForceGameplayStartCountdown
|
||||
| ServerShuttingCountdown
|
||||
| MultiplayerCountdown
|
||||
)
|
||||
|
||||
|
||||
class PlaylistStatus(BaseModel):
|
||||
count_active: int
|
||||
count_total: int
|
||||
ruleset_ids: list[int]
|
||||
|
||||
|
||||
class MultiplayerRoom(BaseModel):
|
||||
RoomId: int
|
||||
State: MultiplayerRoomState
|
||||
Settings: MultiplayerRoomSettings = MultiplayerRoomSettings(
|
||||
PlaylistItemId=0,
|
||||
MatchType=MatchType.HEAD_TO_HEAD,
|
||||
QueueMode=QueueMode.HOST_ONLY,
|
||||
AutoStartDuration=timedelta(0),
|
||||
AutoSkip=False,
|
||||
)
|
||||
Users: list[MultiplayerRoomUser]
|
||||
Host: MultiplayerRoomUser
|
||||
MatchState: MatchRoomState | None
|
||||
Playlist: list[MultiPlayerListItem]
|
||||
ActivecCountDowns: list[MultiplayerCountdownType]
|
||||
ChannelID: int
|
||||
|
||||
@classmethod
|
||||
def CanAddPlayistItem(cls, user: MultiplayerRoomUser) -> bool:
|
||||
return user == cls.Host or cls.Settings.QueueMode != QueueMode.HOST_ONLY
|
||||
|
||||
@classmethod
|
||||
async def from_apiRoom(cls, room: Room, db: AsyncSession, fetcher: Fetcher):
|
||||
s = cls.model_validate(room.model_dump())
|
||||
s.RoomId = room.room_id if room.room_id is not None else 0
|
||||
s.ChannelID = room.channel_id
|
||||
s.Settings = MultiplayerRoomSettings.from_apiRoom(room)
|
||||
s.Host = await MultiplayerRoomUser.from_id(room.host.id if room.host else 0, db)
|
||||
s.Playlist = [
|
||||
await MultiPlayerListItem.from_apiItem(item, db, fetcher)
|
||||
for item in room.playlist
|
||||
]
|
||||
return s
|
||||
|
||||
|
||||
class Room(BaseModel):
|
||||
room_id: int
|
||||
name: str
|
||||
password: str | None
|
||||
has_password: bool = Field(exclude=True)
|
||||
host: User | None
|
||||
category: RoomCategory
|
||||
duration: int | None
|
||||
starts_at: datetime | None
|
||||
ends_at: datetime | None
|
||||
max_particapants: int | None = Field(exclude=True)
|
||||
particapant_count: int
|
||||
recent_particapants: list[User]
|
||||
type: MatchType
|
||||
max_attempts: int | None
|
||||
playlist: list[PlaylistItem]
|
||||
playlist_item_status: list[RoomPlaylistItemStats]
|
||||
difficulity_range: RoomDifficultyRange
|
||||
queue_mode: QueueMode
|
||||
auto_skip: bool
|
||||
auto_start_duration: int
|
||||
current_user_score: PlaylistAggregateScore | None
|
||||
current_playlist_item: PlaylistItem | None
|
||||
channel_id: int
|
||||
status: RoomStatus
|
||||
availability: RoomAvailability = Field(exclude=True)
|
||||
|
||||
class Config:
|
||||
exclude_none = True
|
||||
|
||||
@classmethod
|
||||
async def from_mpRoom(
|
||||
cls, room: MultiplayerRoom, db: AsyncSession, fetcher: Fetcher
|
||||
):
|
||||
s = cls.model_validate(room.model_dump())
|
||||
s.room_id = room.RoomId
|
||||
s.name = room.Settings.Name
|
||||
s.password = room.Settings.Password
|
||||
s.type = room.Settings.MatchType
|
||||
s.queue_mode = room.Settings.QueueMode
|
||||
s.auto_skip = room.Settings.AutoSkip
|
||||
s.host = room.Host.User
|
||||
s.playlist = [
|
||||
await PlaylistItem.from_mpListItem(item, db, fetcher)
|
||||
for item in room.Playlist
|
||||
]
|
||||
return s
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
from enum import Enum
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from .mods import API_MODS, APIMod, init_mods
|
||||
@@ -93,52 +93,14 @@ class HitResult(str, Enum):
|
||||
)
|
||||
|
||||
|
||||
class HitResultInt(IntEnum):
|
||||
PERFECT = 0
|
||||
GREAT = 1
|
||||
GOOD = 2
|
||||
OK = 3
|
||||
MEH = 4
|
||||
MISS = 5
|
||||
|
||||
LARGE_TICK_HIT = 6
|
||||
SMALL_TICK_HIT = 7
|
||||
SLIDER_TAIL_HIT = 8
|
||||
|
||||
LARGE_BONUS = 9
|
||||
SMALL_BONUS = 10
|
||||
|
||||
LARGE_TICK_MISS = 11
|
||||
SMALL_TICK_MISS = 12
|
||||
|
||||
IGNORE_HIT = 13
|
||||
IGNORE_MISS = 14
|
||||
|
||||
NONE = 15
|
||||
COMBO_BREAK = 16
|
||||
|
||||
LEGACY_COMBO_INCREASE = 99
|
||||
|
||||
def is_hit(self) -> bool:
|
||||
return self not in (
|
||||
HitResultInt.NONE,
|
||||
HitResultInt.IGNORE_MISS,
|
||||
HitResultInt.COMBO_BREAK,
|
||||
HitResultInt.LARGE_TICK_MISS,
|
||||
HitResultInt.SMALL_TICK_MISS,
|
||||
HitResultInt.MISS,
|
||||
)
|
||||
|
||||
|
||||
class LeaderboardType(Enum):
|
||||
GLOBAL = "global"
|
||||
FRIENDS = "friends"
|
||||
FRIENDS = "friend"
|
||||
COUNTRY = "country"
|
||||
TEAM = "team"
|
||||
|
||||
|
||||
ScoreStatistics = dict[HitResult, int]
|
||||
ScoreStatisticsInt = dict[HitResultInt, int]
|
||||
|
||||
|
||||
class SoloScoreSubmissionInfo(BaseModel):
|
||||
@@ -176,8 +138,8 @@ class SoloScoreSubmissionInfo(BaseModel):
|
||||
class LegacyReplaySoloScoreInfo(TypedDict):
|
||||
online_id: int
|
||||
mods: list[APIMod]
|
||||
statistics: ScoreStatisticsInt
|
||||
maximum_statistics: ScoreStatisticsInt
|
||||
statistics: ScoreStatistics
|
||||
maximum_statistics: ScoreStatistics
|
||||
client_version: str
|
||||
rank: Rank
|
||||
user_id: int
|
||||
|
||||
@@ -1,54 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import Any, get_origin
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
def serialize_to_list(value: BaseModel) -> list[Any]:
|
||||
data = []
|
||||
for field, info in value.__class__.model_fields.items():
|
||||
v = getattr(value, field)
|
||||
anno = get_origin(info.annotation)
|
||||
if anno and issubclass(anno, BaseModel):
|
||||
data.append(serialize_to_list(v))
|
||||
elif anno and issubclass(anno, list):
|
||||
data.append(
|
||||
TypeAdapter(
|
||||
info.annotation, config=ConfigDict(arbitrary_types_allowed=True)
|
||||
).dump_python(v)
|
||||
)
|
||||
elif isinstance(v, datetime.datetime):
|
||||
data.append([v, 0])
|
||||
else:
|
||||
data.append(v)
|
||||
return data
|
||||
@dataclass
|
||||
class SignalRMeta:
|
||||
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
|
||||
json_ignore: bool = False # implement of JsonIgnore (json) attribute
|
||||
use_abbr: bool = True
|
||||
|
||||
|
||||
class MessagePackArrayModel(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def unpack(cls, v: Any) -> Any:
|
||||
if isinstance(v, list):
|
||||
fields = list(cls.model_fields.keys())
|
||||
if len(v) != len(fields):
|
||||
raise ValueError(f"Expected list of length {len(fields)}, got {len(v)}")
|
||||
return dict(zip(fields, v))
|
||||
return v
|
||||
|
||||
@model_serializer
|
||||
def serialize(self) -> list[Any]:
|
||||
return serialize_to_list(self)
|
||||
class SignalRUnionMessage(BaseModel):
|
||||
union_type: ClassVar[int]
|
||||
|
||||
|
||||
class Transport(BaseModel):
|
||||
|
||||
@@ -2,17 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
from typing import Annotated, Any
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import APIMod
|
||||
|
||||
from .score import (
|
||||
ScoreStatisticsInt,
|
||||
ScoreStatistics,
|
||||
)
|
||||
from .signalr import MessagePackArrayModel, UserState
|
||||
from .signalr import SignalRMeta, UserState
|
||||
|
||||
from msgpack_lazer_api import APIMod
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SpectatedUserState(IntEnum):
|
||||
@@ -24,14 +24,12 @@ class SpectatedUserState(IntEnum):
|
||||
Quit = 5
|
||||
|
||||
|
||||
class SpectatorState(MessagePackArrayModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
class SpectatorState(BaseModel):
|
||||
beatmap_id: int | None = None
|
||||
ruleset_id: int | None = None # 0,1,2,3
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
state: SpectatedUserState
|
||||
maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict)
|
||||
maximum_statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SpectatorState):
|
||||
@@ -44,22 +42,20 @@ class SpectatorState(MessagePackArrayModel):
|
||||
)
|
||||
|
||||
|
||||
class ScoreProcessorStatistics(MessagePackArrayModel):
|
||||
base_score: int
|
||||
maximum_base_score: int
|
||||
class ScoreProcessorStatistics(BaseModel):
|
||||
base_score: float
|
||||
maximum_base_score: float
|
||||
accuracy_judgement_count: int
|
||||
combo_portion: float
|
||||
bouns_portion: float
|
||||
bonus_portion: float
|
||||
|
||||
|
||||
class FrameHeader(MessagePackArrayModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
class FrameHeader(BaseModel):
|
||||
total_score: int
|
||||
acc: float
|
||||
accuracy: float
|
||||
combo: int
|
||||
max_combo: int
|
||||
statistics: ScoreStatisticsInt = Field(default_factory=dict)
|
||||
statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
score_processor_statistics: ScoreProcessorStatistics
|
||||
received_time: datetime.datetime
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
@@ -87,14 +83,18 @@ class FrameHeader(MessagePackArrayModel):
|
||||
# SMOKE = 16
|
||||
|
||||
|
||||
class LegacyReplayFrame(MessagePackArrayModel):
|
||||
class LegacyReplayFrame(BaseModel):
|
||||
time: float # from ReplayFrame,the parent of LegacyReplayFrame
|
||||
x: float | None = None
|
||||
y: float | None = None
|
||||
mouse_x: float | None = None
|
||||
mouse_y: float | None = None
|
||||
button_state: int
|
||||
|
||||
header: Annotated[
|
||||
FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)
|
||||
]
|
||||
|
||||
class FrameDataBundle(MessagePackArrayModel):
|
||||
|
||||
class FrameDataBundle(BaseModel):
|
||||
header: FrameHeader
|
||||
frames: list[LegacyReplayFrame]
|
||||
|
||||
@@ -106,18 +106,16 @@ class APIUser(BaseModel):
|
||||
|
||||
|
||||
class ScoreInfo(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
mods: list[APIMod]
|
||||
user: APIUser
|
||||
ruleset: int
|
||||
maximum_statistics: ScoreStatisticsInt
|
||||
maximum_statistics: ScoreStatistics
|
||||
id: int | None = None
|
||||
total_score: int | None = None
|
||||
acc: float | None = None
|
||||
accuracy: float | None = None
|
||||
max_combo: int | None = None
|
||||
combo: int | None = None
|
||||
statistics: ScoreStatisticsInt = Field(default_factory=dict)
|
||||
statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StoreScore(BaseModel):
|
||||
|
||||
@@ -2,15 +2,11 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .score import GameMode
|
||||
from .model import UTCBaseModel
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database import LazerUserAchievement, Team
|
||||
|
||||
|
||||
class PlayStyle(str, Enum):
|
||||
MOUSE = "mouse"
|
||||
@@ -77,24 +73,7 @@ class MonthlyPlaycount(BaseModel):
|
||||
count: int
|
||||
|
||||
|
||||
class UserAchievement(BaseModel):
|
||||
achieved_at: datetime
|
||||
achievement_id: int
|
||||
|
||||
# 添加数据库模型转换方法
|
||||
def to_db_model(self, user_id: int) -> "LazerUserAchievement":
|
||||
from app.database import (
|
||||
LazerUserAchievement,
|
||||
)
|
||||
|
||||
return LazerUserAchievement(
|
||||
user_id=user_id,
|
||||
achievement_id=self.achievement_id,
|
||||
achieved_at=self.achieved_at,
|
||||
)
|
||||
|
||||
|
||||
class RankHighest(BaseModel):
|
||||
class RankHighest(UTCBaseModel):
|
||||
rank: int
|
||||
updated_at: datetime
|
||||
|
||||
@@ -104,115 +83,6 @@ class RankHistory(BaseModel):
|
||||
data: list[int]
|
||||
|
||||
|
||||
class DailyChallengeStats(BaseModel):
|
||||
daily_streak_best: int = 0
|
||||
daily_streak_current: int = 0
|
||||
last_update: datetime | None = None
|
||||
last_weekly_streak: datetime | None = None
|
||||
playcount: int = 0
|
||||
top_10p_placements: int = 0
|
||||
top_50p_placements: int = 0
|
||||
user_id: int
|
||||
weekly_streak_best: int = 0
|
||||
weekly_streak_current: int = 0
|
||||
|
||||
|
||||
class Page(BaseModel):
|
||||
html: str = ""
|
||||
raw: str = ""
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
# 基本信息
|
||||
id: int
|
||||
username: str
|
||||
avatar_url: str
|
||||
country_code: str
|
||||
default_group: str = "default"
|
||||
is_active: bool = True
|
||||
is_bot: bool = False
|
||||
is_deleted: bool = False
|
||||
is_online: bool = True
|
||||
is_supporter: bool = False
|
||||
is_restricted: bool = False
|
||||
last_visit: datetime | None = None
|
||||
pm_friends_only: bool = False
|
||||
profile_colour: str | None = None
|
||||
|
||||
# 个人资料
|
||||
cover_url: str | None = None
|
||||
discord: str | None = None
|
||||
has_supported: bool = False
|
||||
interests: str | None = None
|
||||
join_date: datetime
|
||||
location: str | None = None
|
||||
max_blocks: int = 100
|
||||
max_friends: int = 500
|
||||
occupation: str | None = None
|
||||
playmode: GameMode = GameMode.OSU
|
||||
playstyle: list[PlayStyle] = []
|
||||
post_count: int = 0
|
||||
profile_hue: int | None = None
|
||||
profile_order: list[str] = [
|
||||
"me",
|
||||
"recent_activity",
|
||||
"top_ranks",
|
||||
"medals",
|
||||
"historical",
|
||||
"beatmaps",
|
||||
"kudosu",
|
||||
]
|
||||
title: str | None = None
|
||||
title_url: str | None = None
|
||||
twitter: str | None = None
|
||||
website: str | None = None
|
||||
session_verified: bool = False
|
||||
support_level: int = 0
|
||||
|
||||
# 关联对象
|
||||
country: Country
|
||||
cover: Cover
|
||||
kudosu: Kudosu
|
||||
statistics: Statistics
|
||||
statistics_rulesets: dict[str, Statistics]
|
||||
|
||||
# 计数信息
|
||||
beatmap_playcounts_count: int = 0
|
||||
comments_count: int = 0
|
||||
favourite_beatmapset_count: int = 0
|
||||
follower_count: int = 0
|
||||
graveyard_beatmapset_count: int = 0
|
||||
guest_beatmapset_count: int = 0
|
||||
loved_beatmapset_count: int = 0
|
||||
mapping_follower_count: int = 0
|
||||
nominated_beatmapset_count: int = 0
|
||||
pending_beatmapset_count: int = 0
|
||||
ranked_beatmapset_count: int = 0
|
||||
ranked_and_approved_beatmapset_count: int = 0
|
||||
unranked_beatmapset_count: int = 0
|
||||
scores_best_count: int = 0
|
||||
scores_first_count: int = 0
|
||||
scores_pinned_count: int = 0
|
||||
scores_recent_count: int = 0
|
||||
|
||||
# 历史数据
|
||||
account_history: list[dict] = []
|
||||
active_tournament_banner: dict | None = None
|
||||
active_tournament_banners: list[dict] = []
|
||||
badges: list[dict] = []
|
||||
current_season_stats: dict | None = None
|
||||
daily_challenge_user_stats: DailyChallengeStats | None = None
|
||||
groups: list[dict] = []
|
||||
monthly_playcounts: list[MonthlyPlaycount] = []
|
||||
page: Page = Page()
|
||||
previous_usernames: list[str] = []
|
||||
rank_highest: RankHighest | None = None
|
||||
rank_history: RankHistory | None = None
|
||||
rankHistory: RankHistory | None = None # 兼容性别名
|
||||
replays_watched_counts: list[dict] = []
|
||||
team: "Team | None" = None
|
||||
user_achievements: list[UserAchievement] = []
|
||||
|
||||
|
||||
class APIUser(BaseModel):
|
||||
id: int
|
||||
|
||||
@@ -7,6 +7,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
beatmapset,
|
||||
me,
|
||||
relationship,
|
||||
room,
|
||||
score,
|
||||
user,
|
||||
)
|
||||
@@ -14,4 +15,9 @@ from .api_router import router as api_router
|
||||
from .auth import router as auth_router
|
||||
from .fetcher import fetcher_router as fetcher_router
|
||||
|
||||
__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"]
|
||||
__all__ = [
|
||||
"api_router",
|
||||
"auth_router",
|
||||
"fetcher_router",
|
||||
"signalr_router",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import re
|
||||
|
||||
from app.auth import (
|
||||
@@ -12,17 +12,21 @@ from app.auth import (
|
||||
store_token,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.database import User as DBUser
|
||||
from app.database import DailyChallengeStats, User
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies import get_db
|
||||
from app.log import logger
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
RegistrationRequestErrors,
|
||||
TokenResponse,
|
||||
UserRegistrationErrors,
|
||||
)
|
||||
from app.models.score import GameMode
|
||||
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -110,12 +114,12 @@ async def register_user(
|
||||
email_errors = validate_email(user_email)
|
||||
password_errors = validate_password(user_password)
|
||||
|
||||
result = await db.exec(select(DBUser).where(DBUser.name == user_username))
|
||||
result = await db.exec(select(User).where(User.username == user_username))
|
||||
existing_user = result.first()
|
||||
if existing_user:
|
||||
username_errors.append("Username is already taken")
|
||||
|
||||
result = await db.exec(select(DBUser).where(DBUser.email == user_email))
|
||||
result = await db.exec(select(User).where(User.email == user_email))
|
||||
existing_email = result.first()
|
||||
if existing_email:
|
||||
email_errors.append("Email is already taken")
|
||||
@@ -135,119 +139,41 @@ async def register_user(
|
||||
|
||||
try:
|
||||
# 创建新用户
|
||||
from datetime import datetime
|
||||
import time
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
result = await db.execute( # pyright: ignore[reportDeprecated]
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
|
||||
)
|
||||
)
|
||||
next_id = result.one()[0]
|
||||
if next_id <= 2:
|
||||
await db.execute(text("ALTER TABLE lazer_users AUTO_INCREMENT = 3"))
|
||||
await db.commit()
|
||||
|
||||
new_user = DBUser(
|
||||
name=user_username,
|
||||
safe_name=user_username.lower(), # 安全用户名(小写)
|
||||
new_user = User(
|
||||
username=user_username,
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1, # 普通用户权限
|
||||
country="CN", # 默认国家
|
||||
creation_time=int(time.time()),
|
||||
latest_activity=int(time.time()),
|
||||
preferred_mode=0, # 默认模式
|
||||
play_style=0, # 默认游戏风格
|
||||
country_code="CN", # 默认国家
|
||||
join_date=datetime.now(UTC),
|
||||
last_visit=datetime.now(UTC),
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
|
||||
# 保存用户ID,因为会话可能会关闭
|
||||
user_id = new_user.id
|
||||
|
||||
if user_id <= 2:
|
||||
await db.rollback()
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
|
||||
await db.commit()
|
||||
|
||||
# 重新创建用户
|
||||
new_user = DBUser(
|
||||
name=user_username,
|
||||
safe_name=user_username.lower(),
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1,
|
||||
country="CN",
|
||||
creation_time=int(time.time()),
|
||||
latest_activity=int(time.time()),
|
||||
preferred_mode=0,
|
||||
play_style=0,
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
user_id = new_user.id
|
||||
|
||||
# 最终检查ID是否有效
|
||||
if user_id <= 2:
|
||||
await db.rollback()
|
||||
errors = RegistrationRequestErrors(
|
||||
message=(
|
||||
"Failed to create account with valid ID. "
|
||||
"Please contact support."
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
except Exception as fix_error:
|
||||
await db.rollback()
|
||||
print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
|
||||
errors = RegistrationRequestErrors(
|
||||
message="Failed to create account with valid ID. Please try again."
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
# 创建默认的 lazer_profile
|
||||
from app.database.user import LazerUserProfile
|
||||
|
||||
lazer_profile = LazerUserProfile(
|
||||
user_id=user_id,
|
||||
is_active=True,
|
||||
is_bot=False,
|
||||
is_deleted=False,
|
||||
is_online=True,
|
||||
is_supporter=False,
|
||||
is_restricted=False,
|
||||
session_verified=False,
|
||||
has_supported=False,
|
||||
pm_friends_only=False,
|
||||
default_group="default",
|
||||
join_date=datetime.utcnow(),
|
||||
playmode="osu",
|
||||
support_level=0,
|
||||
max_blocks=50,
|
||||
max_friends=250,
|
||||
post_count=0,
|
||||
)
|
||||
|
||||
db.add(lazer_profile)
|
||||
assert new_user.id is not None, "New user ID should not be None"
|
||||
for i in GameMode:
|
||||
statistics = UserStatistics(mode=i, user_id=new_user.id)
|
||||
db.add(statistics)
|
||||
daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id)
|
||||
db.add(daily_challenge_user_stats)
|
||||
await db.commit()
|
||||
|
||||
# 返回成功响应
|
||||
return JSONResponse(
|
||||
status_code=201,
|
||||
content={"message": "Account created successfully", "user_id": user_id},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
# 打印详细错误信息用于调试
|
||||
print(f"Registration error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.exception(f"Registration error for user {user_username}")
|
||||
|
||||
# 返回通用错误
|
||||
errors = RegistrationRequestErrors(
|
||||
@@ -323,6 +249,7 @@ async def oauth_token(
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
assert user.id
|
||||
await store_token(
|
||||
db,
|
||||
user.id,
|
||||
|
||||
@@ -5,12 +5,7 @@ import hashlib
|
||||
import json
|
||||
|
||||
from app.calculator import calculate_beatmap_attribute
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
BeatmapResp,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database import Beatmap, BeatmapResp, User
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
@@ -27,9 +22,8 @@ from .api_router import router
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from httpx import HTTPError, HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis
|
||||
import rosu_pp_py as rosu
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -39,7 +33,7 @@ async def lookup_beatmap(
|
||||
id: int | None = Query(default=None, alias="id"),
|
||||
md5: str | None = Query(default=None, alias="checksum"),
|
||||
filename: str | None = Query(default=None, alias="filename"),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -56,19 +50,19 @@ async def lookup_beatmap(
|
||||
if beatmap is None:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
|
||||
|
||||
|
||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||
async def get_beatmap(
|
||||
bid: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
@@ -80,43 +74,28 @@ class BatchGetResp(BaseModel):
|
||||
@router.get("/beatmaps", tags=["beatmap"], response_model=BatchGetResp)
|
||||
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
|
||||
async def batch_get_beatmaps(
|
||||
b_ids: list[int] = Query(alias="id", default_factory=list),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
b_ids: list[int] = Query(alias="ids[]", default_factory=list),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not b_ids:
|
||||
# select 50 beatmaps by last_updated
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.order_by(col(Beatmap.last_updated).desc())
|
||||
.limit(50)
|
||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(col(Beatmap.id).in_(b_ids))
|
||||
.limit(50)
|
||||
)
|
||||
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
|
||||
).all()
|
||||
|
||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||
return BatchGetResp(
|
||||
beatmaps=[
|
||||
await BeatmapResp.from_db(bm, session=db, user=current_user)
|
||||
for bm in beatmaps
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -126,7 +105,7 @@ async def batch_get_beatmaps(
|
||||
)
|
||||
async def get_beatmap_attributes(
|
||||
beatmap: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
mods: list[str] = Query(default_factory=list),
|
||||
ruleset: GameMode | None = Query(default=None),
|
||||
ruleset_id: int | None = Query(default=None),
|
||||
@@ -153,8 +132,8 @@ async def get_beatmap_attributes(
|
||||
f"beatmap:{beatmap}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
)
|
||||
if redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
try:
|
||||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
|
||||
@@ -164,7 +143,7 @@ async def get_beatmap_attributes(
|
||||
)
|
||||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
redis.set(key, attr.model_dump_json())
|
||||
await redis.set(key, attr.model_dump_json())
|
||||
return attr
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import (
|
||||
Beatmapset,
|
||||
BeatmapsetResp,
|
||||
User as DBUser,
|
||||
)
|
||||
from typing import Literal
|
||||
|
||||
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
@@ -12,27 +10,47 @@ from app.fetcher import Fetcher
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from fastapi.responses import RedirectResponse
|
||||
from httpx import HTTPStatusError
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||
async def lookup_beatmapset(
|
||||
beatmap_id: int = Query(),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmapset_id = (
|
||||
await db.exec(select(Beatmap.beatmapset_id).where(Beatmap.id == beatmap_id))
|
||||
).first()
|
||||
if not beatmapset_id:
|
||||
try:
|
||||
resp = await fetcher.get_beatmap(beatmap_id)
|
||||
await Beatmap.from_resp(db, resp)
|
||||
await db.refresh(current_user)
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
beatmapset = (
|
||||
await db.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))
|
||||
).first()
|
||||
if not beatmapset:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
resp = await BeatmapsetResp.from_db(beatmapset, session=db, user=current_user)
|
||||
return resp
|
||||
|
||||
|
||||
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||
async def get_beatmapset(
|
||||
sid: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmapset = (
|
||||
await db.exec(
|
||||
select(Beatmapset)
|
||||
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||
.where(Beatmapset.id == sid)
|
||||
)
|
||||
).first()
|
||||
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
|
||||
if not beatmapset:
|
||||
try:
|
||||
resp = await fetcher.get_beatmapset(sid)
|
||||
@@ -40,5 +58,55 @@ async def get_beatmapset(
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
else:
|
||||
resp = BeatmapsetResp.from_db(beatmapset)
|
||||
resp = await BeatmapsetResp.from_db(
|
||||
beatmapset, session=db, include=["recent_favourites"], user=current_user
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])
|
||||
async def download_beatmapset(
|
||||
beatmapset: int,
|
||||
no_video: bool = Query(True, alias="noVideo"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if current_user.country_code == "CN":
|
||||
return RedirectResponse(
|
||||
f"https://txy1.sayobot.cn/beatmaps/download/"
|
||||
f"{'novideo' if no_video else 'full'}/{beatmapset}?server=auto"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"])
|
||||
async def favourite_beatmapset(
|
||||
beatmapset: int,
|
||||
action: Literal["favourite", "unfavourite"] = Form(),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
existing_favourite = (
|
||||
await db.exec(
|
||||
select(FavouriteBeatmapset).where(
|
||||
FavouriteBeatmapset.user_id == current_user.id,
|
||||
FavouriteBeatmapset.beatmapset_id == beatmapset,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
if action == "favourite" and existing_favourite:
|
||||
raise HTTPException(status_code=400, detail="Already favourited")
|
||||
elif action == "unfavourite" and not existing_favourite:
|
||||
raise HTTPException(status_code=400, detail="Not favourited")
|
||||
|
||||
if action == "favourite":
|
||||
favourite = FavouriteBeatmapset(
|
||||
user_id=current_user.id, beatmapset_id=beatmapset
|
||||
)
|
||||
db.add(favourite)
|
||||
else:
|
||||
await db.delete(existing_favourite)
|
||||
await db.commit()
|
||||
|
||||
@@ -1,28 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database import User, UserResp
|
||||
from app.database.lazer_user import ALL_INCLUDED
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user import (
|
||||
User as ApiUser,
|
||||
)
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/me/{ruleset}", response_model=ApiUser)
|
||||
@router.get("/me/", response_model=ApiUser)
|
||||
@router.get("/me/{ruleset}", response_model=UserResp)
|
||||
@router.get("/me/", response_model=UserResp)
|
||||
async def get_user_info_default(
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
ruleset: GameMode | None = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取当前用户信息(默认使用osu模式)"""
|
||||
# 默认使用osu模式
|
||||
api_user = await convert_db_user_to_api_user(current_user, ruleset)
|
||||
return api_user
|
||||
return await UserResp.from_db(
|
||||
current_user,
|
||||
session,
|
||||
ALL_INCLUDED,
|
||||
ruleset,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from app.dependencies.user import get_current_user
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Request
|
||||
from sqlalchemy.orm import joinedload
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -26,17 +26,19 @@ async def get_relationship(
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationships = await db.exec(
|
||||
select(Relationship)
|
||||
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user.id,
|
||||
Relationship.type == relationship_type,
|
||||
)
|
||||
)
|
||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
|
||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
|
||||
|
||||
|
||||
@router.post("/friends", tags=["relationship"], response_model=RelationshipResp)
|
||||
class AddFriendResp(BaseModel):
|
||||
user_relation: RelationshipResp
|
||||
|
||||
|
||||
@router.post("/friends", tags=["relationship"], response_model=AddFriendResp)
|
||||
@router.post("/blocks", tags=["relationship"])
|
||||
async def add_relationship(
|
||||
request: Request,
|
||||
@@ -87,14 +89,10 @@ async def add_relationship(
|
||||
if origin_type == RelationshipType.FOLLOW:
|
||||
relationship = (
|
||||
await db.exec(
|
||||
select(Relationship)
|
||||
.where(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user_id,
|
||||
Relationship.target_id == target,
|
||||
)
|
||||
.options(
|
||||
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
).first()
|
||||
assert relationship, "Relationship should exist after commit"
|
||||
|
||||
@@ -1,106 +1,296 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.room import RoomIndex
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
|
||||
from app.database.beatmap import Beatmap, BeatmapResp
|
||||
from app.database.beatmapset import BeatmapsetResp
|
||||
from app.database.lazer_user import User, UserResp
|
||||
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||
from app.database.playlist_attempts import ItemAttemptsCount, ItemAttemptsResp
|
||||
from app.database.playlists import Playlist, PlaylistResp
|
||||
from app.database.room import Room, RoomBase, RoomResp
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.room import MultiplayerRoom, MultiplayerRoomState, Room
|
||||
from app.models.multiplayer_hub import (
|
||||
MultiplayerRoom,
|
||||
MultiplayerRoomUser,
|
||||
ServerMultiplayerRoom,
|
||||
)
|
||||
from app.models.room import RoomStatus
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from api_router import router
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from sqlmodel import select
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/rooms", tags=["rooms"], response_model=list[Room])
|
||||
@router.get("/rooms", tags=["rooms"], response_model=list[RoomResp])
|
||||
async def get_all_rooms(
|
||||
mode: str | None = Query(None), # TODO: 对房间根据状态进行筛选
|
||||
status: str | None = Query(None),
|
||||
category: str | None = Query(
|
||||
None
|
||||
), # TODO: 对房间根据分类进行筛选(真的有人用这功能吗)
|
||||
mode: Literal["open", "ended", "participated", "owned", None] = Query(
|
||||
default="open"
|
||||
), # TODO: 对房间根据状态进行筛选
|
||||
category: str = Query(default="realtime"), # TODO
|
||||
status: RoomStatus | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
all_roomID = (await db.exec(select(RoomIndex))).all()
|
||||
redis = get_redis()
|
||||
if redis is not None:
|
||||
resp: list[Room] = []
|
||||
for id in all_roomID:
|
||||
dumped_room = redis.get(str(id))
|
||||
validated_room = MultiplayerRoom.model_validate_json(str(dumped_room))
|
||||
flag: bool = False
|
||||
if status is not None:
|
||||
if (
|
||||
validated_room.State == MultiplayerRoomState.OPEN
|
||||
and status == "idle"
|
||||
):
|
||||
flag = True
|
||||
elif validated_room != MultiplayerRoomState.CLOSED:
|
||||
flag = True
|
||||
if flag:
|
||||
resp.append(
|
||||
await Room.from_mpRoom(
|
||||
MultiplayerRoom.model_validate_json(str(dumped_room)),
|
||||
db,
|
||||
fetcher,
|
||||
)
|
||||
)
|
||||
return resp
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Redis Error")
|
||||
rooms = MultiplayerHubs.rooms.values()
|
||||
resp_list: list[RoomResp] = []
|
||||
for room in rooms:
|
||||
# if category == "realtime" and room.category != "normal":
|
||||
# continue
|
||||
# elif category != room.category and category != "":
|
||||
# continue
|
||||
resp_list.append(await RoomResp.from_hub(room))
|
||||
return resp_list
|
||||
|
||||
|
||||
@router.get("/rooms/{room}", tags=["room"], response_model=Room)
|
||||
class APICreatedRoom(RoomResp):
|
||||
error: str = ""
|
||||
|
||||
|
||||
class APIUploadedRoom(RoomBase):
|
||||
def to_room(self) -> Room:
|
||||
"""
|
||||
将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。
|
||||
"""
|
||||
room_dict = self.model_dump()
|
||||
room_dict.pop("playlist", None)
|
||||
# host_id 已在字段中
|
||||
return Room(**room_dict)
|
||||
|
||||
id: int | None
|
||||
host_id: int | None = None
|
||||
playlist: list[Playlist]
|
||||
|
||||
|
||||
@router.post("/rooms", tags=["room"], response_model=APICreatedRoom)
|
||||
async def create_room(
|
||||
room: APIUploadedRoom,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
# db_room = Room.from_resp(room)
|
||||
await db.refresh(current_user)
|
||||
user_id = current_user.id
|
||||
db_room = room.to_room()
|
||||
db_room.host_id = current_user.id if current_user.id else 1
|
||||
db.add(db_room)
|
||||
await db.commit()
|
||||
await db.refresh(db_room)
|
||||
|
||||
playlist: list[Playlist] = []
|
||||
# 处理 APIUploadedRoom 里的 playlist 字段
|
||||
for item in room.playlist:
|
||||
# 确保 room_id 正确赋值
|
||||
item.id = await Playlist.get_next_id_for_room(db_room.id, db)
|
||||
item.room_id = db_room.id
|
||||
item.owner_id = user_id if user_id else 1
|
||||
db.add(item)
|
||||
await db.commit()
|
||||
await db.refresh(item)
|
||||
playlist.append(item)
|
||||
await db.refresh(db_room)
|
||||
db_room.playlist = playlist
|
||||
server_room = ServerMultiplayerRoom(
|
||||
room=MultiplayerRoom.from_db(db_room),
|
||||
category=db_room.category,
|
||||
start_at=datetime.now(UTC),
|
||||
hub=MultiplayerHubs,
|
||||
)
|
||||
MultiplayerHubs.rooms[db_room.id] = server_room
|
||||
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room))
|
||||
created_room.error = ""
|
||||
return created_room
|
||||
|
||||
|
||||
@router.get("/rooms/{room}", tags=["room"], response_model=RoomResp)
|
||||
async def get_room(
|
||||
room: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
redis = get_redis()
|
||||
if redis:
|
||||
dumped_room = str(redis.get(str(room)))
|
||||
if dumped_room is not None:
|
||||
resp = await Room.from_mpRoom(
|
||||
MultiplayerRoom.model_validate_json(str(dumped_room)), db, fetcher
|
||||
)
|
||||
return resp
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Room Not Found")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Redis error")
|
||||
|
||||
|
||||
class APICreatedRoom(Room):
|
||||
error: str | None
|
||||
|
||||
|
||||
@router.post("/rooms", tags=["beatmap"], response_model=APICreatedRoom)
|
||||
async def create_room(
|
||||
room: Room,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
redis = get_redis()
|
||||
if redis:
|
||||
room_index = RoomIndex()
|
||||
db.add(room_index)
|
||||
await db.commit()
|
||||
await db.refresh(room_index)
|
||||
server_room = await MultiplayerRoom.from_apiRoom(room, db, fetcher)
|
||||
redis.set(str(room_index.id), server_room.model_dump_json())
|
||||
room.room_id = room_index.id
|
||||
return APICreatedRoom(**room.model_dump(), error=None)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="redis error")
|
||||
server_room = MultiplayerHubs.rooms[room]
|
||||
return await RoomResp.from_hub(server_room)
|
||||
|
||||
|
||||
@router.delete("/rooms/{room}", tags=["room"])
|
||||
async def remove_room(room: int, db: AsyncSession = Depends(get_db)):
|
||||
redis = get_redis()
|
||||
if redis:
|
||||
redis.delete(str(room))
|
||||
room_index = await db.get(RoomIndex, room)
|
||||
if room_index:
|
||||
await db.delete(room_index)
|
||||
async def delete_room(room: int, db: AsyncSession = Depends(get_db)):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
else:
|
||||
await db.delete(db_room)
|
||||
return None
|
||||
|
||||
|
||||
@router.put("/rooms/{room}/users/{user}", tags=["room"])
|
||||
async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_db)):
|
||||
server_room = MultiplayerHubs.rooms[room]
|
||||
server_room.room.users.append(MultiplayerRoomUser(user_id=user))
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is not None:
|
||||
db_room.participant_count += 1
|
||||
await db.commit()
|
||||
resp = await RoomResp.from_hub(server_room)
|
||||
await db.refresh(db_room)
|
||||
for item in db_room.playlist:
|
||||
resp.playlist.append(await PlaylistResp.from_db(item, ["beatmap"]))
|
||||
return resp
|
||||
else:
|
||||
raise HTTPException(404, "room not found0")
|
||||
|
||||
|
||||
class APILeaderboard(BaseModel):
|
||||
leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
|
||||
user_score: ItemAttemptsResp | None = None
|
||||
|
||||
|
||||
@router.get("/rooms/{room}/leaderboard", tags=["room"], response_model=APILeaderboard)
|
||||
async def get_room_leaderboard(
|
||||
room: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
server_room = MultiplayerHubs.rooms[room]
|
||||
if not server_room:
|
||||
raise HTTPException(404, "Room not found")
|
||||
|
||||
aggs = await db.exec(
|
||||
select(ItemAttemptsCount)
|
||||
.where(ItemAttemptsCount.room_id == room)
|
||||
.order_by(col(ItemAttemptsCount.total_score).desc())
|
||||
)
|
||||
aggs_resp = []
|
||||
user_agg = None
|
||||
for i, agg in enumerate(aggs):
|
||||
resp = await ItemAttemptsResp.from_db(agg, db)
|
||||
resp.position = i + 1
|
||||
aggs_resp.append(resp)
|
||||
if agg.user_id == current_user.id:
|
||||
user_agg = resp
|
||||
return APILeaderboard(
|
||||
leaderboard=aggs_resp,
|
||||
user_score=user_agg,
|
||||
)
|
||||
|
||||
|
||||
class RoomEvents(BaseModel):
|
||||
beatmaps: list[BeatmapResp] = Field(default_factory=list)
|
||||
beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
|
||||
current_playlist_item_id: int = 0
|
||||
events: list[MultiplayerEventResp] = Field(default_factory=list)
|
||||
first_event_id: int = 0
|
||||
last_event_id: int = 0
|
||||
playlist_items: list[PlaylistResp] = Field(default_factory=list)
|
||||
room: RoomResp
|
||||
user: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get("/rooms/{room_id}/events", response_model=RoomEvents, tags=["room"])
|
||||
async def get_room_events(
|
||||
room_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
after: int | None = Query(None, ge=0),
|
||||
before: int | None = Query(None, ge=0),
|
||||
):
|
||||
events = (
|
||||
await db.exec(
|
||||
select(MultiplayerEvent)
|
||||
.where(
|
||||
MultiplayerEvent.room_id == room_id,
|
||||
col(MultiplayerEvent.id) > after if after is not None else True,
|
||||
col(MultiplayerEvent.id) < before if before is not None else True,
|
||||
)
|
||||
.order_by(col(MultiplayerEvent.id).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_ids = set()
|
||||
playlist_items = {}
|
||||
beatmap_ids = set()
|
||||
|
||||
event_resps = []
|
||||
first_event_id = 0
|
||||
last_event_id = 0
|
||||
|
||||
current_playlist_item_id = 0
|
||||
for event in events:
|
||||
event_resps.append(MultiplayerEventResp.from_db(event))
|
||||
|
||||
if event.user_id:
|
||||
user_ids.add(event.user_id)
|
||||
|
||||
if event.playlist_item_id is not None and (
|
||||
playitem := (
|
||||
await db.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == event.playlist_item_id,
|
||||
Playlist.room_id == room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
):
|
||||
current_playlist_item_id = playitem.id
|
||||
playlist_items[event.playlist_item_id] = playitem
|
||||
beatmap_ids.add(playitem.beatmap_id)
|
||||
scores = await db.exec(
|
||||
select(Score).where(
|
||||
Score.playlist_item_id == event.playlist_item_id,
|
||||
Score.room_id == room_id,
|
||||
)
|
||||
)
|
||||
for score in scores:
|
||||
user_ids.add(score.user_id)
|
||||
beatmap_ids.add(score.beatmap_id)
|
||||
|
||||
assert event.id is not None
|
||||
first_event_id = min(first_event_id, event.id)
|
||||
last_event_id = max(last_event_id, event.id)
|
||||
|
||||
if room := MultiplayerHubs.rooms.get(room_id):
|
||||
current_playlist_item_id = room.queue.current_item.id
|
||||
room_resp = await RoomResp.from_hub(room)
|
||||
else:
|
||||
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
room_resp = await RoomResp.from_db(room)
|
||||
|
||||
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||
user_resps = [await UserResp.from_db(user, db) for user in users]
|
||||
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
|
||||
beatmap_resps = [
|
||||
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
|
||||
]
|
||||
beatmapset_resps = {}
|
||||
for beatmap_resp in beatmap_resps:
|
||||
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
|
||||
|
||||
playlist_items_resps = [
|
||||
await PlaylistResp.from_db(item) for item in playlist_items.values()
|
||||
]
|
||||
|
||||
return RoomEvents(
|
||||
beatmaps=beatmap_resps,
|
||||
beatmapsets=beatmapset_resps,
|
||||
current_playlist_item_id=current_playlist_item_id,
|
||||
events=event_resps,
|
||||
first_event_id=first_event_id,
|
||||
last_event_id=last_event_id,
|
||||
playlist_items=playlist_items_resps,
|
||||
room=room_resp,
|
||||
user=user_resps,
|
||||
)
|
||||
|
||||
@@ -1,18 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
import time
|
||||
|
||||
from app.calculator import clamp
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
Beatmap,
|
||||
Playlist,
|
||||
Room,
|
||||
Score,
|
||||
ScoreResp,
|
||||
ScoreToken,
|
||||
ScoreTokenResp,
|
||||
User,
|
||||
)
|
||||
from app.database.playlist_attempts import ItemAttemptsCount
|
||||
from app.database.playlist_best_score import (
|
||||
PlaylistBestScore,
|
||||
get_position,
|
||||
process_playlist_best_score,
|
||||
)
|
||||
from app.database.score import (
|
||||
MultiplayerScores,
|
||||
ScoreAround,
|
||||
get_leaderboard,
|
||||
process_score,
|
||||
process_user,
|
||||
)
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.score import Score, ScoreResp, process_score, process_user
|
||||
from app.database.score_token import ScoreToken, ScoreTokenResp
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
LeaderboardType,
|
||||
Rank,
|
||||
SoloScoreSubmissionInfo,
|
||||
)
|
||||
@@ -21,11 +44,75 @@ from .api_router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select, true
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
READ_SCORE_TIMEOUT = 10
|
||||
|
||||
|
||||
async def submit_score(
|
||||
info: SoloScoreSubmissionInfo,
|
||||
beatmap: int,
|
||||
token: int,
|
||||
current_user: User,
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
item_id: int | None = None,
|
||||
room_id: int | None = None,
|
||||
):
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token)
|
||||
)
|
||||
).first()
|
||||
if not score_token or score_token.user_id != current_user.id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
if score_token.score_id:
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score).where(
|
||||
Score.id == score_token.score_id,
|
||||
Score.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
else:
|
||||
beatmap_status = (
|
||||
await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap))
|
||||
).first()
|
||||
if beatmap_status is None:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
ranked = beatmap_status in {
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
}
|
||||
score = await process_score(
|
||||
current_user,
|
||||
beatmap,
|
||||
ranked,
|
||||
score_token,
|
||||
info,
|
||||
fetcher,
|
||||
db,
|
||||
redis,
|
||||
)
|
||||
await db.refresh(current_user)
|
||||
score_id = score.id
|
||||
score_token.score_id = score_id
|
||||
await process_user(db, current_user, score, ranked)
|
||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
assert score is not None
|
||||
return await ScoreResp.from_db(db, score)
|
||||
|
||||
|
||||
class BeatmapScores(BaseModel):
|
||||
scores: list[ScoreResp]
|
||||
@@ -37,44 +124,26 @@ class BeatmapScores(BaseModel):
|
||||
)
|
||||
async def get_beatmap_scores(
|
||||
beatmap: int,
|
||||
mode: GameMode,
|
||||
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
|
||||
mode: GameMode | None = Query(None),
|
||||
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
|
||||
type: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
mods: list[str] = Query(default_factory=set, alias="mods[]"),
|
||||
type: LeaderboardType = Query(LeaderboardType.GLOBAL),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
|
||||
all_scores = (
|
||||
await db.exec(
|
||||
Score.select_clause_unique(
|
||||
Score.beatmap_id == beatmap,
|
||||
col(Score.passed).is_(True),
|
||||
Score.gamemode == mode if mode is not None else true(),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_score = (
|
||||
await db.exec(
|
||||
Score.select_clause_unique(
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == current_user.id,
|
||||
col(Score.passed).is_(True),
|
||||
Score.gamemode == mode if mode is not None else true(),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
all_scores, user_score = await get_leaderboard(
|
||||
db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods
|
||||
)
|
||||
|
||||
return BeatmapScores(
|
||||
scores=[await ScoreResp.from_db(db, score, score.user) for score in all_scores],
|
||||
userScore=await ScoreResp.from_db(db, user_score, user_score.user)
|
||||
if user_score
|
||||
else None,
|
||||
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
|
||||
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -94,7 +163,7 @@ async def get_user_beatmap_score(
|
||||
legacy_only: bool = Query(None),
|
||||
mode: str = Query(None),
|
||||
mods: str = Query(None), # TODO:添加mods筛选
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
@@ -103,7 +172,7 @@ async def get_user_beatmap_score(
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
Score.select_clause(True)
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == mode if mode is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
@@ -118,9 +187,10 @@ async def get_user_beatmap_score(
|
||||
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
|
||||
)
|
||||
else:
|
||||
resp = await ScoreResp.from_db(db, user_score)
|
||||
return BeatmapUserScore(
|
||||
position=user_score.position if user_score.position is not None else 0,
|
||||
score=await ScoreResp.from_db(db, user_score, user_score.user),
|
||||
position=resp.rank_global or 0,
|
||||
score=resp,
|
||||
)
|
||||
|
||||
|
||||
@@ -134,7 +204,7 @@ async def get_user_all_beatmap_scores(
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
ruleset: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
@@ -143,7 +213,7 @@ async def get_user_all_beatmap_scores(
|
||||
)
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
Score.select_clause()
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == ruleset if ruleset is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
@@ -153,9 +223,7 @@ async def get_user_all_beatmap_scores(
|
||||
)
|
||||
).all()
|
||||
|
||||
return [
|
||||
await ScoreResp.from_db(db, score, current_user) for score in all_user_scores
|
||||
]
|
||||
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -166,9 +234,10 @@ async def create_solo_score(
|
||||
version_hash: str = Form(""),
|
||||
beatmap_hash: str = Form(),
|
||||
ruleset_id: int = Form(..., ge=0, le=3),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
assert current_user.id
|
||||
async with db:
|
||||
score_token = ScoreToken(
|
||||
user_id=current_user.id,
|
||||
@@ -190,64 +259,275 @@ async def submit_solo_score(
|
||||
beatmap: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
):
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
async with db:
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token, ScoreToken.user_id == current_user.id)
|
||||
return await submit_score(info, beatmap, token, current_user, db, redis, fetcher)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp
|
||||
)
|
||||
async def create_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
beatmap_id: int = Form(),
|
||||
beatmap_hash: str = Form(),
|
||||
ruleset_id: int = Form(..., ge=0, le=3),
|
||||
version_hash: str = Form(""),
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
if room.ended_at and room.ended_at < datetime.now(UTC):
|
||||
raise HTTPException(status_code=400, detail="Room has ended")
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
).first()
|
||||
if not score_token or score_token.user_id != current_user.id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
if score_token.score_id:
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
Score.id == score_token.score_id,
|
||||
Score.user_id == current_user.id,
|
||||
)
|
||||
).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist not found")
|
||||
|
||||
# validate
|
||||
if not item.freestyle:
|
||||
if item.ruleset_id != ruleset_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Ruleset mismatch in playlist item"
|
||||
)
|
||||
if item.beatmap_id != beatmap_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Beatmap ID mismatch in playlist item"
|
||||
)
|
||||
agg = await session.exec(
|
||||
select(ItemAttemptsCount).where(
|
||||
ItemAttemptsCount.room_id == room_id,
|
||||
ItemAttemptsCount.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
agg = agg.first()
|
||||
if agg and room.max_attempts and agg.attempts >= room.max_attempts:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="You have reached the maximum attempts for this room",
|
||||
)
|
||||
if item.expired:
|
||||
raise HTTPException(status_code=400, detail="Playlist item has expired")
|
||||
if item.played_at:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Playlist item has already been played"
|
||||
)
|
||||
# 这里应该不用验证mod了吧。。。
|
||||
|
||||
score_token = ScoreToken(
|
||||
user_id=current_user.id,
|
||||
beatmap_id=beatmap_id,
|
||||
ruleset_id=INT_TO_MODE[ruleset_id],
|
||||
playlist_item_id=playlist_id,
|
||||
)
|
||||
session.add(score_token)
|
||||
await session.commit()
|
||||
await session.refresh(score_token)
|
||||
return ScoreTokenResp.from_db(score_token)
|
||||
|
||||
|
||||
@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}")
|
||||
async def submit_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist item not found")
|
||||
|
||||
user_id = current_user.id
|
||||
score_resp = await submit_score(
|
||||
info,
|
||||
item.beatmap_id,
|
||||
token,
|
||||
current_user,
|
||||
session,
|
||||
redis,
|
||||
fetcher,
|
||||
item.id,
|
||||
room_id,
|
||||
)
|
||||
await process_playlist_best_score(
|
||||
room_id,
|
||||
playlist_id,
|
||||
user_id,
|
||||
score_resp.id,
|
||||
score_resp.total_score,
|
||||
session,
|
||||
redis,
|
||||
)
|
||||
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
|
||||
return score_resp
|
||||
|
||||
|
||||
class IndexedScoreResp(MultiplayerScores):
|
||||
total: int
|
||||
user_score: ScoreResp | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=IndexedScoreResp
|
||||
)
|
||||
async def index_playlist_scores(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
limit: int = 50,
|
||||
cursor: int = Query(2000000, alias="cursor[total_score]"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
limit = clamp(limit, 1, 50)
|
||||
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore)
|
||||
.where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
PlaylistBestScore.total_score < cursor,
|
||||
)
|
||||
.order_by(col(PlaylistBestScore.total_score).desc())
|
||||
.limit(limit + 1)
|
||||
)
|
||||
).all()
|
||||
has_more = len(scores) > limit
|
||||
if has_more:
|
||||
scores = scores[:-1]
|
||||
|
||||
user_score = None
|
||||
score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
|
||||
for score in score_resp:
|
||||
score.position = await get_position(room_id, playlist_id, score.id, session)
|
||||
if score.user_id == current_user.id:
|
||||
user_score = score
|
||||
resp = IndexedScoreResp(
|
||||
scores=score_resp,
|
||||
user_score=user_score,
|
||||
total=len(scores),
|
||||
params={
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
if has_more:
|
||||
resp.cursor = {
|
||||
"total_score": scores[-1].total_score,
|
||||
}
|
||||
return resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
|
||||
response_model=ScoreResp,
|
||||
)
|
||||
async def show_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
score_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
start_time = time.time()
|
||||
score_record = None
|
||||
completed = False
|
||||
while time.time() - start_time < READ_SCORE_TIMEOUT:
|
||||
if score_record is None:
|
||||
score_record = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.score_id == score_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
else:
|
||||
beatmap_status = (
|
||||
await db.exec(
|
||||
select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
|
||||
if completed_players := await redis.get(
|
||||
f"multiplayer:{room_id}:gameplay:players"
|
||||
):
|
||||
completed = completed_players == "0"
|
||||
if score_record and completed:
|
||||
break
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(room_id, playlist_id, score_id, session)
|
||||
if completed:
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
).first()
|
||||
if beatmap_status is None:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
ranked = beatmap_status in {
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
}
|
||||
score = await process_score(
|
||||
current_user,
|
||||
beatmap,
|
||||
ranked,
|
||||
score_token,
|
||||
info,
|
||||
fetcher,
|
||||
db,
|
||||
redis,
|
||||
)
|
||||
await db.refresh(current_user)
|
||||
score_id = score.id
|
||||
score_token.score_id = score_id
|
||||
await process_user(db, current_user, score, ranked)
|
||||
score = (
|
||||
await db.exec(Score.select_clause().where(Score.id == score_id))
|
||||
).first()
|
||||
assert score is not None
|
||||
return await ScoreResp.from_db(db, score, current_user)
|
||||
).all()
|
||||
higher_scores = []
|
||||
lower_scores = []
|
||||
for score in scores:
|
||||
if score.total_score > resp.total_score:
|
||||
higher_scores.append(await ScoreResp.from_db(session, score.score))
|
||||
elif score.total_score < resp.total_score:
|
||||
lower_scores.append(await ScoreResp.from_db(session, score.score))
|
||||
resp.scores_around = ScoreAround(
|
||||
higher=MultiplayerScores(scores=higher_scores),
|
||||
lower=MultiplayerScores(scores=lower_scores),
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
|
||||
response_model=ScoreResp,
|
||||
)
|
||||
async def get_user_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
score_record = None
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < READ_SCORE_TIMEOUT:
|
||||
score_record = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.user_id == user_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if score_record:
|
||||
break
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(
|
||||
room_id, playlist_id, score_record.score_id, session
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.database import User, UserResp
|
||||
from app.database.lazer_user import SEARCH_INCLUDED
|
||||
from app.dependencies.database import get_db
|
||||
from app.models.score import INT_TO_MODE
|
||||
from app.models.user import User as ApiUser
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .api_router import router
|
||||
|
||||
@@ -17,28 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import col
|
||||
|
||||
|
||||
# ---------- Shared Utility ----------
|
||||
async def get_user_by_lookup(
|
||||
db: AsyncSession, lookup: str, key: str = "id"
|
||||
) -> DBUser | None:
|
||||
"""根据查找方式获取用户"""
|
||||
if key == "id":
|
||||
try:
|
||||
user_id = int(lookup)
|
||||
result = await db.exec(select(DBUser).where(DBUser.id == user_id))
|
||||
return result.first()
|
||||
except ValueError:
|
||||
return None
|
||||
elif key == "username":
|
||||
result = await db.exec(select(DBUser).where(DBUser.name == lookup))
|
||||
return result.first()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# ---------- Batch Users ----------
|
||||
class BatchUserResponse(BaseModel):
|
||||
users: list[ApiUser]
|
||||
users: list[UserResp]
|
||||
|
||||
|
||||
@router.get("/users", response_model=BatchUserResponse)
|
||||
@@ -51,75 +28,44 @@ async def get_users(
|
||||
):
|
||||
if user_ids:
|
||||
searched_users = (
|
||||
await session.exec(
|
||||
DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
|
||||
)
|
||||
await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
|
||||
).all()
|
||||
else:
|
||||
searched_users = (
|
||||
await session.exec(DBUser.all_select_clause().limit(50))
|
||||
).all()
|
||||
searched_users = (await session.exec(select(User).limit(50))).all()
|
||||
return BatchUserResponse(
|
||||
users=[
|
||||
await convert_db_user_to_api_user(
|
||||
searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value
|
||||
await UserResp.from_db(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
)
|
||||
for searched_user in searched_users
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# # ---------- Individual User ----------
|
||||
# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
|
||||
# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
|
||||
# async def get_user_with_mode(
|
||||
# user_lookup: str,
|
||||
# mode: Literal["osu", "taiko", "fruits", "mania"],
|
||||
# key: Literal["id", "username"] = Query("id"),
|
||||
# current_user: DBUser = Depends(get_current_user),
|
||||
# db: AsyncSession = Depends(get_db),
|
||||
# ):
|
||||
# """获取指定游戏模式的用户信息"""
|
||||
# user = await get_user_by_lookup(db, user_lookup, key)
|
||||
# if not user:
|
||||
# raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# return await convert_db_user_to_api_user(user, mode)
|
||||
|
||||
|
||||
# @router.get("/users/{user_lookup}", response_model=ApiUser)
|
||||
# @router.get("/users/{user_lookup}/", response_model=ApiUser)
|
||||
# async def get_user_default(
|
||||
# user_lookup: str,
|
||||
# key: Literal["id", "username"] = Query("id"),
|
||||
# current_user: DBUser = Depends(get_current_user),
|
||||
# db: AsyncSession = Depends(get_db),
|
||||
# ):
|
||||
# """获取用户信息(默认使用osu模式,但包含所有模式的统计信息)"""
|
||||
# user = await get_user_by_lookup(db, user_lookup, key)
|
||||
# if not user:
|
||||
# raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# return await convert_db_user_to_api_user(user, "osu")
|
||||
|
||||
|
||||
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
|
||||
@router.get("/users/{user}/", response_model=ApiUser)
|
||||
@router.get("/users/{user}", response_model=ApiUser)
|
||||
@router.get("/users/{user}/{ruleset}", response_model=UserResp)
|
||||
@router.get("/users/{user}/", response_model=UserResp)
|
||||
@router.get("/users/{user}", response_model=UserResp)
|
||||
async def get_user_info(
|
||||
user: str,
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
ruleset: GameMode | None = None,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
DBUser.all_select_clause().where(
|
||||
DBUser.id == int(user)
|
||||
select(User).where(
|
||||
User.id == int(user)
|
||||
if user.isdigit()
|
||||
else DBUser.name == user.removeprefix("@")
|
||||
else User.username == user.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not searched_user:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
return await convert_db_user_to_api_user(searched_user, ruleset=ruleset)
|
||||
return await UserResp.from_db(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
ruleset=ruleset,
|
||||
)
|
||||
|
||||
@@ -6,9 +6,9 @@ import time
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.exception import InvokeException
|
||||
from app.log import logger
|
||||
from app.models.signalr import UserState
|
||||
from app.signalr.exception import InvokeException
|
||||
from app.signalr.packet import (
|
||||
ClosePacket,
|
||||
CompletionPacket,
|
||||
@@ -21,7 +21,6 @@ from app.signalr.store import ResultStore
|
||||
from app.signalr.utils import get_signature
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
|
||||
@@ -49,7 +48,7 @@ class Client:
|
||||
self.connection_id = connection_id
|
||||
self.connection_token = connection_token
|
||||
self.connection = connection
|
||||
self.procotol = protocol
|
||||
self.protocol = protocol
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._ping_task: asyncio.Task | None = None
|
||||
self._store = ResultStore()
|
||||
@@ -62,14 +61,14 @@ class Client:
|
||||
return int(self.connection_id)
|
||||
|
||||
async def send_packet(self, packet: Packet):
|
||||
await self.connection.send_bytes(self.procotol.encode(packet))
|
||||
await self.connection.send_bytes(self.protocol.encode(packet))
|
||||
|
||||
async def receive_packets(self) -> list[Packet]:
|
||||
message = await self.connection.receive()
|
||||
d = message.get("bytes") or message.get("text", "").encode()
|
||||
if not d:
|
||||
return []
|
||||
return self.procotol.decode(d)
|
||||
return self.protocol.decode(d)
|
||||
|
||||
async def _ping(self):
|
||||
while True:
|
||||
@@ -263,10 +262,9 @@ class Hub[TState: UserState]:
|
||||
for name, param in signature.parameters.items():
|
||||
if name == "self" or param.annotation is Client:
|
||||
continue
|
||||
if issubclass(param.annotation, BaseModel):
|
||||
call_params.append(param.annotation.model_validate(args.pop(0)))
|
||||
else:
|
||||
call_params.append(args.pop(0))
|
||||
call_params.append(
|
||||
client.protocol.validate_object(args.pop(0), param.annotation)
|
||||
)
|
||||
return await method_(client, *call_params)
|
||||
|
||||
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
||||
|
||||
@@ -2,15 +2,15 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
from datetime import UTC, datetime
|
||||
from typing import override
|
||||
|
||||
from app.database.relationship import Relationship, RelationshipType
|
||||
from app.dependencies.database import engine
|
||||
from app.database import Relationship, RelationshipType, User
|
||||
from app.dependencies.database import engine, get_redis
|
||||
from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -30,7 +30,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
) -> set[Coroutine]:
|
||||
if store is not None and not store.pushable:
|
||||
return set()
|
||||
data = store.to_dict() if store else None
|
||||
data = store.for_push if store else None
|
||||
return {
|
||||
self.broadcast_group_call(
|
||||
self.online_presence_watchers_group(),
|
||||
@@ -54,6 +54,18 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
async def _clean_state(self, state: MetadataClientState) -> None:
|
||||
if state.pushable:
|
||||
await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None))
|
||||
redis = get_redis()
|
||||
if await redis.exists(f"metadata:online:{state.connection_id}"):
|
||||
await redis.delete(f"metadata:online:{state.connection_id}")
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
user = (
|
||||
await session.exec(
|
||||
select(User).where(User.id == int(state.connection_id))
|
||||
)
|
||||
).one()
|
||||
user.last_visit = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
@override
|
||||
def create_state(self, client: Client) -> MetadataClientState:
|
||||
@@ -89,10 +101,14 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
self.friend_presence_watchers_group(friend_id),
|
||||
"FriendPresenceUpdated",
|
||||
friend_id,
|
||||
friend_state.to_dict(),
|
||||
friend_state.for_push
|
||||
if friend_state.pushable
|
||||
else None,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
redis = get_redis()
|
||||
await redis.set(f"metadata:online:{user_id}", "")
|
||||
|
||||
async def UpdateStatus(self, client: Client, status: int) -> None:
|
||||
status_ = OnlineStatus(status)
|
||||
@@ -107,27 +123,24 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
store.for_push,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
|
||||
async def UpdateActivity(
|
||||
self, client: Client, activity: UserActivity | None
|
||||
) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
activity = (
|
||||
TypeAdapter(UserActivity).validate_python(activity_dict)
|
||||
if activity_dict
|
||||
else None
|
||||
)
|
||||
store = self.get_or_create_state(client)
|
||||
store.user_activity = activity
|
||||
store.activity = activity
|
||||
tasks = self.broadcast_tasks(user_id, store)
|
||||
tasks.add(
|
||||
self.call_noblock(
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
store.for_push,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -139,7 +152,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
store,
|
||||
)
|
||||
for user_id, store in self.state.items()
|
||||
if store.pushable
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,15 +7,13 @@ import struct
|
||||
import time
|
||||
from typing import override
|
||||
|
||||
from app.database import Beatmap
|
||||
from app.database import Beatmap, User
|
||||
from app.database.score import Score
|
||||
from app.database.score_token import ScoreToken
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import engine
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import mods_to_int
|
||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
|
||||
from app.models.signalr import serialize_to_list
|
||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
|
||||
from app.models.spectator_hub import (
|
||||
APIUser,
|
||||
FrameDataBundle,
|
||||
@@ -70,8 +68,8 @@ def save_replay(
|
||||
md5: str,
|
||||
username: str,
|
||||
score: Score,
|
||||
statistics: ScoreStatisticsInt,
|
||||
maximum_statistics: ScoreStatisticsInt,
|
||||
statistics: ScoreStatistics,
|
||||
maximum_statistics: ScoreStatistics,
|
||||
frames: list[LegacyReplayFrame],
|
||||
) -> None:
|
||||
data = bytearray()
|
||||
@@ -108,8 +106,8 @@ def save_replay(
|
||||
last_time = 0
|
||||
for frame in frames:
|
||||
frame_strs.append(
|
||||
f"{frame.time - last_time}|{frame.x or 0.0}"
|
||||
f"|{frame.y or 0.0}|{frame.button_state}"
|
||||
f"{frame.time - last_time}|{frame.mouse_x or 0.0}"
|
||||
f"|{frame.mouse_y or 0.0}|{frame.button_state}"
|
||||
)
|
||||
last_time = frame.time
|
||||
frame_strs.append("-12345|0|0|0")
|
||||
@@ -166,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
tasks = [
|
||||
self.call_noblock(
|
||||
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
|
||||
)
|
||||
self.call_noblock(client, "UserBeganPlaying", user_id, store.state)
|
||||
for user_id, store in self.state.items()
|
||||
if store.state is not None
|
||||
]
|
||||
@@ -197,7 +193,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
).first()
|
||||
if not user:
|
||||
return
|
||||
name = user.name
|
||||
name = user.username
|
||||
store.state = state
|
||||
store.beatmap_status = beatmap.beatmap_status
|
||||
store.checksum = beatmap.checksum
|
||||
@@ -215,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
self.group_id(user_id),
|
||||
"UserBeganPlaying",
|
||||
user_id,
|
||||
serialize_to_list(state),
|
||||
state,
|
||||
)
|
||||
|
||||
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
||||
@@ -223,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
state = self.get_or_create_state(client)
|
||||
if not state.score:
|
||||
return
|
||||
state.score.score_info.acc = frame_data.header.acc
|
||||
state.score.score_info.accuracy = frame_data.header.accuracy
|
||||
state.score.score_info.combo = frame_data.header.combo
|
||||
state.score.score_info.max_combo = frame_data.header.max_combo
|
||||
state.score.score_info.statistics = frame_data.header.statistics
|
||||
@@ -234,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
self.group_id(user_id),
|
||||
"UserSentFrames",
|
||||
user_id,
|
||||
frame_data.model_dump(),
|
||||
frame_data,
|
||||
)
|
||||
|
||||
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
||||
@@ -297,19 +293,18 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
score_record.id,
|
||||
)
|
||||
# save replay
|
||||
if store.state.state == SpectatedUserState.Passed:
|
||||
score_record.has_replay = True
|
||||
await session.commit()
|
||||
await session.refresh(score_record)
|
||||
save_replay(
|
||||
ruleset_id=store.ruleset_id,
|
||||
md5=store.checksum,
|
||||
username=store.score.score_info.user.name,
|
||||
score=score_record,
|
||||
statistics=store.score.score_info.statistics,
|
||||
maximum_statistics=store.score.score_info.maximum_statistics,
|
||||
frames=store.score.replay_frames,
|
||||
)
|
||||
score_record.has_replay = True
|
||||
await session.commit()
|
||||
await session.refresh(score_record)
|
||||
save_replay(
|
||||
ruleset_id=store.ruleset_id,
|
||||
md5=store.checksum,
|
||||
username=store.score.score_info.user.name,
|
||||
score=score_record,
|
||||
statistics=store.score.score_info.statistics,
|
||||
maximum_statistics=store.score.score_info.maximum_statistics,
|
||||
frames=store.score.replay_frames,
|
||||
)
|
||||
|
||||
async def _end_session(self, user_id: int, state: SpectatorState) -> None:
|
||||
if state.state == SpectatedUserState.Playing:
|
||||
@@ -318,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
self.group_id(user_id),
|
||||
"UserFinishedPlaying",
|
||||
user_id,
|
||||
serialize_to_list(state) if state else None,
|
||||
state,
|
||||
)
|
||||
|
||||
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
@@ -329,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
client,
|
||||
"UserBeganPlaying",
|
||||
target_id,
|
||||
serialize_to_list(target_store.state),
|
||||
target_store.state,
|
||||
)
|
||||
store = self.get_or_create_state(client)
|
||||
store.watched_user.add(target_id)
|
||||
@@ -339,7 +334,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
username = (
|
||||
await session.exec(select(User.name).where(User.id == user_id))
|
||||
await session.exec(select(User.username).where(User.id == user_id))
|
||||
).first()
|
||||
if not username:
|
||||
return
|
||||
|
||||
@@ -1,14 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
import datetime
|
||||
from enum import Enum, IntEnum
|
||||
import inspect
|
||||
import json
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol as TypingProtocol,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from app.models.signalr import SignalRMeta, SignalRUnionMessage
|
||||
from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal
|
||||
|
||||
import msgpack_lazer_api as m
|
||||
from pydantic import BaseModel
|
||||
|
||||
SEP = b"\x1e"
|
||||
|
||||
@@ -73,8 +83,67 @@ class Protocol(TypingProtocol):
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes: ...
|
||||
|
||||
@classmethod
|
||||
def validate_object(cls, v: Any, typ: type) -> Any: ...
|
||||
|
||||
|
||||
class MsgpackProtocol:
|
||||
@classmethod
|
||||
def serialize_msgpack(cls, v: Any) -> Any:
|
||||
typ = v.__class__
|
||||
if issubclass(typ, BaseModel):
|
||||
return cls.serialize_to_list(v)
|
||||
elif issubclass(typ, list):
|
||||
return [cls.serialize_msgpack(item) for item in v]
|
||||
elif issubclass(typ, datetime.datetime):
|
||||
return [v, 0]
|
||||
elif issubclass(typ, datetime.timedelta):
|
||||
return int(v.total_seconds() * 10_000_000)
|
||||
elif isinstance(v, dict):
|
||||
return {
|
||||
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
|
||||
for k, value in v.items()
|
||||
}
|
||||
elif issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
return list_.index(v) if v in list_ else v.value
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def serialize_to_list(cls, value: BaseModel) -> list[Any]:
|
||||
values = []
|
||||
for field, info in value.__class__.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.member_ignore:
|
||||
continue
|
||||
values.append(cls.serialize_msgpack(v=getattr(value, field)))
|
||||
if issubclass(value.__class__, SignalRUnionMessage):
|
||||
return [value.__class__.union_type, values]
|
||||
else:
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def process_object(v: Any, typ: type[BaseModel]) -> Any:
|
||||
if isinstance(v, list):
|
||||
d = {}
|
||||
i = 0
|
||||
for field, info in typ.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.member_ignore:
|
||||
continue
|
||||
anno = info.annotation
|
||||
if anno is None:
|
||||
d[camel_to_snake(field)] = v[i]
|
||||
else:
|
||||
d[field] = MsgpackProtocol.validate_object(v[i], anno)
|
||||
i += 1
|
||||
return d
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def _encode_varint(value: int) -> bytes:
|
||||
result = []
|
||||
@@ -140,6 +209,53 @@ class MsgpackProtocol:
|
||||
]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@classmethod
|
||||
def validate_object(cls, v: Any, typ: type) -> Any:
|
||||
if issubclass(typ, BaseModel):
|
||||
return typ.model_validate(obj=cls.process_object(v, typ))
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||
return v[0]
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
|
||||
return datetime.timedelta(seconds=int(v / 10_000_000))
|
||||
elif get_origin(typ) is list:
|
||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
|
||||
elif get_origin(typ) is dict:
|
||||
return {
|
||||
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
|
||||
v, get_args(typ)[1]
|
||||
)
|
||||
for k, v in v.items()
|
||||
}
|
||||
elif (origin := get_origin(typ)) is Union or origin is UnionType:
|
||||
args = get_args(typ)
|
||||
if len(args) == 2 and NoneType in args:
|
||||
non_none_args = [arg for arg in args if arg is not NoneType]
|
||||
if len(non_none_args) == 1:
|
||||
if v is None:
|
||||
return None
|
||||
return cls.validate_object(v, non_none_args[0])
|
||||
|
||||
# suppose use `MessagePack-CSharp Union | None`
|
||||
# except `X (Other Type) | None`
|
||||
if NoneType in args and v is None:
|
||||
return None
|
||||
if not all(
|
||||
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot validate {v} to {typ}, "
|
||||
"only SignalRUnionMessage subclasses are supported"
|
||||
)
|
||||
union_type = v[0]
|
||||
for arg in args:
|
||||
assert issubclass(arg, SignalRUnionMessage)
|
||||
if arg.union_type == union_type:
|
||||
return cls.validate_object(v[1], arg)
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes:
|
||||
payload = [packet.type.value, packet.header or {}]
|
||||
@@ -151,20 +267,24 @@ class MsgpackProtocol:
|
||||
]
|
||||
)
|
||||
if packet.arguments is not None:
|
||||
payload.append(packet.arguments)
|
||||
payload.append(
|
||||
[MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments]
|
||||
)
|
||||
if packet.stream_ids is not None:
|
||||
payload.append(packet.stream_ids)
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
result_kind = 2
|
||||
if packet.error:
|
||||
result_kind = 1
|
||||
elif packet.result is None:
|
||||
elif packet.result is not None:
|
||||
result_kind = 3
|
||||
payload.extend(
|
||||
[
|
||||
packet.invocation_id,
|
||||
result_kind,
|
||||
packet.error or packet.result or None,
|
||||
packet.error
|
||||
or MsgpackProtocol.serialize_msgpack(packet.result)
|
||||
or None,
|
||||
]
|
||||
)
|
||||
elif isinstance(packet, ClosePacket):
|
||||
@@ -181,6 +301,86 @@ class MsgpackProtocol:
|
||||
|
||||
|
||||
class JSONProtocol:
|
||||
@classmethod
|
||||
def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False):
|
||||
typ = v.__class__
|
||||
if issubclass(typ, BaseModel):
|
||||
return cls.serialize_model(v, in_union)
|
||||
elif isinstance(v, dict):
|
||||
return {
|
||||
cls.serialize_to_json(k, True): cls.serialize_to_json(value)
|
||||
for k, value in v.items()
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
return [cls.serialize_to_json(item) for item in v]
|
||||
elif isinstance(v, datetime.datetime):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, datetime.timedelta):
|
||||
# d.hh:mm:ss
|
||||
total_seconds = int(v.total_seconds())
|
||||
hours, remainder = divmod(total_seconds, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
return f"{hours:02}:{minutes:02}:{seconds:02}"
|
||||
elif isinstance(v, Enum) and dict_key:
|
||||
return v.value
|
||||
elif isinstance(v, Enum):
|
||||
list_ = list(typ)
|
||||
return list_.index(v)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]:
|
||||
d = {}
|
||||
is_union = issubclass(v.__class__, SignalRUnionMessage)
|
||||
for field, info in v.__class__.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.json_ignore:
|
||||
continue
|
||||
name = (
|
||||
snake_to_camel(
|
||||
field,
|
||||
metadata.use_abbr if metadata else True,
|
||||
)
|
||||
if not is_union
|
||||
else snake_to_pascal(
|
||||
field,
|
||||
metadata.use_abbr if metadata else True,
|
||||
)
|
||||
)
|
||||
d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union)
|
||||
if is_union and not in_union:
|
||||
return {
|
||||
"$dtype": v.__class__.__name__,
|
||||
"$value": d,
|
||||
}
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def process_object(
|
||||
v: Any, typ: type[BaseModel], from_union: bool = False
|
||||
) -> dict[str, Any]:
|
||||
d = {}
|
||||
for field, info in typ.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.json_ignore:
|
||||
continue
|
||||
name = (
|
||||
snake_to_camel(field, metadata.use_abbr if metadata else True)
|
||||
if not from_union
|
||||
else snake_to_pascal(field, metadata.use_abbr if metadata else True)
|
||||
)
|
||||
value = v.get(name)
|
||||
anno = typ.model_fields[field].annotation
|
||||
if anno is None:
|
||||
d[field] = value
|
||||
continue
|
||||
d[field] = JSONProtocol.validate_object(value, anno)
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
packets_raw = input.removesuffix(SEP).split(SEP)
|
||||
@@ -225,6 +425,63 @@ class JSONProtocol:
|
||||
]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@classmethod
|
||||
def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
|
||||
if issubclass(typ, BaseModel):
|
||||
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||
return datetime.datetime.fromisoformat(v)
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
|
||||
# d.hh:mm:ss
|
||||
parts = v.split(":")
|
||||
if len(parts) == 3:
|
||||
return datetime.timedelta(
|
||||
hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2])
|
||||
)
|
||||
elif len(parts) == 2:
|
||||
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
|
||||
elif len(parts) == 1:
|
||||
return datetime.timedelta(seconds=int(parts[0]))
|
||||
elif get_origin(typ) is list:
|
||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
|
||||
elif get_origin(typ) is dict:
|
||||
return {
|
||||
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
|
||||
v, get_args(typ)[1]
|
||||
)
|
||||
for k, v in v.items()
|
||||
}
|
||||
elif (origin := get_origin(typ)) is Union or origin is UnionType:
|
||||
args = get_args(typ)
|
||||
if len(args) == 2 and NoneType in args:
|
||||
non_none_args = [arg for arg in args if arg is not NoneType]
|
||||
if len(non_none_args) == 1:
|
||||
if v is None:
|
||||
return None
|
||||
return cls.validate_object(v, non_none_args[0])
|
||||
|
||||
# suppose use `MessagePack-CSharp Union | None`
|
||||
# except `X (Other Type) | None`
|
||||
if NoneType in args and v is None:
|
||||
return None
|
||||
if not all(
|
||||
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot validate {v} to {typ}, "
|
||||
"only SignalRUnionMessage subclasses are supported"
|
||||
)
|
||||
# https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
|
||||
union_type = v["$dtype"]
|
||||
for arg in args:
|
||||
assert issubclass(arg, SignalRUnionMessage)
|
||||
if arg.__name__ == union_type:
|
||||
return cls.validate_object(v["$value"], arg, True)
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes:
|
||||
payload: dict[str, Any] = {
|
||||
@@ -241,7 +498,9 @@ class JSONProtocol:
|
||||
if packet.invocation_id is not None:
|
||||
payload["invocationId"] = packet.invocation_id
|
||||
if packet.arguments is not None:
|
||||
payload["arguments"] = packet.arguments
|
||||
payload["arguments"] = [
|
||||
JSONProtocol.serialize_to_json(arg) for arg in packet.arguments
|
||||
]
|
||||
if packet.stream_ids is not None:
|
||||
payload["streamIds"] = packet.stream_ids
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
@@ -253,7 +512,7 @@ class JSONProtocol:
|
||||
if packet.error is not None:
|
||||
payload["error"] = packet.error
|
||||
if packet.result is not None:
|
||||
payload["result"] = packet.result
|
||||
payload["result"] = JSONProtocol.serialize_to_json(packet.result)
|
||||
elif isinstance(packet, PingPacket):
|
||||
pass
|
||||
elif isinstance(packet, ClosePacket):
|
||||
|
||||
535
app/utils.py
535
app/utils.py
@@ -1,465 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.database import (
|
||||
LazerUserCounts,
|
||||
LazerUserProfile,
|
||||
LazerUserStatistics,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.models.user import (
|
||||
Country,
|
||||
Cover,
|
||||
DailyChallengeStats,
|
||||
GradeCounts,
|
||||
Kudosu,
|
||||
Level,
|
||||
Page,
|
||||
RankHighest,
|
||||
RankHistory,
|
||||
Statistics,
|
||||
User,
|
||||
UserAchievement,
|
||||
)
|
||||
|
||||
|
||||
def unix_timestamp_to_windows(timestamp: int) -> int:
|
||||
"""Convert a Unix timestamp to a Windows timestamp."""
|
||||
return (timestamp + 62135596800) * 10_000_000
|
||||
|
||||
|
||||
async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
|
||||
"""将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
|
||||
|
||||
# 从db_user获取基本字段值
|
||||
user_id = getattr(db_user, "id")
|
||||
user_name = getattr(db_user, "name")
|
||||
user_country = getattr(db_user, "country")
|
||||
user_country_code = user_country # 在User模型中,country字段就是country_code
|
||||
|
||||
# 获取 Lazer 用户资料
|
||||
profile = db_user.lazer_profile
|
||||
if not profile:
|
||||
# 如果没有 lazer 资料,使用默认值
|
||||
profile = LazerUserProfile(
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系
|
||||
lzrcnt = db_user.lazer_counts
|
||||
|
||||
if not lzrcnt:
|
||||
# 如果没有 lazer 计数,使用默认值
|
||||
lzrcnt = LazerUserCounts(user_id=user_id)
|
||||
|
||||
# 获取指定模式的统计信息
|
||||
user_stats = None
|
||||
if db_user.lazer_statistics:
|
||||
for stat in db_user.lazer_statistics:
|
||||
if stat.mode == ruleset:
|
||||
user_stats = stat
|
||||
break
|
||||
|
||||
if not user_stats:
|
||||
# 如果没有找到指定模式的统计,创建默认统计
|
||||
user_stats = LazerUserStatistics(user_id=user_id)
|
||||
|
||||
# 获取国家信息
|
||||
country_code = db_user.country_code if db_user.country_code is not None else "XX"
|
||||
|
||||
country = Country(code=str(country_code), name=get_country_name(str(country_code)))
|
||||
|
||||
# 获取 Kudosu 信息
|
||||
kudosu = Kudosu(available=0, total=0)
|
||||
|
||||
# 获取计数信息
|
||||
# counts = LazerUserCounts(user_id=user_id)
|
||||
|
||||
# 转换统计信息
|
||||
statistics = Statistics(
|
||||
count_100=user_stats.count_100,
|
||||
count_300=user_stats.count_300,
|
||||
count_50=user_stats.count_50,
|
||||
count_miss=user_stats.count_miss,
|
||||
level=Level(
|
||||
current=user_stats.level_current, progress=user_stats.level_progress
|
||||
),
|
||||
global_rank=user_stats.global_rank,
|
||||
global_rank_exp=user_stats.global_rank_exp,
|
||||
pp=float(user_stats.pp) if user_stats.pp else 0.0,
|
||||
pp_exp=float(user_stats.pp_exp) if user_stats.pp_exp else 0.0,
|
||||
ranked_score=user_stats.ranked_score,
|
||||
hit_accuracy=float(user_stats.hit_accuracy) if user_stats.hit_accuracy else 0.0,
|
||||
play_count=user_stats.play_count,
|
||||
play_time=user_stats.play_time,
|
||||
total_score=user_stats.total_score,
|
||||
total_hits=user_stats.total_hits,
|
||||
maximum_combo=user_stats.maximum_combo,
|
||||
replays_watched_by_others=user_stats.replays_watched_by_others,
|
||||
is_ranked=user_stats.is_ranked,
|
||||
grade_counts=GradeCounts(
|
||||
ss=user_stats.grade_ss,
|
||||
ssh=user_stats.grade_ssh,
|
||||
s=user_stats.grade_s,
|
||||
sh=user_stats.grade_sh,
|
||||
a=user_stats.grade_a,
|
||||
),
|
||||
country_rank=user_stats.country_rank,
|
||||
rank={"country": user_stats.country_rank} if user_stats.country_rank else None,
|
||||
)
|
||||
|
||||
# 转换所有模式的统计信息
|
||||
statistics_rulesets = {}
|
||||
if db_user.lazer_statistics:
|
||||
for stat in db_user.lazer_statistics:
|
||||
statistics_rulesets[stat.mode] = Statistics(
|
||||
count_100=stat.count_100,
|
||||
count_300=stat.count_300,
|
||||
count_50=stat.count_50,
|
||||
count_miss=stat.count_miss,
|
||||
level=Level(current=stat.level_current, progress=stat.level_progress),
|
||||
global_rank=stat.global_rank,
|
||||
global_rank_exp=stat.global_rank_exp,
|
||||
pp=float(stat.pp) if stat.pp else 0.0,
|
||||
pp_exp=float(stat.pp_exp) if stat.pp_exp else 0.0,
|
||||
ranked_score=stat.ranked_score,
|
||||
hit_accuracy=float(stat.hit_accuracy) if stat.hit_accuracy else 0.0,
|
||||
play_count=stat.play_count,
|
||||
play_time=stat.play_time,
|
||||
total_score=stat.total_score,
|
||||
total_hits=stat.total_hits,
|
||||
maximum_combo=stat.maximum_combo,
|
||||
replays_watched_by_others=stat.replays_watched_by_others,
|
||||
is_ranked=stat.is_ranked,
|
||||
grade_counts=GradeCounts(
|
||||
ss=stat.grade_ss,
|
||||
ssh=stat.grade_ssh,
|
||||
s=stat.grade_s,
|
||||
sh=stat.grade_sh,
|
||||
a=stat.grade_a,
|
||||
),
|
||||
country_rank=stat.country_rank,
|
||||
rank={"country": stat.country_rank} if stat.country_rank else None,
|
||||
)
|
||||
|
||||
# 转换国家信息
|
||||
country = Country(code=user_country_code, name=get_country_name(user_country_code))
|
||||
|
||||
# 转换封面信息
|
||||
cover_url = (
|
||||
profile.cover_url
|
||||
if profile and profile.cover_url
|
||||
else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
|
||||
)
|
||||
cover = Cover(
|
||||
custom_url=profile.cover_url if profile else None, url=str(cover_url), id=None
|
||||
)
|
||||
|
||||
# 转换 Kudosu 信息
|
||||
kudosu = Kudosu(available=0, total=0)
|
||||
|
||||
# 转换成就信息
|
||||
user_achievements = []
|
||||
if db_user.lazer_achievements:
|
||||
for achievement in db_user.lazer_achievements:
|
||||
user_achievements.append(
|
||||
UserAchievement(
|
||||
achieved_at=achievement.achieved_at,
|
||||
achievement_id=achievement.achievement_id,
|
||||
)
|
||||
)
|
||||
|
||||
# 转换排名历史
|
||||
rank_history = None
|
||||
rank_history_data = None
|
||||
for rh in db_user.rank_history:
|
||||
if rh.mode == ruleset:
|
||||
rank_history_data = rh.rank_data
|
||||
break
|
||||
|
||||
if rank_history_data:
|
||||
rank_history = RankHistory(mode=ruleset, data=rank_history_data)
|
||||
|
||||
# 转换每日挑战统计
|
||||
# daily_challenge_stats = None
|
||||
# if db_user.daily_challenge_stats:
|
||||
# dcs = db_user.daily_challenge_stats
|
||||
# daily_challenge_stats = DailyChallengeStats(
|
||||
# daily_streak_best=dcs.daily_streak_best,
|
||||
# daily_streak_current=dcs.daily_streak_current,
|
||||
# last_update=dcs.last_update,
|
||||
# last_weekly_streak=dcs.last_weekly_streak,
|
||||
# playcount=dcs.playcount,
|
||||
# top_10p_placements=dcs.top_10p_placements,
|
||||
# top_50p_placements=dcs.top_50p_placements,
|
||||
# user_id=dcs.user_id,
|
||||
# weekly_streak_best=dcs.weekly_streak_best,
|
||||
# weekly_streak_current=dcs.weekly_streak_current,
|
||||
# )
|
||||
|
||||
# 转换最高排名
|
||||
rank_highest = None
|
||||
if user_stats.rank_highest:
|
||||
rank_highest = RankHighest(
|
||||
rank=user_stats.rank_highest,
|
||||
updated_at=user_stats.rank_highest_updated_at or datetime.utcnow(),
|
||||
)
|
||||
|
||||
# 转换团队信息
|
||||
team = None
|
||||
if db_user.team_membership:
|
||||
team_member = db_user.team_membership # 假设用户只属于一个团队
|
||||
team = team_member.team
|
||||
|
||||
# 创建用户对象
|
||||
# 从db_user获取基本字段值
|
||||
user_id = getattr(db_user, "id")
|
||||
user_name = getattr(db_user, "name")
|
||||
user_country = getattr(db_user, "country")
|
||||
|
||||
# 获取用户头像URL
|
||||
avatar_url = None
|
||||
|
||||
# 首先检查 profile 中的 avatar_url
|
||||
if profile and hasattr(profile, "avatar_url") and profile.avatar_url:
|
||||
avatar_url = str(profile.avatar_url)
|
||||
|
||||
# 然后检查是否有关联的头像记录
|
||||
if avatar_url is None and hasattr(db_user, "avatar") and db_user.avatar is not None:
|
||||
if db_user.avatar.r2_game_url:
|
||||
# 优先使用游戏用的头像URL
|
||||
avatar_url = str(db_user.avatar.r2_game_url)
|
||||
elif db_user.avatar.r2_original_url:
|
||||
# 其次使用原始头像URL
|
||||
avatar_url = str(db_user.avatar.r2_original_url)
|
||||
|
||||
# 如果还是没有找到,通过查询获取
|
||||
# if db_session and avatar_url is None:
|
||||
# try:
|
||||
# # 导入UserAvatar模型
|
||||
|
||||
# # 尝试查找用户的头像记录
|
||||
# statement = select(UserAvatar).where(
|
||||
# UserAvatar.user_id == user_id, UserAvatar.is_active == True
|
||||
# )
|
||||
# avatar_record = db_session.exec(statement).first()
|
||||
# if avatar_record is not None:
|
||||
# if avatar_record.r2_game_url is not None:
|
||||
# # 优先使用游戏用的头像URL
|
||||
# avatar_url = str(avatar_record.r2_game_url)
|
||||
# elif avatar_record.r2_original_url is not None:
|
||||
# # 其次使用原始头像URL
|
||||
# avatar_url = str(avatar_record.r2_original_url)
|
||||
# except Exception as e:
|
||||
# print(f"获取用户头像时出错: {e}")
|
||||
# print(f"最终头像URL: {avatar_url}")
|
||||
# 如果仍然没有找到头像URL,则使用默认URL
|
||||
if avatar_url is None:
|
||||
avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1"
|
||||
|
||||
# 处理 profile_order 列表排序
|
||||
profile_order = [
|
||||
"me",
|
||||
"recent_activity",
|
||||
"top_ranks",
|
||||
"medals",
|
||||
"historical",
|
||||
"beatmaps",
|
||||
"kudosu",
|
||||
]
|
||||
if profile and profile.profile_order:
|
||||
profile_order = profile.profile_order.split(",")
|
||||
|
||||
# 在convert_db_user_to_api_user函数中添加active_tournament_banners处理
|
||||
active_tournament_banners = []
|
||||
if db_user.active_banners:
|
||||
for banner in db_user.active_banners:
|
||||
active_tournament_banners.append(
|
||||
{
|
||||
"tournament_id": banner.tournament_id,
|
||||
"image_url": banner.image_url,
|
||||
"is_active": banner.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
# 在convert_db_user_to_api_user函数中添加badges处理
|
||||
badges = []
|
||||
if db_user.lazer_badges:
|
||||
for badge in db_user.lazer_badges:
|
||||
badges.append(
|
||||
{
|
||||
"badge_id": badge.badge_id,
|
||||
"awarded_at": badge.awarded_at,
|
||||
"description": badge.description,
|
||||
"image_url": badge.image_url,
|
||||
"url": badge.url,
|
||||
}
|
||||
)
|
||||
|
||||
# 在convert_db_user_to_api_user函数中添加monthly_playcounts处理
|
||||
monthly_playcounts = []
|
||||
if db_user.lazer_monthly_playcounts:
|
||||
for playcount in db_user.lazer_monthly_playcounts:
|
||||
monthly_playcounts.append(
|
||||
{
|
||||
"start_date": playcount.start_date.isoformat()
|
||||
if playcount.start_date
|
||||
else None,
|
||||
"play_count": playcount.play_count,
|
||||
}
|
||||
)
|
||||
|
||||
# 在convert_db_user_to_api_user函数中添加previous_usernames处理
|
||||
previous_usernames = []
|
||||
if db_user.lazer_previous_usernames:
|
||||
for username in db_user.lazer_previous_usernames:
|
||||
previous_usernames.append(
|
||||
{
|
||||
"username": username.username,
|
||||
"changed_at": username.changed_at.isoformat()
|
||||
if username.changed_at
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
# 在convert_db_user_to_api_user函数中添加replays_watched_counts处理
|
||||
replays_watched_counts = []
|
||||
if hasattr(db_user, "lazer_replays_watched") and db_user.lazer_replays_watched:
|
||||
for replay in db_user.lazer_replays_watched:
|
||||
replays_watched_counts.append(
|
||||
{
|
||||
"start_date": replay.start_date.isoformat()
|
||||
if replay.start_date
|
||||
else None,
|
||||
"count": replay.count,
|
||||
}
|
||||
)
|
||||
|
||||
# 创建用户对象
|
||||
user = User(
|
||||
id=user_id,
|
||||
username=user_name,
|
||||
avatar_url=avatar_url,
|
||||
country_code=str(country_code),
|
||||
default_group=profile.default_group if profile else "default",
|
||||
is_active=profile.is_active,
|
||||
is_bot=profile.is_bot,
|
||||
is_deleted=profile.is_deleted,
|
||||
is_online=profile.is_online,
|
||||
is_supporter=profile.is_supporter,
|
||||
is_restricted=profile.is_restricted,
|
||||
last_visit=db_user.last_visit,
|
||||
pm_friends_only=profile.pm_friends_only,
|
||||
profile_colour=profile.profile_colour,
|
||||
cover_url=profile.cover_url
|
||||
if profile and profile.cover_url
|
||||
else "https://assets.ppy.sh/user-profile-covers/default.jpeg",
|
||||
discord=profile.discord if profile else None,
|
||||
has_supported=profile.has_supported if profile else False,
|
||||
interests=profile.interests if profile else None,
|
||||
join_date=profile.join_date if profile.join_date else datetime.now(UTC),
|
||||
location=profile.location if profile else None,
|
||||
max_blocks=profile.max_blocks if profile and profile.max_blocks else 100,
|
||||
max_friends=profile.max_friends if profile and profile.max_friends else 500,
|
||||
post_count=profile.post_count if profile and profile.post_count else 0,
|
||||
profile_hue=profile.profile_hue if profile and profile.profile_hue else None,
|
||||
profile_order=profile_order, # 使用排序后的 profile_order
|
||||
title=profile.title if profile else None,
|
||||
title_url=profile.title_url if profile else None,
|
||||
twitter=profile.twitter if profile else None,
|
||||
website=profile.website if profile else None,
|
||||
session_verified=True,
|
||||
support_level=profile.support_level if profile else 0,
|
||||
country=country,
|
||||
cover=cover,
|
||||
kudosu=kudosu,
|
||||
statistics=statistics,
|
||||
statistics_rulesets=statistics_rulesets,
|
||||
beatmap_playcounts_count=lzrcnt.beatmap_playcounts_count if lzrcnt else 0,
|
||||
comments_count=lzrcnt.comments_count if lzrcnt else 0,
|
||||
favourite_beatmapset_count=lzrcnt.favourite_beatmapset_count if lzrcnt else 0,
|
||||
follower_count=lzrcnt.follower_count if lzrcnt else 0,
|
||||
graveyard_beatmapset_count=lzrcnt.graveyard_beatmapset_count if lzrcnt else 0,
|
||||
guest_beatmapset_count=lzrcnt.guest_beatmapset_count if lzrcnt else 0,
|
||||
loved_beatmapset_count=lzrcnt.loved_beatmapset_count if lzrcnt else 0,
|
||||
mapping_follower_count=lzrcnt.mapping_follower_count if lzrcnt else 0,
|
||||
nominated_beatmapset_count=lzrcnt.nominated_beatmapset_count if lzrcnt else 0,
|
||||
pending_beatmapset_count=lzrcnt.pending_beatmapset_count if lzrcnt else 0,
|
||||
ranked_beatmapset_count=lzrcnt.ranked_beatmapset_count if lzrcnt else 0,
|
||||
ranked_and_approved_beatmapset_count=lzrcnt.ranked_and_approved_beatmapset_count
|
||||
if lzrcnt
|
||||
else 0,
|
||||
unranked_beatmapset_count=lzrcnt.unranked_beatmapset_count if lzrcnt else 0,
|
||||
scores_best_count=lzrcnt.scores_best_count if lzrcnt else 0,
|
||||
scores_first_count=lzrcnt.scores_first_count if lzrcnt else 0,
|
||||
scores_pinned_count=lzrcnt.scores_pinned_count,
|
||||
scores_recent_count=lzrcnt.scores_recent_count if lzrcnt else 0,
|
||||
account_history=[], # TODO: 获取用户历史账户信息
|
||||
# active_tournament_banner=len(active_tournament_banners),
|
||||
active_tournament_banners=active_tournament_banners,
|
||||
badges=badges,
|
||||
current_season_stats=None,
|
||||
daily_challenge_user_stats=DailyChallengeStats(
|
||||
user_id=user_id,
|
||||
daily_streak_best=db_user.daily_challenge_stats.daily_streak_best
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
daily_streak_current=db_user.daily_challenge_stats.daily_streak_current
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
last_update=db_user.daily_challenge_stats.last_update
|
||||
if db_user.daily_challenge_stats
|
||||
else None,
|
||||
last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak
|
||||
if db_user.daily_challenge_stats
|
||||
else None,
|
||||
playcount=db_user.daily_challenge_stats.playcount
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
top_10p_placements=db_user.daily_challenge_stats.top_10p_placements
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
top_50p_placements=db_user.daily_challenge_stats.top_50p_placements
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current
|
||||
if db_user.daily_challenge_stats
|
||||
else 0,
|
||||
),
|
||||
groups=[],
|
||||
monthly_playcounts=monthly_playcounts,
|
||||
page=Page(html=profile.page_html or "", raw=profile.page_raw or "")
|
||||
if profile.page_html or profile.page_raw
|
||||
else Page(),
|
||||
previous_usernames=previous_usernames,
|
||||
rank_highest=rank_highest,
|
||||
rank_history=rank_history,
|
||||
rankHistory=rank_history,
|
||||
replays_watched_counts=replays_watched_counts,
|
||||
team=team,
|
||||
user_achievements=user_achievements,
|
||||
)
|
||||
|
||||
return user
|
||||
def camel_to_snake(name: str) -> str:
|
||||
"""Convert a camelCase string to snake_case."""
|
||||
result = []
|
||||
last_chr = ""
|
||||
for char in name:
|
||||
if char.isupper():
|
||||
if not last_chr.isupper() and result:
|
||||
result.append("_")
|
||||
result.append(char.lower())
|
||||
else:
|
||||
result.append(char)
|
||||
last_chr = char
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def get_country_name(country_code: str) -> str:
|
||||
"""根据国家代码获取国家名称"""
|
||||
country_names = {
|
||||
"CN": "China",
|
||||
"JP": "Japan",
|
||||
"US": "United States",
|
||||
"GB": "United Kingdom",
|
||||
"DE": "Germany",
|
||||
"FR": "France",
|
||||
"KR": "South Korea",
|
||||
"CA": "Canada",
|
||||
"AU": "Australia",
|
||||
"BR": "Brazil",
|
||||
# 可以添加更多国家
|
||||
def snake_to_camel(name: str, use_abbr: bool = True) -> str:
|
||||
"""Convert a snake_case string to camelCase."""
|
||||
if not name:
|
||||
return name
|
||||
|
||||
parts = name.split("_")
|
||||
if not parts:
|
||||
return name
|
||||
|
||||
# 常见缩写词列表
|
||||
abbreviations = {
|
||||
"id",
|
||||
"url",
|
||||
"api",
|
||||
"http",
|
||||
"https",
|
||||
"xml",
|
||||
"json",
|
||||
"css",
|
||||
"html",
|
||||
"sql",
|
||||
"db",
|
||||
}
|
||||
return country_names.get(country_code, "Unknown")
|
||||
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.lower() in abbreviations and use_abbr:
|
||||
result.append(part.upper())
|
||||
else:
|
||||
if result:
|
||||
result.append(part.capitalize())
|
||||
else:
|
||||
result.append(part.lower())
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def snake_to_pascal(name: str, use_abbr: bool = True) -> str:
|
||||
"""Convert a snake_case string to PascalCase."""
|
||||
if not name:
|
||||
return name
|
||||
|
||||
parts = name.split("_")
|
||||
if not parts:
|
||||
return name
|
||||
|
||||
# 常见缩写词列表
|
||||
abbreviations = {
|
||||
"id",
|
||||
"url",
|
||||
"api",
|
||||
"http",
|
||||
"https",
|
||||
"xml",
|
||||
"json",
|
||||
"css",
|
||||
"html",
|
||||
"sql",
|
||||
"db",
|
||||
}
|
||||
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.lower() in abbreviations and use_abbr:
|
||||
result.append(part.upper())
|
||||
else:
|
||||
result.append(part.capitalize())
|
||||
|
||||
return "".join(result)
|
||||
|
||||
114
main.py
114
main.py
@@ -4,25 +4,27 @@ from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from app.config import settings
|
||||
from app.database import Team # noqa: F401
|
||||
from app.dependencies.database import create_tables, engine
|
||||
from app.dependencies.database import create_tables, engine, redis_client
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.models.user import User
|
||||
from app.router import api_router, auth_router, fetcher_router, signalr_router
|
||||
from app.router import (
|
||||
api_router,
|
||||
auth_router,
|
||||
fetcher_router,
|
||||
signalr_router,
|
||||
)
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
User.model_rebuild()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# on startup
|
||||
await create_tables()
|
||||
get_fetcher() # 初始化 fetcher
|
||||
await get_fetcher() # 初始化 fetcher
|
||||
# on shutdown
|
||||
yield
|
||||
await engine.dispose()
|
||||
await redis_client.aclose()
|
||||
|
||||
|
||||
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
|
||||
@@ -44,104 +46,6 @@ async def health_check():
|
||||
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
|
||||
|
||||
|
||||
# @app.get("/api/v2/friends")
|
||||
# async def get_friends():
|
||||
# return JSONResponse(
|
||||
# content=[
|
||||
# {
|
||||
# "id": 123456,
|
||||
# "username": "BestFriend",
|
||||
# "is_online": True,
|
||||
# "is_supporter": False,
|
||||
# "country": {"code": "US", "name": "United States"},
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
|
||||
|
||||
# @app.get("/api/v2/notifications")
|
||||
# async def get_notifications():
|
||||
# return JSONResponse(content={"notifications": [], "unread_count": 0})
|
||||
|
||||
|
||||
# @app.post("/api/v2/chat/ack")
|
||||
# async def chat_ack():
|
||||
# return JSONResponse(content={"status": "ok"})
|
||||
|
||||
|
||||
# @app.get("/api/v2/users/{user_id}/{mode}")
|
||||
# async def get_user_mode(user_id: int, mode: str):
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "id": user_id,
|
||||
# "username": "测试测试测",
|
||||
# "statistics": {
|
||||
# "level": {"current": 97, "progress": 96},
|
||||
# "pp": 114514,
|
||||
# "global_rank": 666,
|
||||
# "country_rank": 1,
|
||||
# "hit_accuracy": 100,
|
||||
# },
|
||||
# "country": {"code": "JP", "name": "Japan"},
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
# @app.get("/api/v2/me")
|
||||
# async def get_me():
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "id": 15651670,
|
||||
# "username": "Googujiang",
|
||||
# "is_online": True,
|
||||
# "country": {"code": "JP", "name": "Japan"},
|
||||
# "statistics": {
|
||||
# "level": {"current": 97, "progress": 96},
|
||||
# "pp": 2826.26,
|
||||
# "global_rank": 298026,
|
||||
# "country_rank": 11220,
|
||||
# "hit_accuracy": 95.7168,
|
||||
# },
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
# @app.post("/signalr/metadata/negotiate")
|
||||
# async def metadata_negotiate(negotiateVersion: int = 1):
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "connectionId": "abc123",
|
||||
# "availableTransports": [
|
||||
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
|
||||
# ],
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
# @app.post("/signalr/spectator/negotiate")
|
||||
# async def spectator_negotiate(negotiateVersion: int = 1):
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "connectionId": "spec456",
|
||||
# "availableTransports": [
|
||||
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
|
||||
# ],
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
# @app.post("/signalr/multiplayer/negotiate")
|
||||
# async def multiplayer_negotiate(negotiateVersion: int = 1):
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "connectionId": "multi789",
|
||||
# "availableTransports": [
|
||||
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
|
||||
# ],
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from app.log import logger # noqa: F401
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""score: remove best_id in database
|
||||
"""beatmapset: support favourite count
|
||||
|
||||
Revision ID: 78be13c71791
|
||||
Revises: dc4d25c428c7
|
||||
Create Date: 2025-07-29 07:57:33.764517
|
||||
Revision ID: 1178d0758ebf
|
||||
Revises:
|
||||
Create Date: 2025-08-01 04:05:09.882800
|
||||
|
||||
"""
|
||||
|
||||
@@ -15,8 +15,8 @@ import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "78be13c71791"
|
||||
down_revision: str | Sequence[str] | None = "dc4d25c428c7"
|
||||
revision: str = "1178d0758ebf"
|
||||
down_revision: str | Sequence[str] | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
@@ -24,7 +24,7 @@ depends_on: str | Sequence[str] | None = None
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("scores", "best_id")
|
||||
op.drop_column("beatmapsets", "favourite_count")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -32,7 +32,9 @@ def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"scores",
|
||||
sa.Column("best_id", mysql.INTEGER(), autoincrement=False, nullable=True),
|
||||
"beatmapsets",
|
||||
sa.Column(
|
||||
"favourite_count", mysql.INTEGER(), autoincrement=False, nullable=False
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,54 @@
|
||||
"""relationship: fix unique relationship
|
||||
|
||||
Revision ID: 58a11441d302
|
||||
Revises: 1178d0758ebf
|
||||
Create Date: 2025-08-01 04:23:02.498166
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "58a11441d302"
|
||||
down_revision: str | Sequence[str] | None = "1178d0758ebf"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"relationship",
|
||||
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||
)
|
||||
op.drop_constraint("PRIMARY", "relationship", type_="primary")
|
||||
op.create_primary_key("pk_relationship", "relationship", ["id"])
|
||||
op.alter_column(
|
||||
"relationship", "user_id", existing_type=mysql.BIGINT(), nullable=True
|
||||
)
|
||||
op.alter_column(
|
||||
"relationship", "target_id", existing_type=mysql.BIGINT(), nullable=True
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("pk_relationship", "relationship", type_="primary")
|
||||
op.create_primary_key("PRIMARY", "relationship", ["user_id", "target_id"])
|
||||
op.alter_column(
|
||||
"relationship", "target_id", existing_type=mysql.BIGINT(), nullable=False
|
||||
)
|
||||
op.alter_column(
|
||||
"relationship", "user_id", existing_type=mysql.BIGINT(), nullable=False
|
||||
)
|
||||
op.drop_column("relationship", "id")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,89 @@
|
||||
"""playlist: index playlist id
|
||||
|
||||
Revision ID: d0c1b2cefe91
|
||||
Revises: 58a11441d302
|
||||
Create Date: 2025-08-06 06:02:10.512616
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d0c1b2cefe91"
|
||||
down_revision: str | Sequence[str] | None = "58a11441d302"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_index(
|
||||
op.f("ix_room_playlists_id"), "room_playlists", ["id"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"playlist_best_scores",
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("score_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("room_id", sa.Integer(), nullable=False),
|
||||
sa.Column("playlist_id", sa.Integer(), nullable=False),
|
||||
sa.Column("total_score", sa.BigInteger(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["playlist_id"],
|
||||
["room_playlists.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["room_id"],
|
||||
["rooms.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["score_id"],
|
||||
["scores.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["lazer_users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("score_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_playlist_best_scores_playlist_id"),
|
||||
"playlist_best_scores",
|
||||
["playlist_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_playlist_best_scores_room_id"),
|
||||
"playlist_best_scores",
|
||||
["room_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_playlist_best_scores_user_id"),
|
||||
"playlist_best_scores",
|
||||
["user_id"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(
|
||||
op.f("ix_playlist_best_scores_user_id"), table_name="playlist_best_scores"
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_playlist_best_scores_room_id"), table_name="playlist_best_scores"
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_playlist_best_scores_playlist_id"), table_name="playlist_best_scores"
|
||||
)
|
||||
op.drop_table("playlist_best_scores")
|
||||
op.drop_index(op.f("ix_room_playlists_id"), table_name="room_playlists")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,36 +0,0 @@
|
||||
"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator
|
||||
|
||||
Revision ID: dc4d25c428c7
|
||||
Revises:
|
||||
Create Date: 2025-07-29 01:43:40.221070
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dc4d25c428c7"
|
||||
down_revision: str | Sequence[str] | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True))
|
||||
op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("scores", "nsmall_tick_hit")
|
||||
op.drop_column("scores", "nlarge_tick_hit")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,11 +1,4 @@
|
||||
from typing import Any
|
||||
|
||||
class APIMod:
|
||||
def __init__(self, acronym: str, settings: dict[str, Any]) -> None: ...
|
||||
@property
|
||||
def acronym(self) -> str: ...
|
||||
@property
|
||||
def settings(self) -> str: ...
|
||||
|
||||
def encode(obj: Any) -> bytes: ...
|
||||
def decode(data: bytes) -> Any: ...
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use crate::APIMod;
|
||||
use chrono::{TimeZone, Utc};
|
||||
use pyo3::types::PyDict;
|
||||
use pyo3::{prelude::*, IntoPyObjectExt};
|
||||
use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
|
||||
pub fn read_object(
|
||||
@@ -13,6 +11,8 @@ pub fn read_object(
|
||||
match rmp::decode::read_marker(cursor) {
|
||||
Ok(marker) => match marker {
|
||||
rmp::Marker::Null => Ok(py.None()),
|
||||
rmp::Marker::True => Ok(true.into_py_any(py)?),
|
||||
rmp::Marker::False => Ok(false.into_py_any(py)?),
|
||||
rmp::Marker::FixPos(val) => Ok(val.into_pyobject(py)?.into_any().unbind()),
|
||||
rmp::Marker::FixNeg(val) => Ok(val.into_pyobject(py)?.into_any().unbind()),
|
||||
rmp::Marker::U8 => {
|
||||
@@ -86,8 +86,6 @@ pub fn read_object(
|
||||
cursor.read_exact(&mut data).map_err(to_py_err)?;
|
||||
Ok(data.into_pyobject(py)?.into_any().unbind())
|
||||
}
|
||||
rmp::Marker::True => Ok(true.into_py_any(py)?),
|
||||
rmp::Marker::False => Ok(false.into_py_any(py)?),
|
||||
rmp::Marker::FixStr(len) => read_string(py, cursor, len as u32),
|
||||
rmp::Marker::Str8 => {
|
||||
let mut buf = [0u8; 1];
|
||||
@@ -206,13 +204,12 @@ fn read_array(
|
||||
let obj1 = read_object(py, cursor, false)?;
|
||||
if obj1.extract::<String>(py).map_or(false, |k| k.len() == 2) {
|
||||
let obj2 = read_object(py, cursor, true)?;
|
||||
return Ok(APIMod {
|
||||
acronym: obj1.extract::<String>(py)?,
|
||||
settings: obj2.extract::<HashMap<String, PyObject>>(py)?,
|
||||
}
|
||||
.into_pyobject(py)?
|
||||
.into_any()
|
||||
.unbind());
|
||||
|
||||
let api_mod_dict = PyDict::new(py);
|
||||
api_mod_dict.set_item("acronym", obj1)?;
|
||||
api_mod_dict.set_item("settings", obj2)?;
|
||||
|
||||
return Ok(api_mod_dict.into_pyobject(py)?.into_any().unbind());
|
||||
} else {
|
||||
items.push(obj1);
|
||||
i += 1;
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use crate::APIMod;
|
||||
use chrono::{DateTime, Utc};
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyStringMethods};
|
||||
use chrono::{DateTime, Utc};
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyResult, PyStringMethods};
|
||||
use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString};
|
||||
use pyo3::{Bound, PyAny, PyRef, Python};
|
||||
use pyo3::{Bound, PyAny};
|
||||
use std::io::Write;
|
||||
|
||||
fn write_list(buf: &mut Vec<u8>, obj: &Bound<'_, PyList>) {
|
||||
@@ -61,19 +60,42 @@ fn write_hashmap(buf: &mut Vec<u8>, obj: &Bound<'_, PyDict>) {
|
||||
}
|
||||
}
|
||||
|
||||
fn write_nil(buf: &mut Vec<u8>){
|
||||
fn write_nil(buf: &mut Vec<u8>) {
|
||||
rmp::encode::write_nil(buf).unwrap();
|
||||
}
|
||||
|
||||
// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
|
||||
fn write_api_mod(buf: &mut Vec<u8>, api_mod: PyRef<APIMod>) {
|
||||
rmp::encode::write_array_len(buf, 2).unwrap();
|
||||
rmp::encode::write_str(buf, &api_mod.acronym).unwrap();
|
||||
rmp::encode::write_array_len(buf, api_mod.settings.len() as u32).unwrap();
|
||||
for (k, v) in api_mod.settings.iter() {
|
||||
rmp::encode::write_str(buf, k).unwrap();
|
||||
Python::with_gil(|py| write_object(buf, &v.bind(py)));
|
||||
fn is_api_mod(dict: &Bound<'_, PyDict>) -> bool {
|
||||
if let Ok(Some(acronym)) = dict.get_item("acronym") {
|
||||
if let Ok(acronym_str) = acronym.extract::<String>() {
|
||||
return acronym_str.len() == 2;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
|
||||
fn write_api_mod(buf: &mut Vec<u8>, api_mod: &Bound<'_, PyDict>) -> PyResult<()> {
|
||||
let acronym = api_mod
|
||||
.get_item("acronym")?
|
||||
.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("APIMod missing 'acronym' field"))?;
|
||||
let acronym_str = acronym.extract::<String>()?;
|
||||
|
||||
let settings = api_mod
|
||||
.get_item("settings")?
|
||||
.unwrap_or_else(|| PyDict::new(acronym.py()).into_any());
|
||||
let settings_dict = settings.downcast::<PyDict>()?;
|
||||
|
||||
rmp::encode::write_array_len(buf, 2).unwrap();
|
||||
rmp::encode::write_str(buf, &acronym_str).unwrap();
|
||||
rmp::encode::write_array_len(buf, settings_dict.len() as u32).unwrap();
|
||||
|
||||
for (k, v) in settings_dict.iter() {
|
||||
let key_str = k.extract::<String>()?;
|
||||
rmp::encode::write_str(buf, &key_str).unwrap();
|
||||
write_object(buf, &v);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_datetime(buf: &mut Vec<u8>, obj: &Bound<'_, PyDateTime>) {
|
||||
@@ -110,22 +132,24 @@ pub fn write_object(buf: &mut Vec<u8>, obj: &Bound<'_, PyAny>) {
|
||||
write_list(buf, list);
|
||||
} else if let Ok(string) = obj.downcast::<PyString>() {
|
||||
write_string(buf, string);
|
||||
} else if let Ok(integer) = obj.downcast::<PyInt>() {
|
||||
write_integer(buf, integer);
|
||||
} else if let Ok(float) = obj.downcast::<PyFloat>() {
|
||||
write_float(buf, float);
|
||||
} else if let Ok(boolean) = obj.downcast::<PyBool>() {
|
||||
write_bool(buf, boolean);
|
||||
} else if let Ok(float) = obj.downcast::<PyFloat>() {
|
||||
write_float(buf, float);
|
||||
} else if let Ok(integer) = obj.downcast::<PyInt>() {
|
||||
write_integer(buf, integer);
|
||||
} else if let Ok(bytes) = obj.downcast::<PyBytes>() {
|
||||
write_bin(buf, bytes);
|
||||
} else if let Ok(dict) = obj.downcast::<PyDict>() {
|
||||
write_hashmap(buf, dict);
|
||||
if is_api_mod(dict) {
|
||||
write_api_mod(buf, dict).unwrap_or_else(|_| write_hashmap(buf, dict));
|
||||
} else {
|
||||
write_hashmap(buf, dict);
|
||||
}
|
||||
} else if let Ok(_none) = obj.downcast::<PyNone>() {
|
||||
write_nil(buf);
|
||||
} else if let Ok(datetime) = obj.downcast::<PyDateTime>() {
|
||||
write_datetime(buf, datetime);
|
||||
} else if let Ok(api_mod) = obj.extract::<PyRef<APIMod>>() {
|
||||
write_api_mod(buf, api_mod);
|
||||
} else {
|
||||
panic!("Unsupported type");
|
||||
}
|
||||
|
||||
@@ -2,30 +2,6 @@ mod decode;
|
||||
mod encode;
|
||||
|
||||
use pyo3::prelude::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[pyclass]
|
||||
struct APIMod {
|
||||
#[pyo3(get, set)]
|
||||
acronym: String,
|
||||
#[pyo3(get, set)]
|
||||
settings: HashMap<String, PyObject>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl APIMod {
|
||||
#[new]
|
||||
fn new(acronym: String, settings: HashMap<String, PyObject>) -> Self {
|
||||
APIMod { acronym, settings }
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!(
|
||||
"APIMod(acronym='{}', settings={:?})",
|
||||
self.acronym, self.settings
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(name = "encode")]
|
||||
@@ -46,6 +22,5 @@ fn decode_py(py: Python, data: &[u8]) -> PyResult<PyObject> {
|
||||
fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(encode_py, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(decode_py, m)?)?;
|
||||
m.add_class::<APIMod>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
- `mods.json`: 包含了游戏中的所有可用mod的详细信息。
|
||||
- Origin: https://github.com/ppy/osu-web/blob/master/database/mods.json
|
||||
- Version: 2025/6/10 `b68c920b1db3d443b9302fdc3f86010c875fe380`
|
||||
- Version: 2025/7/30 `ff49b66b27a2850aea4b6b3ba563cfe936cb6082`
|
||||
|
||||
@@ -2438,7 +2438,8 @@
|
||||
"Settings": [],
|
||||
"IncompatibleMods": [
|
||||
"CN",
|
||||
"RX"
|
||||
"RX",
|
||||
"MF"
|
||||
],
|
||||
"RequiresConfiguration": false,
|
||||
"UserPlayable": false,
|
||||
@@ -2460,7 +2461,8 @@
|
||||
"AC",
|
||||
"AT",
|
||||
"CN",
|
||||
"RX"
|
||||
"RX",
|
||||
"MF"
|
||||
],
|
||||
"RequiresConfiguration": false,
|
||||
"UserPlayable": false,
|
||||
@@ -2477,7 +2479,8 @@
|
||||
"Settings": [],
|
||||
"IncompatibleMods": [
|
||||
"AT",
|
||||
"CN"
|
||||
"CN",
|
||||
"MF"
|
||||
],
|
||||
"RequiresConfiguration": false,
|
||||
"UserPlayable": true,
|
||||
@@ -2638,6 +2641,24 @@
|
||||
"ValidForMultiplayerAsFreeMod": true,
|
||||
"AlwaysValidForSubmission": false
|
||||
},
|
||||
{
|
||||
"Acronym": "MF",
|
||||
"Name": "Moving Fast",
|
||||
"Description": "Dashing by default, slow down!",
|
||||
"Type": "Fun",
|
||||
"Settings": [],
|
||||
"IncompatibleMods": [
|
||||
"AT",
|
||||
"CN",
|
||||
"RX"
|
||||
],
|
||||
"RequiresConfiguration": false,
|
||||
"UserPlayable": true,
|
||||
"ValidForMultiplayer": true,
|
||||
"ValidForFreestyleAsRequiredMod": false,
|
||||
"ValidForMultiplayerAsFreeMod": true,
|
||||
"AlwaysValidForSubmission": false
|
||||
},
|
||||
{
|
||||
"Acronym": "SV2",
|
||||
"Name": "Score V2",
|
||||
|
||||
Reference in New Issue
Block a user