chore(merge): merge pull request #8 from feat/multiplayer-api
feat: 增加mp房间相关接口
This commit is contained in:
@@ -16,10 +16,22 @@ from .lazer_user import (
|
||||
User,
|
||||
UserResp,
|
||||
)
|
||||
from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||
from .playlist_attempts import (
|
||||
ItemAttemptsCount,
|
||||
ItemAttemptsResp,
|
||||
PlaylistAggregateScore,
|
||||
)
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
from .playlists import Playlist, PlaylistResp
|
||||
from .pp_best_score import PPBestScore
|
||||
from .relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from .room import APIUploadedRoom, Room, RoomResp
|
||||
from .room_participated_user import RoomParticipatedUser
|
||||
from .score import (
|
||||
MultiplayerScores,
|
||||
Score,
|
||||
ScoreAround,
|
||||
ScoreBase,
|
||||
ScoreResp,
|
||||
ScoreStatistics,
|
||||
@@ -37,6 +49,7 @@ from .user_account_history import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"APIUploadedRoom",
|
||||
"Beatmap",
|
||||
"BeatmapPlaycounts",
|
||||
"BeatmapPlaycountsResp",
|
||||
@@ -46,12 +59,25 @@ __all__ = [
|
||||
"DailyChallengeStats",
|
||||
"DailyChallengeStatsResp",
|
||||
"FavouriteBeatmapset",
|
||||
"ItemAttemptsCount",
|
||||
"ItemAttemptsResp",
|
||||
"MultiplayerEvent",
|
||||
"MultiplayerEventResp",
|
||||
"MultiplayerScores",
|
||||
"OAuthToken",
|
||||
"PPBestScore",
|
||||
"Playlist",
|
||||
"PlaylistAggregateScore",
|
||||
"PlaylistBestScore",
|
||||
"PlaylistResp",
|
||||
"Relationship",
|
||||
"RelationshipResp",
|
||||
"RelationshipType",
|
||||
"Room",
|
||||
"RoomParticipatedUser",
|
||||
"RoomResp",
|
||||
"Score",
|
||||
"ScoreAround",
|
||||
"ScoreBase",
|
||||
"ScoreResp",
|
||||
"ScoreStatistics",
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.score import MODE_TO_INT, GameMode
|
||||
|
||||
from .beatmap_playcounts import BeatmapPlaycounts
|
||||
@@ -23,7 +22,7 @@ class BeatmapOwner(SQLModel):
|
||||
username: str
|
||||
|
||||
|
||||
class BeatmapBase(SQLModel, UTCBaseModel):
|
||||
class BeatmapBase(SQLModel):
|
||||
# Beatmap
|
||||
url: str
|
||||
mode: GameMode
|
||||
@@ -63,7 +62,7 @@ class BeatmapBase(SQLModel, UTCBaseModel):
|
||||
|
||||
class Beatmap(BeatmapBase, table=True):
|
||||
__tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
id: int = Field(primary_key=True, index=True)
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
beatmap_status: BeatmapRankStatus
|
||||
# optional
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .lazer_user import BASE_INCLUDES, User, UserResp
|
||||
@@ -14,6 +13,8 @@ from sqlmodel import Field, Relationship, SQLModel, col, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
|
||||
@@ -87,7 +88,7 @@ class BeatmapTranslationText(BaseModel):
|
||||
id: int | None = None
|
||||
|
||||
|
||||
class BeatmapsetBase(SQLModel, UTCBaseModel):
|
||||
class BeatmapsetBase(SQLModel):
|
||||
# Beatmapset
|
||||
artist: str = Field(index=True)
|
||||
artist_unicode: str = Field(index=True)
|
||||
@@ -186,6 +187,16 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
|
||||
return beatmapset
|
||||
|
||||
@classmethod
|
||||
async def get_or_fetch(
|
||||
cls, session: AsyncSession, fetcher: "Fetcher", sid: int
|
||||
) -> "Beatmapset":
|
||||
beatmapset = await session.get(Beatmapset, sid)
|
||||
if not beatmapset:
|
||||
resp = await fetcher.get_beatmapset(sid)
|
||||
beatmapset = await cls.from_resp(session, resp)
|
||||
return beatmapset
|
||||
|
||||
|
||||
class BeatmapsetResp(BeatmapsetBase):
|
||||
id: int
|
||||
|
||||
@@ -29,9 +29,7 @@ class BestScore(SQLModel, table=True):
|
||||
)
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
total_score: int = Field(
|
||||
default=0, sa_column=Column(BigInteger, ForeignKey("scores.total_score"))
|
||||
)
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
mods: list[str] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
|
||||
56
app/database/multiplayer_event.py
Normal file
56
app/database/multiplayer_event.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerEventBase(SQLModel, UTCBaseModel):
|
||||
playlist_item_id: int | None = None
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),
|
||||
)
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
event_type: str = Field(index=True)
|
||||
|
||||
|
||||
class MultiplayerEvent(MultiplayerEventBase, table=True):
|
||||
__tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
|
||||
)
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
updated_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
event_detail: dict[str, Any] | None = Field(
|
||||
sa_column=Column(JSON),
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerEventResp(MultiplayerEventBase):
|
||||
id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, event: MultiplayerEvent) -> "MultiplayerEventResp":
|
||||
return cls.model_validate(event)
|
||||
151
app/database/playlist_attempts.py
Normal file
151
app/database/playlist_attempts.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from .lazer_user import User, UserResp
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class ItemAttemptsCountBase(SQLModel):
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
attempts: int = Field(default=0)
|
||||
completed: int = Field(default=0)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
accuracy: float = 0.0
|
||||
pp: float = 0
|
||||
total_score: int = 0
|
||||
|
||||
|
||||
class ItemAttemptsCount(ItemAttemptsCountBase, table=True):
|
||||
__tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
user: User = Relationship()
|
||||
|
||||
async def get_position(self, session: AsyncSession) -> int:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=col(ItemAttemptsCountBase.room_id),
|
||||
order_by=col(ItemAttemptsCountBase.total_score).desc(),
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
subq = select(ItemAttemptsCountBase, rownum).subquery()
|
||||
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
|
||||
result = await session.exec(stmt)
|
||||
return result.one()
|
||||
|
||||
async def update(self, session: AsyncSession):
|
||||
playlist_scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == self.room_id,
|
||||
PlaylistBestScore.user_id == self.user_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
self.attempts = sum(score.attempts for score in playlist_scores)
|
||||
self.total_score = sum(score.total_score for score in playlist_scores)
|
||||
self.pp = sum(score.score.pp for score in playlist_scores)
|
||||
self.completed = len(playlist_scores)
|
||||
self.accuracy = (
|
||||
sum(score.score.accuracy for score in playlist_scores) / self.completed
|
||||
if self.completed > 0
|
||||
else 0.0
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(self)
|
||||
|
||||
@classmethod
|
||||
async def get_or_create(
|
||||
cls,
|
||||
room_id: int,
|
||||
user_id: int,
|
||||
session: AsyncSession,
|
||||
) -> "ItemAttemptsCount":
|
||||
item_attempts = await session.exec(
|
||||
select(cls).where(
|
||||
cls.room_id == room_id,
|
||||
cls.user_id == user_id,
|
||||
)
|
||||
)
|
||||
item_attempts = item_attempts.first()
|
||||
if item_attempts is None:
|
||||
item_attempts = cls(room_id=room_id, user_id=user_id)
|
||||
session.add(item_attempts)
|
||||
await session.commit()
|
||||
await session.refresh(item_attempts)
|
||||
await item_attempts.update(session)
|
||||
return item_attempts
|
||||
|
||||
|
||||
class ItemAttemptsResp(ItemAttemptsCountBase):
|
||||
user: UserResp | None = None
|
||||
position: int | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
item_attempts: ItemAttemptsCount,
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
) -> "ItemAttemptsResp":
|
||||
resp = cls.model_validate(item_attempts.model_dump())
|
||||
resp.user = await UserResp.from_db(
|
||||
item_attempts.user,
|
||||
session=session,
|
||||
include=["statistics", "team", "daily_challenge_user_stats"],
|
||||
)
|
||||
if "position" in include:
|
||||
resp.position = await item_attempts.get_position(session)
|
||||
# resp.accuracy *= 100
|
||||
return resp
|
||||
|
||||
|
||||
class ItemAttemptsCountForItem(BaseModel):
|
||||
id: int
|
||||
attempts: int
|
||||
passed: bool
|
||||
|
||||
|
||||
class PlaylistAggregateScore(BaseModel):
|
||||
playlist_item_attempts: list[ItemAttemptsCountForItem] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
room_id: int,
|
||||
user_id: int,
|
||||
session: AsyncSession,
|
||||
) -> "PlaylistAggregateScore":
|
||||
playlist_scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
PlaylistBestScore.user_id == user_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
playlist_item_attempts = []
|
||||
for score in playlist_scores:
|
||||
playlist_item_attempts.append(
|
||||
ItemAttemptsCountForItem(
|
||||
id=score.playlist_id,
|
||||
attempts=score.attempts,
|
||||
passed=score.score.passed,
|
||||
)
|
||||
)
|
||||
return cls(playlist_item_attempts=playlist_item_attempts)
|
||||
110
app/database/playlist_best_score.py
Normal file
110
app/database/playlist_best_score.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .lazer_user import User
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .score import Score
|
||||
|
||||
|
||||
class PlaylistBestScore(SQLModel, table=True):
|
||||
__tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
playlist_id: int = Field(foreign_key="room_playlists.id", index=True)
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
attempts: int = Field(default=0) # playlist
|
||||
|
||||
user: User = Relationship()
|
||||
score: "Score" = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "[PlaylistBestScore.score_id]",
|
||||
"lazy": "joined",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def process_playlist_best_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
user_id: int,
|
||||
score_id: int,
|
||||
total_score: int,
|
||||
session: AsyncSession,
|
||||
redis: Redis,
|
||||
):
|
||||
previous = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.user_id == user_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if previous is None:
|
||||
previous = PlaylistBestScore(
|
||||
user_id=user_id,
|
||||
score_id=score_id,
|
||||
room_id=room_id,
|
||||
playlist_id=playlist_id,
|
||||
total_score=total_score,
|
||||
)
|
||||
session.add(previous)
|
||||
elif not previous.score.passed or previous.total_score < total_score:
|
||||
previous.score_id = score_id
|
||||
previous.total_score = total_score
|
||||
previous.attempts += 1
|
||||
await session.commit()
|
||||
if await redis.exists(f"multiplayer:{room_id}:gameplay:players"):
|
||||
await redis.decr(f"multiplayer:{room_id}:gameplay:players")
|
||||
|
||||
|
||||
async def get_position(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
score_id: int,
|
||||
session: AsyncSession,
|
||||
) -> int:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=(
|
||||
col(PlaylistBestScore.playlist_id),
|
||||
col(PlaylistBestScore.room_id),
|
||||
),
|
||||
order_by=col(PlaylistBestScore.total_score).desc(),
|
||||
)
|
||||
.label("row_number")
|
||||
)
|
||||
subq = (
|
||||
select(PlaylistBestScore, rownum)
|
||||
.where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
|
||||
result = await session.exec(stmt)
|
||||
s = result.one_or_none()
|
||||
return s if s else 0
|
||||
143
app/database/playlists.py
Normal file
143
app/database/playlists.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.mods import APIMod
|
||||
from app.models.multiplayer_hub import PlaylistItem
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .room import Room
|
||||
|
||||
|
||||
class PlaylistBase(SQLModel, UTCBaseModel):
|
||||
id: int = Field(index=True)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
ruleset_id: int = Field(ge=0, le=3)
|
||||
expired: bool = Field(default=False)
|
||||
playlist_order: int = Field(default=0)
|
||||
played_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True)),
|
||||
default=None,
|
||||
)
|
||||
allowed_mods: list[APIMod] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
required_mods: list[APIMod] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
beatmap_id: int = Field(
|
||||
foreign_key="beatmaps.id",
|
||||
)
|
||||
freestyle: bool = Field(default=False)
|
||||
|
||||
|
||||
class Playlist(PlaylistBase, table=True):
|
||||
__tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType]
|
||||
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
|
||||
room_id: int = Field(foreign_key="rooms.id", exclude=True)
|
||||
|
||||
beatmap: Beatmap = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"lazy": "joined",
|
||||
}
|
||||
)
|
||||
room: "Room" = Relationship()
|
||||
|
||||
@classmethod
|
||||
async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
|
||||
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(
|
||||
cls.room_id == room_id
|
||||
)
|
||||
result = await session.exec(stmt)
|
||||
return result.one()
|
||||
|
||||
@classmethod
|
||||
async def from_hub(
|
||||
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
|
||||
) -> "Playlist":
|
||||
next_id = await cls.get_next_id_for_room(room_id, session=session)
|
||||
return cls(
|
||||
id=next_id,
|
||||
owner_id=playlist.owner_id,
|
||||
ruleset_id=playlist.ruleset_id,
|
||||
beatmap_id=playlist.beatmap_id,
|
||||
required_mods=playlist.required_mods,
|
||||
allowed_mods=playlist.allowed_mods,
|
||||
expired=playlist.expired,
|
||||
playlist_order=playlist.playlist_order,
|
||||
played_at=playlist.played_at,
|
||||
freestyle=playlist.freestyle,
|
||||
room_id=room_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(
|
||||
select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
|
||||
)
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
raise ValueError("Playlist item not found")
|
||||
db_playlist.owner_id = playlist.owner_id
|
||||
db_playlist.ruleset_id = playlist.ruleset_id
|
||||
db_playlist.beatmap_id = playlist.beatmap_id
|
||||
db_playlist.required_mods = playlist.required_mods
|
||||
db_playlist.allowed_mods = playlist.allowed_mods
|
||||
db_playlist.expired = playlist.expired
|
||||
db_playlist.playlist_order = playlist.playlist_order
|
||||
db_playlist.played_at = playlist.played_at
|
||||
db_playlist.freestyle = playlist.freestyle
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def add_to_db(
|
||||
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
|
||||
):
|
||||
db_playlist = await cls.from_hub(playlist, room_id, session)
|
||||
session.add(db_playlist)
|
||||
await session.commit()
|
||||
await session.refresh(db_playlist)
|
||||
playlist.id = db_playlist.id
|
||||
|
||||
@classmethod
|
||||
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(
|
||||
select(cls).where(cls.id == item_id, cls.room_id == room_id)
|
||||
)
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
raise ValueError("Playlist item not found")
|
||||
await session.delete(db_playlist)
|
||||
await session.commit()
|
||||
|
||||
|
||||
class PlaylistResp(PlaylistBase):
|
||||
beatmap: BeatmapResp | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, playlist: Playlist, include: list[str] = []
|
||||
) -> "PlaylistResp":
|
||||
data = playlist.model_dump()
|
||||
if "beatmap" in include:
|
||||
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
|
||||
resp = cls.model_validate(data)
|
||||
return resp
|
||||
@@ -1,6 +1,177 @@
|
||||
from sqlmodel import Field, SQLModel
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.database.playlist_attempts import PlaylistAggregateScore
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.multiplayer_hub import ServerMultiplayerRoom
|
||||
from app.models.room import (
|
||||
MatchType,
|
||||
QueueMode,
|
||||
RoomCategory,
|
||||
RoomDifficultyRange,
|
||||
RoomPlaylistItemStats,
|
||||
RoomStatus,
|
||||
)
|
||||
|
||||
from .lazer_user import User, UserResp
|
||||
from .playlists import Playlist, PlaylistResp
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class RoomIndex(SQLModel, table=True):
|
||||
__tablename__ = "mp_room_index" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(default=None, primary_key=True, index=True) # pyright: ignore[reportCallIssue]
|
||||
class RoomBase(SQLModel, UTCBaseModel):
|
||||
name: str = Field(index=True)
|
||||
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
|
||||
duration: int | None = Field(default=None) # minutes
|
||||
starts_at: datetime | None = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
ends_at: datetime | None = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
participant_count: int = Field(default=0)
|
||||
max_attempts: int | None = Field(default=None) # playlists
|
||||
type: MatchType
|
||||
queue_mode: QueueMode
|
||||
auto_skip: bool
|
||||
auto_start_duration: int
|
||||
status: RoomStatus
|
||||
# TODO: channel_id
|
||||
|
||||
|
||||
class Room(AsyncAttrs, RoomBase, table=True):
|
||||
__tablename__ = "rooms" # pyright: ignore[reportAssignmentType]
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
host_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
|
||||
host: User = Relationship()
|
||||
playlist: list[Playlist] = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"lazy": "selectin",
|
||||
"cascade": "all, delete-orphan",
|
||||
"overlaps": "room",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RoomResp(RoomBase):
|
||||
id: int
|
||||
has_password: bool = False
|
||||
host: UserResp | None = None
|
||||
playlist: list[PlaylistResp] = []
|
||||
playlist_item_stats: RoomPlaylistItemStats | None = None
|
||||
difficulty_range: RoomDifficultyRange | None = None
|
||||
current_playlist_item: PlaylistResp | None = None
|
||||
current_user_score: PlaylistAggregateScore | None = None
|
||||
recent_participants: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
room: Room,
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
user: User | None = None,
|
||||
) -> "RoomResp":
|
||||
resp = cls.model_validate(room.model_dump())
|
||||
|
||||
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
|
||||
difficulty_range = RoomDifficultyRange(
|
||||
min=0,
|
||||
max=0,
|
||||
)
|
||||
rulesets = set()
|
||||
for playlist in room.playlist:
|
||||
stats.count_total += 1
|
||||
if not playlist.expired:
|
||||
stats.count_active += 1
|
||||
rulesets.add(playlist.ruleset_id)
|
||||
difficulty_range.min = min(
|
||||
difficulty_range.min, playlist.beatmap.difficulty_rating
|
||||
)
|
||||
difficulty_range.max = max(
|
||||
difficulty_range.max, playlist.beatmap.difficulty_rating
|
||||
)
|
||||
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
|
||||
stats.ruleset_ids = list(rulesets)
|
||||
resp.playlist_item_stats = stats
|
||||
resp.difficulty_range = difficulty_range
|
||||
resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
|
||||
resp.recent_participants = []
|
||||
for recent_participant in await session.exec(
|
||||
select(RoomParticipatedUser)
|
||||
.where(
|
||||
RoomParticipatedUser.room_id == room.id,
|
||||
col(RoomParticipatedUser.left_at).is_(None),
|
||||
)
|
||||
.limit(8)
|
||||
.order_by(col(RoomParticipatedUser.joined_at).desc())
|
||||
):
|
||||
resp.recent_participants.append(
|
||||
await UserResp.from_db(
|
||||
await recent_participant.awaitable_attrs.user,
|
||||
session,
|
||||
include=["statistics"],
|
||||
)
|
||||
)
|
||||
resp.host = await UserResp.from_db(
|
||||
await room.awaitable_attrs.host, session, include=["statistics"]
|
||||
)
|
||||
if "current_user_score" in include and user:
|
||||
resp.current_user_score = await PlaylistAggregateScore.from_db(
|
||||
room.id, user.id, session
|
||||
)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def from_hub(cls, server_room: ServerMultiplayerRoom) -> "RoomResp":
|
||||
room = server_room.room
|
||||
resp = cls(
|
||||
id=room.room_id,
|
||||
name=room.settings.name,
|
||||
type=room.settings.match_type,
|
||||
queue_mode=room.settings.queue_mode,
|
||||
auto_skip=room.settings.auto_skip,
|
||||
auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
|
||||
status=server_room.status,
|
||||
category=server_room.category,
|
||||
# duration = room.settings.duration,
|
||||
starts_at=server_room.start_at,
|
||||
participant_count=len(room.users),
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class APIUploadedRoom(RoomBase):
|
||||
def to_room(self) -> Room:
|
||||
"""
|
||||
将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。
|
||||
"""
|
||||
room_dict = self.model_dump()
|
||||
room_dict.pop("playlist", None)
|
||||
# host_id 已在字段中
|
||||
return Room(**room_dict)
|
||||
|
||||
id: int | None
|
||||
host_id: int | None = None
|
||||
playlist: list[Playlist] = Field(default_factory=list)
|
||||
|
||||
39
app/database/room_participated_user.py
Normal file
39
app/database/room_participated_user.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
from .room import Room
|
||||
|
||||
|
||||
class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True):
|
||||
__tablename__ = "room_participated_users" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
)
|
||||
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False))
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False)
|
||||
)
|
||||
joined_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
left_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True), default=None
|
||||
)
|
||||
|
||||
room: "Room" = Relationship()
|
||||
user: "User" = Relationship()
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Sequence
|
||||
from datetime import UTC, date, datetime
|
||||
import json
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.calculator import (
|
||||
calculate_pp,
|
||||
@@ -14,7 +14,7 @@ from app.calculator import (
|
||||
clamp,
|
||||
)
|
||||
from app.database.team import TeamMember
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.model import RespWithCursor, UTCBaseModel
|
||||
from app.models.mods import APIMod, mods_can_get_pp
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
@@ -89,10 +89,11 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
default=0, sa_column=Column(BigInteger), exclude=True
|
||||
)
|
||||
type: str
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
|
||||
# optional
|
||||
# TODO: current_user_attributes
|
||||
position: int | None = Field(default=None) # multiplayer
|
||||
# position: int | None = Field(default=None) # multiplayer
|
||||
|
||||
|
||||
class Score(ScoreBase, table=True):
|
||||
@@ -100,7 +101,6 @@ class Score(ScoreBase, table=True):
|
||||
id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
|
||||
)
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
@@ -163,6 +163,8 @@ class ScoreResp(ScoreBase):
|
||||
maximum_statistics: ScoreStatistics | None = None
|
||||
rank_global: int | None = None
|
||||
rank_country: int | None = None
|
||||
position: int | None = None
|
||||
scores_around: "ScoreAround | None" = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
||||
@@ -234,6 +236,16 @@ class ScoreResp(ScoreBase):
|
||||
return s
|
||||
|
||||
|
||||
class MultiplayerScores(RespWithCursor):
|
||||
scores: list[ScoreResp] = Field(default_factory=list)
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ScoreAround(SQLModel):
|
||||
higher: MultiplayerScores | None = None
|
||||
lower: MultiplayerScores | None = None
|
||||
|
||||
|
||||
async def get_best_id(session: AsyncSession, score_id: int) -> None:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
@@ -329,6 +341,10 @@ async def get_leaderboard(
|
||||
self_query = (
|
||||
select(BestScore)
|
||||
.where(BestScore.user_id == user.id)
|
||||
.where(
|
||||
col(BestScore.beatmap_id) == beatmap,
|
||||
col(BestScore.gamemode) == mode,
|
||||
)
|
||||
.order_by(col(BestScore.total_score).desc())
|
||||
.limit(1)
|
||||
)
|
||||
@@ -616,6 +632,8 @@ async def process_score(
|
||||
fetcher: "Fetcher",
|
||||
session: AsyncSession,
|
||||
redis: Redis,
|
||||
item_id: int | None = None,
|
||||
room_id: int | None = None,
|
||||
) -> Score:
|
||||
assert user.id
|
||||
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
|
||||
@@ -647,6 +665,8 @@ async def process_score(
|
||||
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
|
||||
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
|
||||
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
|
||||
playlist_item_id=item_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
if can_get_pp:
|
||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
@@ -678,4 +698,5 @@ async def process_score(
|
||||
await session.refresh(score)
|
||||
await session.refresh(score_token)
|
||||
await session.refresh(user)
|
||||
await redis.publish("score:processed", score.id)
|
||||
return score
|
||||
|
||||
@@ -38,3 +38,7 @@ async def create_tables():
|
||||
# Redis 依赖
|
||||
def get_redis():
|
||||
return redis_client
|
||||
|
||||
|
||||
def get_redis_pubsub():
|
||||
return redis_client.pubsub()
|
||||
|
||||
26
app/dependencies/scheduler.py
Normal file
26
app/dependencies/scheduler.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
scheduler: AsyncIOScheduler | None = None
|
||||
|
||||
|
||||
def init_scheduler():
|
||||
global scheduler
|
||||
scheduler = AsyncIOScheduler(timezone=UTC)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def get_scheduler() -> AsyncIOScheduler:
|
||||
global scheduler
|
||||
if scheduler is None:
|
||||
init_scheduler()
|
||||
return scheduler # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
def stop_scheduler():
|
||||
global scheduler
|
||||
if scheduler:
|
||||
scheduler.shutdown()
|
||||
@@ -3,10 +3,12 @@ from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState
|
||||
from app.models.signalr import SignalRUnionMessage, UserState
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS = 13
|
||||
|
||||
|
||||
class _UserActivity(SignalRUnionMessage): ...
|
||||
|
||||
@@ -96,16 +98,14 @@ UserActivity = (
|
||||
| ModdingBeatmap
|
||||
| TestingBeatmap
|
||||
| InDailyChallengeLobby
|
||||
| PlayingDailyChallenge
|
||||
)
|
||||
|
||||
|
||||
class UserPresence(BaseModel):
|
||||
activity: UserActivity | None = Field(
|
||||
default=None, metadata=SignalRMeta(use_upper_case=True)
|
||||
)
|
||||
status: OnlineStatus | None = Field(
|
||||
default=None, metadata=SignalRMeta(use_upper_case=True)
|
||||
)
|
||||
activity: UserActivity | None = None
|
||||
|
||||
status: OnlineStatus | None = None
|
||||
|
||||
@property
|
||||
def pushable(self) -> bool:
|
||||
@@ -126,3 +126,34 @@ class OnlineStatus(IntEnum):
|
||||
OFFLINE = 0 # 隐身
|
||||
DO_NOT_DISTURB = 1
|
||||
ONLINE = 2
|
||||
|
||||
|
||||
class DailyChallengeInfo(BaseModel):
|
||||
room_id: int
|
||||
|
||||
|
||||
class MultiplayerPlaylistItemStats(BaseModel):
|
||||
playlist_item_id: int = 0
|
||||
total_score_distribution: list[int] = Field(
|
||||
default_factory=list,
|
||||
min_length=TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
max_length=TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
)
|
||||
cumulative_score: int = 0
|
||||
last_processed_score_id: int = 0
|
||||
|
||||
|
||||
class MultiplayerRoomStats(BaseModel):
|
||||
room_id: int
|
||||
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerRoomScoreSetEvent(BaseModel):
|
||||
room_id: int
|
||||
playlist_item_id: int
|
||||
score_id: int
|
||||
user_id: int
|
||||
total_score: int
|
||||
new_rank: int | None = None
|
||||
|
||||
@@ -13,3 +13,10 @@ class UTCBaseModel(BaseModel):
|
||||
v = v.replace(tzinfo=UTC)
|
||||
return v.astimezone(UTC).isoformat()
|
||||
return v
|
||||
|
||||
|
||||
Cursor = dict[str, int]
|
||||
|
||||
|
||||
class RespWithCursor(BaseModel):
|
||||
cursor: Cursor | None = None
|
||||
|
||||
@@ -8,7 +8,7 @@ from app.path import STATIC_DIR
|
||||
|
||||
class APIMod(TypedDict):
|
||||
acronym: str
|
||||
settings: NotRequired[dict[str, bool | float | str]]
|
||||
settings: NotRequired[dict[str, bool | float | str | int]]
|
||||
|
||||
|
||||
# https://github.com/ppy/osu-api/wiki#mods
|
||||
|
||||
926
app/models/multiplayer_hub.py
Normal file
926
app/models/multiplayer_hub.py
Normal file
@@ -0,0 +1,926 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import IntEnum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
TypedDict,
|
||||
cast,
|
||||
override,
|
||||
)
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.dependencies.database import engine
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.exception import InvokeException
|
||||
|
||||
from .mods import APIMod
|
||||
from .room import (
|
||||
DownloadState,
|
||||
MatchType,
|
||||
MultiplayerRoomState,
|
||||
MultiplayerUserState,
|
||||
QueueMode,
|
||||
RoomCategory,
|
||||
RoomStatus,
|
||||
)
|
||||
from .signalr import (
|
||||
SignalRMeta,
|
||||
SignalRUnionMessage,
|
||||
UserState,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import col
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.signalr.hub import MultiplayerHub
|
||||
|
||||
HOST_LIMIT = 50
|
||||
PER_USER_LIMIT = 3
|
||||
|
||||
|
||||
class MultiplayerClientState(UserState):
|
||||
room_id: int = 0
|
||||
|
||||
|
||||
class MultiplayerRoomSettings(BaseModel):
|
||||
name: str = "Unnamed Room"
|
||||
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||
password: str = ""
|
||||
match_type: MatchType = MatchType.HEAD_TO_HEAD
|
||||
queue_mode: QueueMode = QueueMode.HOST_ONLY
|
||||
auto_start_duration: timedelta = timedelta(seconds=0)
|
||||
auto_skip: bool = False
|
||||
|
||||
@property
|
||||
def auto_start_enabled(self) -> bool:
|
||||
return self.auto_start_duration != timedelta(seconds=0)
|
||||
|
||||
|
||||
class BeatmapAvailability(BaseModel):
|
||||
state: DownloadState = DownloadState.UNKNOWN
|
||||
download_progress: float | None = None
|
||||
|
||||
|
||||
class _MatchUserState(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class TeamVersusUserState(_MatchUserState):
|
||||
team_id: int
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
MatchUserState = TeamVersusUserState
|
||||
|
||||
|
||||
class _MatchRoomState(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class MultiplayerTeam(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class TeamVersusRoomState(_MatchRoomState):
|
||||
teams: list[MultiplayerTeam] = Field(
|
||||
default_factory=lambda: [
|
||||
MultiplayerTeam(id=0, name="Team Red"),
|
||||
MultiplayerTeam(id=1, name="Team Blue"),
|
||||
]
|
||||
)
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
MatchRoomState = TeamVersusRoomState
|
||||
|
||||
|
||||
class PlaylistItem(BaseModel):
|
||||
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||
owner_id: int
|
||||
beatmap_id: int
|
||||
beatmap_checksum: str
|
||||
ruleset_id: int
|
||||
required_mods: list[APIMod] = Field(default_factory=list)
|
||||
allowed_mods: list[APIMod] = Field(default_factory=list)
|
||||
expired: bool
|
||||
playlist_order: int
|
||||
played_at: datetime | None = None
|
||||
star_rating: float
|
||||
freestyle: bool
|
||||
|
||||
def _get_api_mods(self):
|
||||
from app.models.mods import API_MODS, init_mods
|
||||
|
||||
if not API_MODS:
|
||||
init_mods()
|
||||
return API_MODS
|
||||
|
||||
def _validate_mod_for_ruleset(
|
||||
self, mod: APIMod, ruleset_key: int, context: str = "mod"
|
||||
) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
# Check if mod is valid for ruleset
|
||||
if (
|
||||
typed_ruleset_key not in API_MODS
|
||||
or mod["acronym"] not in API_MODS[typed_ruleset_key]
|
||||
):
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is invalid for this ruleset"
|
||||
)
|
||||
|
||||
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
|
||||
|
||||
# Check if mod is unplayable in multiplayer
|
||||
if mod_settings.get("UserPlayable", True) is False:
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is not playable by users"
|
||||
)
|
||||
|
||||
if mod_settings.get("ValidForMultiplayer", True) is False:
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is not valid for multiplayer"
|
||||
)
|
||||
|
||||
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
for i, mod1 in enumerate(mods):
|
||||
mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
|
||||
if mod1_settings:
|
||||
incompatible = set(mod1_settings.get("IncompatibleMods", []))
|
||||
for mod2 in mods[i + 1 :]:
|
||||
if mod2["acronym"] in incompatible:
|
||||
raise InvokeException(
|
||||
f"Mods {mod1['acronym']} and "
|
||||
f"{mod2['acronym']} are incompatible"
|
||||
)
|
||||
|
||||
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
|
||||
for req_mod in self.required_mods:
|
||||
req_acronym = req_mod["acronym"]
|
||||
req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
|
||||
if req_settings:
|
||||
incompatible = set(req_settings.get("IncompatibleMods", []))
|
||||
conflicting_allowed = allowed_acronyms & incompatible
|
||||
if conflicting_allowed:
|
||||
conflict_list = ", ".join(conflicting_allowed)
|
||||
raise InvokeException(
|
||||
f"Required mod {req_acronym} conflicts with "
|
||||
f"allowed mods: {conflict_list}"
|
||||
)
|
||||
|
||||
def validate_playlist_item_mods(self) -> None:
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
|
||||
|
||||
# Validate required mods
|
||||
for mod in self.required_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
|
||||
|
||||
# Validate allowed mods
|
||||
for mod in self.allowed_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
|
||||
|
||||
# Check internal compatibility of required mods
|
||||
self._check_mod_compatibility(self.required_mods, ruleset_key)
|
||||
|
||||
# Check compatibility between required and allowed mods
|
||||
self._check_required_allowed_compatibility(ruleset_key)
|
||||
|
||||
def validate_user_mods(
|
||||
self,
|
||||
user: "MultiplayerRoomUser",
|
||||
proposed_mods: list[APIMod],
|
||||
) -> tuple[bool, list[APIMod]]:
|
||||
"""
|
||||
Validates user mods against playlist item rules and returns valid mods.
|
||||
Returns (is_valid, valid_mods).
|
||||
"""
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
|
||||
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
|
||||
|
||||
valid_mods = []
|
||||
all_proposed_valid = True
|
||||
|
||||
# Check if mods are valid for the ruleset
|
||||
for mod in proposed_mods:
|
||||
if (
|
||||
ruleset_key not in API_MODS
|
||||
or mod["acronym"] not in API_MODS[ruleset_key]
|
||||
):
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
valid_mods.append(mod)
|
||||
|
||||
# Check mod compatibility within user mods
|
||||
incompatible_mods = set()
|
||||
final_valid_mods = []
|
||||
for mod in valid_mods:
|
||||
if mod["acronym"] in incompatible_mods:
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
|
||||
if setting_mods:
|
||||
incompatible_mods.update(setting_mods["IncompatibleMods"])
|
||||
final_valid_mods.append(mod)
|
||||
|
||||
# If not freestyle, check against allowed mods
|
||||
if not self.freestyle:
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
filtered_valid_mods = []
|
||||
for mod in final_valid_mods:
|
||||
if mod["acronym"] not in allowed_acronyms:
|
||||
all_proposed_valid = False
|
||||
else:
|
||||
filtered_valid_mods.append(mod)
|
||||
final_valid_mods = filtered_valid_mods
|
||||
|
||||
# Check compatibility with required mods
|
||||
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
|
||||
all_mod_acronyms = {
|
||||
mod["acronym"] for mod in final_valid_mods
|
||||
} | required_mod_acronyms
|
||||
|
||||
# Check for incompatibility between required and user mods
|
||||
filtered_valid_mods = []
|
||||
for mod in final_valid_mods:
|
||||
mod_acronym = mod["acronym"]
|
||||
is_compatible = True
|
||||
|
||||
for other_acronym in all_mod_acronyms:
|
||||
if other_acronym == mod_acronym:
|
||||
continue
|
||||
setting_mods = API_MODS[ruleset_key].get(mod_acronym)
|
||||
if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
|
||||
is_compatible = False
|
||||
all_proposed_valid = False
|
||||
break
|
||||
|
||||
if is_compatible:
|
||||
filtered_valid_mods.append(mod)
|
||||
|
||||
return all_proposed_valid, filtered_valid_mods
|
||||
|
||||
def clone(self) -> "PlaylistItem":
|
||||
copy = self.model_copy()
|
||||
copy.required_mods = list(self.required_mods)
|
||||
copy.allowed_mods = list(self.allowed_mods)
|
||||
copy.expired = False
|
||||
copy.played_at = None
|
||||
return copy
|
||||
|
||||
|
||||
class _MultiplayerCountdown(SignalRUnionMessage):
|
||||
id: int = 0
|
||||
time_remaining: timedelta
|
||||
is_exclusive: Annotated[
|
||||
bool, Field(default=True), SignalRMeta(member_ignore=True)
|
||||
] = True
|
||||
|
||||
|
||||
class MatchStartCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
class ForceGameplayStartCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
|
||||
|
||||
class ServerShuttingDownCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
|
||||
|
||||
MultiplayerCountdown = (
|
||||
MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerRoomUser(BaseModel):
|
||||
user_id: int
|
||||
state: MultiplayerUserState = MultiplayerUserState.IDLE
|
||||
availability: BeatmapAvailability = BeatmapAvailability(
|
||||
state=DownloadState.UNKNOWN, download_progress=None
|
||||
)
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
match_state: MatchUserState | None = None
|
||||
ruleset_id: int | None = None # freestyle
|
||||
beatmap_id: int | None = None # freestyle
|
||||
|
||||
|
||||
class MultiplayerRoom(BaseModel):
|
||||
room_id: int
|
||||
state: MultiplayerRoomState
|
||||
settings: MultiplayerRoomSettings
|
||||
users: list[MultiplayerRoomUser] = Field(default_factory=list)
|
||||
host: MultiplayerRoomUser | None = None
|
||||
match_state: MatchRoomState | None = None
|
||||
playlist: list[PlaylistItem] = Field(default_factory=list)
|
||||
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
|
||||
channel_id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, room) -> "MultiplayerRoom":
|
||||
"""
|
||||
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
|
||||
"""
|
||||
|
||||
# 用户列表
|
||||
users = [MultiplayerRoomUser(user_id=room.host_id)]
|
||||
host_user = MultiplayerRoomUser(user_id=room.host_id)
|
||||
# playlist 转换
|
||||
playlist = []
|
||||
if hasattr(room, "playlist"):
|
||||
for item in room.playlist:
|
||||
playlist.append(
|
||||
PlaylistItem(
|
||||
id=item.id,
|
||||
owner_id=item.owner_id,
|
||||
beatmap_id=item.beatmap_id,
|
||||
beatmap_checksum=item.beatmap.checksum if item.beatmap else "",
|
||||
ruleset_id=item.ruleset_id,
|
||||
required_mods=item.required_mods,
|
||||
allowed_mods=item.allowed_mods,
|
||||
expired=item.expired,
|
||||
playlist_order=item.playlist_order,
|
||||
played_at=item.played_at,
|
||||
star_rating=item.beatmap.difficulty_rating
|
||||
if item.beatmap is not None
|
||||
else 0.0,
|
||||
freestyle=item.freestyle,
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
room_id=room.id,
|
||||
state=getattr(room, "state", MultiplayerRoomState.OPEN),
|
||||
settings=MultiplayerRoomSettings(
|
||||
name=room.name,
|
||||
playlist_item_id=playlist[0].id if playlist else 0,
|
||||
password=getattr(room, "password", ""),
|
||||
match_type=room.type,
|
||||
queue_mode=room.queue_mode,
|
||||
auto_start_duration=timedelta(seconds=room.auto_start_duration),
|
||||
auto_skip=room.auto_skip,
|
||||
),
|
||||
users=users,
|
||||
host=host_user,
|
||||
match_state=None,
|
||||
playlist=playlist,
|
||||
active_countdowns=[],
|
||||
channel_id=getattr(room, "channel_id", 0),
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerQueue:
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
self.server_room = room
|
||||
self.current_index = 0
|
||||
|
||||
@property
|
||||
def hub(self) -> "MultiplayerHub":
|
||||
return self.server_room.hub
|
||||
|
||||
@property
|
||||
def upcoming_items(self):
|
||||
return sorted(
|
||||
(item for item in self.room.playlist if not item.expired),
|
||||
key=lambda i: i.playlist_order,
|
||||
)
|
||||
|
||||
@property
|
||||
def room(self):
|
||||
return self.server_room.room
|
||||
|
||||
async def update_order(self):
|
||||
from app.database import Playlist
|
||||
|
||||
match self.room.settings.queue_mode:
|
||||
case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
|
||||
ordered_active_items = []
|
||||
|
||||
is_first_set = True
|
||||
first_set_order_by_user_id = {}
|
||||
|
||||
active_items = [item for item in self.room.playlist if not item.expired]
|
||||
active_items.sort(key=lambda x: x.id)
|
||||
|
||||
user_item_groups = {}
|
||||
for item in active_items:
|
||||
if item.owner_id not in user_item_groups:
|
||||
user_item_groups[item.owner_id] = []
|
||||
user_item_groups[item.owner_id].append(item)
|
||||
|
||||
max_items = max(
|
||||
(len(items) for items in user_item_groups.values()), default=0
|
||||
)
|
||||
|
||||
for i in range(max_items):
|
||||
current_set = []
|
||||
for user_id, items in user_item_groups.items():
|
||||
if i < len(items):
|
||||
current_set.append(items[i])
|
||||
|
||||
if is_first_set:
|
||||
current_set.sort(
|
||||
key=lambda item: (item.playlist_order, item.id)
|
||||
)
|
||||
ordered_active_items.extend(current_set)
|
||||
first_set_order_by_user_id = {
|
||||
item.owner_id: idx
|
||||
for idx, item in enumerate(ordered_active_items)
|
||||
}
|
||||
else:
|
||||
current_set.sort(
|
||||
key=lambda item: first_set_order_by_user_id.get(
|
||||
item.owner_id, 0
|
||||
)
|
||||
)
|
||||
ordered_active_items.extend(current_set)
|
||||
|
||||
is_first_set = False
|
||||
case _:
|
||||
ordered_active_items = sorted(
|
||||
(item for item in self.room.playlist if not item.expired),
|
||||
key=lambda x: x.id,
|
||||
)
|
||||
async with AsyncSession(engine) as session:
|
||||
for idx, item in enumerate(ordered_active_items):
|
||||
if item.playlist_order == idx:
|
||||
continue
|
||||
item.playlist_order = idx
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
await self.hub.playlist_changed(
|
||||
self.server_room, item, beatmap_changed=False
|
||||
)
|
||||
|
||||
async def update_current_item(self):
|
||||
upcoming_items = self.upcoming_items
|
||||
next_item = (
|
||||
upcoming_items[0]
|
||||
if upcoming_items
|
||||
else max(
|
||||
self.room.playlist,
|
||||
key=lambda i: i.played_at or datetime.min,
|
||||
)
|
||||
)
|
||||
self.current_index = self.room.playlist.index(next_item)
|
||||
last_id = self.room.settings.playlist_item_id
|
||||
self.room.settings.playlist_item_id = next_item.id
|
||||
if last_id != next_item.id:
|
||||
await self.hub.setting_changed(self.server_room, True)
|
||||
|
||||
async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
is_host = self.room.host and self.room.host.user_id == user.user_id
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
|
||||
raise InvokeException("You are not the host")
|
||||
|
||||
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
|
||||
if (
|
||||
len([True for u in self.room.playlist if u.owner_id == user.user_id])
|
||||
>= limit
|
||||
):
|
||||
raise InvokeException(f"You can only have {limit} items in the queue")
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
fetcher = await get_fetcher()
|
||||
async with session:
|
||||
beatmap = await Beatmap.get_or_fetch(
|
||||
session, fetcher, bid=item.beatmap_id
|
||||
)
|
||||
if beatmap is None:
|
||||
raise InvokeException("Beatmap not found")
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = float(
|
||||
beatmap.difficulty_rating
|
||||
) # FIXME: beatmap use decimal
|
||||
await Playlist.add_to_db(item, self.room.room_id, session)
|
||||
self.room.playlist.append(item)
|
||||
await self.hub.playlist_added(self.server_room, item)
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
|
||||
async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
fetcher = await get_fetcher()
|
||||
async with session:
|
||||
beatmap = await Beatmap.get_or_fetch(
|
||||
session, fetcher, bid=item.beatmap_id
|
||||
)
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
existing_item = next(
|
||||
(i for i in self.room.playlist if i.id == item.id), None
|
||||
)
|
||||
if existing_item is None:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item that doesn't exist"
|
||||
)
|
||||
|
||||
if existing_item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item which is not owned by the user"
|
||||
)
|
||||
|
||||
if existing_item.expired:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item which has already been played"
|
||||
)
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = float(beatmap.difficulty_rating)
|
||||
item.playlist_order = existing_item.playlist_order
|
||||
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
|
||||
# Update item in playlist
|
||||
for idx, playlist_item in enumerate(self.room.playlist):
|
||||
if playlist_item.id == item.id:
|
||||
self.room.playlist[idx] = item
|
||||
break
|
||||
|
||||
await self.hub.playlist_changed(
|
||||
self.server_room,
|
||||
item,
|
||||
beatmap_changed=item.beatmap_checksum
|
||||
!= existing_item.beatmap_checksum,
|
||||
)
|
||||
|
||||
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
item = next(
|
||||
(i for i in self.room.playlist if i.id == playlist_item_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if item is None:
|
||||
raise InvokeException("Item does not exist in the room")
|
||||
|
||||
# Check if it's the only item and current item
|
||||
if item == self.current_item:
|
||||
upcoming_items = [i for i in self.room.playlist if not i.expired]
|
||||
if len(upcoming_items) == 1:
|
||||
raise InvokeException("The only item in the room cannot be removed")
|
||||
|
||||
if item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException(
|
||||
"Attempted to remove an item which is not owned by the user"
|
||||
)
|
||||
|
||||
if item.expired:
|
||||
raise InvokeException(
|
||||
"Attempted to remove an item which has already been played"
|
||||
)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
await Playlist.delete_item(item.id, self.room.room_id, session)
|
||||
|
||||
self.room.playlist.remove(item)
|
||||
self.current_index = self.room.playlist.index(self.upcoming_items[0])
|
||||
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
await self.hub.playlist_removed(self.server_room, item.id)
|
||||
|
||||
async def finish_current_item(self):
|
||||
from app.database import Playlist
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
played_at = datetime.now(UTC)
|
||||
await session.execute(
|
||||
update(Playlist)
|
||||
.where(
|
||||
col(Playlist.id) == self.current_item.id,
|
||||
col(Playlist.room_id) == self.room.room_id,
|
||||
)
|
||||
.values(expired=True, played_at=played_at)
|
||||
)
|
||||
self.room.playlist[self.current_index].expired = True
|
||||
self.room.playlist[self.current_index].played_at = played_at
|
||||
await self.hub.playlist_changed(self.server_room, self.current_item, True)
|
||||
await self.update_order()
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||
playitem.expired for playitem in self.room.playlist
|
||||
):
|
||||
assert self.room.host
|
||||
await self.add_item(self.current_item.clone(), self.room.host)
|
||||
await self.update_current_item()
|
||||
|
||||
async def update_queue_mode(self):
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||
playitem.expired for playitem in self.room.playlist
|
||||
):
|
||||
assert self.room.host
|
||||
await self.add_item(self.current_item.clone(), self.room.host)
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
|
||||
@property
|
||||
def current_item(self):
|
||||
return self.room.playlist[self.current_index]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountdownInfo:
|
||||
countdown: MultiplayerCountdown
|
||||
duration: timedelta
|
||||
task: asyncio.Task | None = None
|
||||
|
||||
def __init__(self, countdown: MultiplayerCountdown):
|
||||
self.countdown = countdown
|
||||
self.duration = (
|
||||
countdown.time_remaining
|
||||
if countdown.time_remaining > timedelta(seconds=0)
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
|
||||
|
||||
class _MatchRequest(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class ChangeTeamRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
team_id: int
|
||||
|
||||
|
||||
class StartMatchCountdownRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
duration: timedelta
|
||||
|
||||
|
||||
class StopCountdownRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
id: int
|
||||
|
||||
|
||||
MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
|
||||
|
||||
|
||||
class MatchTypeHandler(ABC):
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
self.room = room
|
||||
self.hub = room.hub
|
||||
|
||||
@abstractmethod
|
||||
async def handle_join(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_request(
|
||||
self, user: MultiplayerRoomUser, request: MatchRequest
|
||||
): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_details(self) -> MatchStartedEventDetail: ...
|
||||
|
||||
|
||||
class HeadToHeadHandler(MatchTypeHandler):
|
||||
@override
|
||||
async def handle_join(self, user: MultiplayerRoomUser):
|
||||
if user.match_state is not None:
|
||||
user.match_state = None
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(
|
||||
self, user: MultiplayerRoomUser, request: MatchRequest
|
||||
): ...
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@override
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
detail = MatchStartedEventDetail(room_type="head_to_head", team=None)
|
||||
return detail
|
||||
|
||||
|
||||
class TeamVersusHandler(MatchTypeHandler):
|
||||
@override
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
super().__init__(room)
|
||||
self.state = TeamVersusRoomState()
|
||||
room.room.match_state = self.state
|
||||
task = asyncio.create_task(self.hub.change_room_match_state(self.room))
|
||||
self.hub.tasks.add(task)
|
||||
task.add_done_callback(self.hub.tasks.discard)
|
||||
|
||||
def _get_best_available_team(self) -> int:
|
||||
for team in self.state.teams:
|
||||
if all(
|
||||
(
|
||||
user.match_state is None
|
||||
or not isinstance(user.match_state, TeamVersusUserState)
|
||||
or user.match_state.team_id != team.id
|
||||
)
|
||||
for user in self.room.room.users
|
||||
):
|
||||
return team.id
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
team_counts = defaultdict(int)
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(
|
||||
user.match_state, TeamVersusUserState
|
||||
):
|
||||
team_counts[user.match_state.team_id] += 1
|
||||
|
||||
if team_counts:
|
||||
min_count = min(team_counts.values())
|
||||
for team_id, count in team_counts.items():
|
||||
if count == min_count:
|
||||
return team_id
|
||||
return self.state.teams[0].id if self.state.teams else 0
|
||||
|
||||
@override
|
||||
async def handle_join(self, user: MultiplayerRoomUser):
|
||||
best_team_id = self._get_best_available_team()
|
||||
user.match_state = TeamVersusUserState(team_id=best_team_id)
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
|
||||
if not isinstance(request, ChangeTeamRequest):
|
||||
return
|
||||
|
||||
if request.team_id not in [team.id for team in self.state.teams]:
|
||||
raise InvokeException("Invalid team ID")
|
||||
|
||||
user.match_state = TeamVersusUserState(team_id=request.team_id)
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@override
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
teams: dict[int, Literal["blue", "red"]] = {}
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(
|
||||
user.match_state, TeamVersusUserState
|
||||
):
|
||||
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
|
||||
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
|
||||
return detail
|
||||
|
||||
|
||||
MATCH_TYPE_HANDLERS = {
|
||||
MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
|
||||
MatchType.TEAM_VERSUS: TeamVersusHandler,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerMultiplayerRoom:
|
||||
room: MultiplayerRoom
|
||||
category: RoomCategory
|
||||
status: RoomStatus
|
||||
start_at: datetime
|
||||
hub: "MultiplayerHub"
|
||||
match_type_handler: MatchTypeHandler
|
||||
queue: MultiplayerQueue
|
||||
_next_countdown_id: int
|
||||
_countdown_id_lock: asyncio.Lock
|
||||
_tracked_countdown: dict[int, CountdownInfo]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
room: MultiplayerRoom,
|
||||
category: RoomCategory,
|
||||
start_at: datetime,
|
||||
hub: "MultiplayerHub",
|
||||
):
|
||||
self.room = room
|
||||
self.category = category
|
||||
self.status = RoomStatus.IDLE
|
||||
self.start_at = start_at
|
||||
self.hub = hub
|
||||
self.queue = MultiplayerQueue(self)
|
||||
self._next_countdown_id = 0
|
||||
self._countdown_id_lock = asyncio.Lock()
|
||||
self._tracked_countdown = {}
|
||||
|
||||
async def set_handler(self):
|
||||
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](
|
||||
self
|
||||
)
|
||||
for i in self.room.users:
|
||||
await self.match_type_handler.handle_join(i)
|
||||
|
||||
async def get_next_countdown_id(self) -> int:
|
||||
async with self._countdown_id_lock:
|
||||
self._next_countdown_id += 1
|
||||
return self._next_countdown_id
|
||||
|
||||
async def start_countdown(
|
||||
self,
|
||||
countdown: MultiplayerCountdown,
|
||||
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
|
||||
):
|
||||
async def _countdown_task(self: "ServerMultiplayerRoom"):
|
||||
await asyncio.sleep(info.duration.total_seconds())
|
||||
if on_complete is not None:
|
||||
await on_complete(self)
|
||||
await self.stop_countdown(countdown)
|
||||
|
||||
if countdown.is_exclusive:
|
||||
await self.stop_all_countdowns(countdown.__class__)
|
||||
countdown.id = await self.get_next_countdown_id()
|
||||
info = CountdownInfo(countdown)
|
||||
self.room.active_countdowns.append(info.countdown)
|
||||
self._tracked_countdown[countdown.id] = info
|
||||
await self.hub.send_match_event(
|
||||
self, CountdownStartedEvent(countdown=info.countdown)
|
||||
)
|
||||
info.task = asyncio.create_task(_countdown_task(self))
|
||||
|
||||
async def stop_countdown(self, countdown: MultiplayerCountdown):
|
||||
info = self._tracked_countdown.get(countdown.id)
|
||||
if info is None:
|
||||
return
|
||||
del self._tracked_countdown[countdown.id]
|
||||
self.room.active_countdowns.remove(countdown)
|
||||
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
|
||||
if info.task is not None and not info.task.done():
|
||||
info.task.cancel()
|
||||
|
||||
async def stop_all_countdowns(self, typ: type[MultiplayerCountdown]):
|
||||
for countdown in list(self._tracked_countdown.values()):
|
||||
if isinstance(countdown.countdown, typ):
|
||||
await self.stop_countdown(countdown.countdown)
|
||||
|
||||
|
||||
class _MatchServerEvent(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class CountdownStartedEvent(_MatchServerEvent):
|
||||
countdown: MultiplayerCountdown
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
class CountdownStoppedEvent(_MatchServerEvent):
|
||||
id: int
|
||||
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
|
||||
|
||||
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
|
||||
|
||||
|
||||
class GameplayAbortReason(IntEnum):
|
||||
LOAD_TOOK_TOO_LONG = 0
|
||||
HOST_ABORTED = 1
|
||||
|
||||
|
||||
class MatchStartedEventDetail(TypedDict):
|
||||
room_type: Literal["playlists", "head_to_head", "team_versus"]
|
||||
team: dict[int, Literal["blue", "red"]] | None
|
||||
@@ -1,7 +1,6 @@
|
||||
# OAuth 相关模型
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel):
|
||||
|
||||
class RegistrationErrorResponse(BaseModel):
|
||||
"""注册错误响应模型"""
|
||||
|
||||
form_error: dict
|
||||
|
||||
|
||||
class UserRegistrationErrors(BaseModel):
|
||||
"""用户注册错误模型"""
|
||||
username: List[str] = []
|
||||
user_email: List[str] = []
|
||||
password: List[str] = []
|
||||
|
||||
username: list[str] = []
|
||||
user_email: list[str] = []
|
||||
password: list[str] = []
|
||||
|
||||
|
||||
class RegistrationRequestErrors(BaseModel):
|
||||
"""注册请求错误模型"""
|
||||
|
||||
message: str | None = None
|
||||
redirect: str | None = None
|
||||
user: UserRegistrationErrors | None = None
|
||||
|
||||
@@ -1,15 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from app.database import User
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.models.mods import APIMod
|
||||
|
||||
from .model import UTCBaseModel
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RoomCategory(str, Enum):
|
||||
@@ -17,6 +10,7 @@ class RoomCategory(str, Enum):
|
||||
SPOTLIGHT = "spotlight"
|
||||
FEATURED_ARTIST = "featured_artist"
|
||||
DAILY_CHALLENGE = "daily_challenge"
|
||||
REALTIME = "realtime" # INTERNAL USE ONLY, DO NOT USE IN API
|
||||
|
||||
|
||||
class MatchType(str, Enum):
|
||||
@@ -42,18 +36,40 @@ class RoomStatus(str, Enum):
|
||||
PLAYING = "playing"
|
||||
|
||||
|
||||
class PlaylistItem(UTCBaseModel):
|
||||
id: int | None
|
||||
owner_id: int
|
||||
ruleset_id: int
|
||||
expired: bool
|
||||
playlist_order: int | None
|
||||
played_at: datetime | None
|
||||
allowed_mods: list[APIMod] = Field(default_factory=list)
|
||||
required_mods: list[APIMod] = Field(default_factory=list)
|
||||
beatmap_id: int
|
||||
beatmap: Beatmap | None
|
||||
freestyle: bool
|
||||
class MultiplayerRoomState(str, Enum):
|
||||
OPEN = "open"
|
||||
WAITING_FOR_LOAD = "waiting_for_load"
|
||||
PLAYING = "playing"
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class MultiplayerUserState(str, Enum):
|
||||
IDLE = "idle"
|
||||
READY = "ready"
|
||||
WAITING_FOR_LOAD = "waiting_for_load"
|
||||
LOADED = "loaded"
|
||||
READY_FOR_GAMEPLAY = "ready_for_gameplay"
|
||||
PLAYING = "playing"
|
||||
FINISHED_PLAY = "finished_play"
|
||||
RESULTS = "results"
|
||||
SPECTATING = "spectating"
|
||||
|
||||
@property
|
||||
def is_playing(self) -> bool:
|
||||
return self in {
|
||||
self.WAITING_FOR_LOAD,
|
||||
self.PLAYING,
|
||||
self.READY_FOR_GAMEPLAY,
|
||||
self.LOADED,
|
||||
}
|
||||
|
||||
|
||||
class DownloadState(str, Enum):
|
||||
UNKNOWN = "unknown"
|
||||
NOT_DOWNLOADED = "not_downloaded"
|
||||
DOWNLOADING = "downloading"
|
||||
IMPORTING = "importing"
|
||||
LOCALLY_AVAILABLE = "locally_available"
|
||||
|
||||
|
||||
class RoomPlaylistItemStats(BaseModel):
|
||||
@@ -67,39 +83,7 @@ class RoomDifficultyRange(BaseModel):
|
||||
max: float
|
||||
|
||||
|
||||
class ItemAttemptsCount(BaseModel):
|
||||
id: int
|
||||
attempts: int
|
||||
passed: bool
|
||||
|
||||
|
||||
class PlaylistAggregateScore(BaseModel):
|
||||
playlist_item_attempts: list[ItemAttemptsCount]
|
||||
|
||||
|
||||
class Room(UTCBaseModel):
|
||||
id: int | None
|
||||
name: str = ""
|
||||
password: str | None
|
||||
has_password: bool = False
|
||||
host: User | None
|
||||
category: RoomCategory = RoomCategory.NORMAL
|
||||
duration: int | None
|
||||
starts_at: datetime | None
|
||||
ends_at: datetime | None
|
||||
participant_count: int = 0
|
||||
recent_participants: list[User] = Field(default_factory=list)
|
||||
max_attempts: int | None
|
||||
playlist: list[PlaylistItem] = Field(default_factory=list)
|
||||
playlist_item_stats: RoomPlaylistItemStats | None
|
||||
difficulty_range: RoomDifficultyRange | None
|
||||
type: MatchType = MatchType.PLAYLISTS
|
||||
queue_mode: QueueMode = QueueMode.HOST_ONLY
|
||||
auto_skip: bool = False
|
||||
auto_start_duration: int = 0
|
||||
current_user_score: PlaylistAggregateScore | None
|
||||
current_playlist_item: PlaylistItem | None
|
||||
channel_id: int = 0
|
||||
status: RoomStatus = RoomStatus.IDLE
|
||||
# availability 字段在当前序列化中未包含,但可能在某些场景下需要
|
||||
availability: RoomAvailability | None
|
||||
class PlaylistStatus(BaseModel):
|
||||
count_active: int
|
||||
count_total: int
|
||||
ruleset_ids: list[int]
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
)
|
||||
|
||||
@@ -15,23 +13,7 @@ from pydantic import (
|
||||
class SignalRMeta:
|
||||
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
|
||||
json_ignore: bool = False # implement of JsonIgnore (json) attribute
|
||||
use_upper_case: bool = False # use upper CamelCase for field names
|
||||
|
||||
|
||||
def _by_index(v: Any, class_: type[Enum]):
|
||||
enum_list = list(class_)
|
||||
if not isinstance(v, int):
|
||||
return v
|
||||
if 0 <= v < len(enum_list):
|
||||
return enum_list[v]
|
||||
raise ValueError(
|
||||
f"Value {v} is out of range for enum "
|
||||
f"{class_.__name__} with {len(enum_list)} items"
|
||||
)
|
||||
|
||||
|
||||
def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator:
|
||||
return BeforeValidator(lambda v: _by_index(v, enum_class))
|
||||
use_abbr: bool = True
|
||||
|
||||
|
||||
class SignalRUnionMessage(BaseModel):
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
from typing import Annotated, Any
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import APIMod
|
||||
@@ -89,9 +89,9 @@ class LegacyReplayFrame(BaseModel):
|
||||
mouse_y: float | None = None
|
||||
button_state: int
|
||||
|
||||
header: FrameHeader | None = Field(
|
||||
default=None, metadata=[SignalRMeta(member_ignore=True)]
|
||||
)
|
||||
header: Annotated[
|
||||
FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)
|
||||
]
|
||||
|
||||
|
||||
class FrameDataBundle(BaseModel):
|
||||
|
||||
@@ -7,6 +7,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
beatmapset,
|
||||
me,
|
||||
relationship,
|
||||
room,
|
||||
score,
|
||||
user,
|
||||
)
|
||||
@@ -14,4 +15,9 @@ from .api_router import router as api_router
|
||||
from .auth import router as auth_router
|
||||
from .fetcher import fetcher_router as fetcher_router
|
||||
|
||||
__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"]
|
||||
__all__ = [
|
||||
"api_router",
|
||||
"auth_router",
|
||||
"fetcher_router",
|
||||
"signalr_router",
|
||||
]
|
||||
|
||||
@@ -74,9 +74,10 @@ class BatchGetResp(BaseModel):
|
||||
@router.get("/beatmaps", tags=["beatmap"], response_model=BatchGetResp)
|
||||
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
|
||||
async def batch_get_beatmaps(
|
||||
b_ids: list[int] = Query(alias="id", default_factory=list),
|
||||
b_ids: list[int] = Query(alias="ids[]", default_factory=list),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
if not b_ids:
|
||||
# select 50 beatmaps by last_updated
|
||||
@@ -86,9 +87,27 @@ async def batch_get_beatmaps(
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmaps = (
|
||||
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
|
||||
).all()
|
||||
beatmaps = list(
|
||||
(
|
||||
await db.exec(
|
||||
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
|
||||
)
|
||||
).all()
|
||||
)
|
||||
not_found_beatmaps = [
|
||||
bid for bid in b_ids if bid not in [bm.id for bm in beatmaps]
|
||||
]
|
||||
beatmaps.extend(
|
||||
beatmap
|
||||
for beatmap in await asyncio.gather(
|
||||
*[
|
||||
Beatmap.get_or_fetch(db, fetcher, bid=bid)
|
||||
for bid in not_found_beatmaps
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(beatmap, Beatmap)
|
||||
)
|
||||
|
||||
return BatchGetResp(
|
||||
beatmaps=[
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
@@ -12,11 +12,25 @@ from .api_router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from fastapi.responses import RedirectResponse
|
||||
from httpx import HTTPStatusError
|
||||
from httpx import HTTPError
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||
async def lookup_beatmapset(
|
||||
beatmap_id: int = Query(),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||
resp = await BeatmapsetResp.from_db(
|
||||
beatmap.beatmapset, session=db, user=current_user
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||
async def get_beatmapset(
|
||||
sid: int,
|
||||
@@ -24,18 +38,13 @@ async def get_beatmapset(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
|
||||
if not beatmapset:
|
||||
try:
|
||||
resp = await fetcher.get_beatmapset(sid)
|
||||
await Beatmapset.from_resp(db, resp)
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
else:
|
||||
resp = await BeatmapsetResp.from_db(
|
||||
try:
|
||||
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, sid)
|
||||
return await BeatmapsetResp.from_db(
|
||||
beatmapset, session=db, include=["recent_favourites"], user=current_user
|
||||
)
|
||||
return resp
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
|
||||
|
||||
@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])
|
||||
|
||||
@@ -96,9 +96,7 @@ async def add_relationship(
|
||||
)
|
||||
).first()
|
||||
assert relationship, "Relationship should exist after commit"
|
||||
return AddFriendResp(
|
||||
user_relation=await RelationshipResp.from_db(db, relationship)
|
||||
)
|
||||
return await RelationshipResp.from_db(db, relationship)
|
||||
|
||||
|
||||
@router.delete("/friends/{target}", tags=["relationship"])
|
||||
|
||||
@@ -1,33 +1,346 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.room import RoomIndex
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
|
||||
from app.database.beatmap import Beatmap, BeatmapResp
|
||||
from app.database.beatmapset import BeatmapsetResp
|
||||
from app.database.lazer_user import User, UserResp
|
||||
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||
from app.database.playlist_attempts import ItemAttemptsCount, ItemAttemptsResp
|
||||
from app.database.playlists import Playlist, PlaylistResp
|
||||
from app.database.room import APIUploadedRoom, Room, RoomResp
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.models.room import Room
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.room import RoomCategory, RoomStatus
|
||||
from app.service.room import create_playlist_room_from_api
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, Query
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/rooms", tags=["rooms"], response_model=list[Room])
|
||||
@router.get("/rooms", tags=["rooms"], response_model=list[RoomResp])
|
||||
async def get_all_rooms(
|
||||
mode: str = Query(
|
||||
None
|
||||
), # TODO: lazer源码显示房间不会是除了open以外的其他状态,先放在这里
|
||||
status: str = Query(None),
|
||||
category: str = Query(None),
|
||||
mode: Literal["open", "ended", "participated", "owned", None] = Query(
|
||||
default="open"
|
||||
),
|
||||
category: RoomCategory = Query(RoomCategory.NORMAL),
|
||||
status: RoomStatus | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
resp_list: list[RoomResp] = []
|
||||
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category]
|
||||
now = datetime.now(UTC)
|
||||
if status is not None:
|
||||
where_clauses.append(col(Room.status) == status)
|
||||
if mode == "open":
|
||||
where_clauses.append(
|
||||
(col(Room.ends_at).is_(None))
|
||||
| (col(Room.ends_at) > now.replace(tzinfo=UTC))
|
||||
)
|
||||
if category == RoomCategory.REALTIME:
|
||||
where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
|
||||
if mode == "participated":
|
||||
where_clauses.append(
|
||||
exists().where(
|
||||
col(RoomParticipatedUser.room_id) == Room.id,
|
||||
col(RoomParticipatedUser.user_id) == current_user.id,
|
||||
)
|
||||
)
|
||||
if mode == "owned":
|
||||
where_clauses.append(col(Room.host_id) == current_user.id)
|
||||
if mode == "ended":
|
||||
where_clauses.append(
|
||||
(col(Room.ends_at).is_not(None))
|
||||
& (col(Room.ends_at) < now.replace(tzinfo=UTC))
|
||||
)
|
||||
|
||||
db_rooms = (
|
||||
(
|
||||
await db.exec(
|
||||
select(Room).where(
|
||||
*where_clauses,
|
||||
)
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
|
||||
for room in db_rooms:
|
||||
resp = await RoomResp.from_db(room, db)
|
||||
if category == RoomCategory.REALTIME:
|
||||
resp.has_password = bool(
|
||||
MultiplayerHubs.rooms[room.id].room.settings.password.strip()
|
||||
)
|
||||
resp.category = RoomCategory.NORMAL
|
||||
resp_list.append(resp)
|
||||
|
||||
return resp_list
|
||||
|
||||
|
||||
class APICreatedRoom(RoomResp):
|
||||
error: str = ""
|
||||
|
||||
|
||||
async def _participate_room(
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession
|
||||
):
|
||||
participated_user = (
|
||||
await session.exec(
|
||||
select(RoomParticipatedUser).where(
|
||||
RoomParticipatedUser.room_id == room_id,
|
||||
RoomParticipatedUser.user_id == user_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if participated_user is None:
|
||||
participated_user = RoomParticipatedUser(
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
joined_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(participated_user)
|
||||
else:
|
||||
participated_user.left_at = None
|
||||
participated_user.joined_at = datetime.now(UTC)
|
||||
db_room.participant_count += 1
|
||||
|
||||
|
||||
@router.post("/rooms", tags=["room"], response_model=APICreatedRoom)
|
||||
async def create_room(
|
||||
room: APIUploadedRoom,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
await _participate_room(db_room.id, user_id, db_room, db)
|
||||
# await db.commit()
|
||||
# await db.refresh(db_room)
|
||||
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
|
||||
created_room.error = ""
|
||||
return created_room
|
||||
|
||||
|
||||
@router.get("/rooms/{room}", tags=["room"], response_model=RoomResp)
|
||||
async def get_room(
|
||||
room: int,
|
||||
category: str = Query(default=""),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
all_room_ids = (await db.exec(select(RoomIndex).where(True))).all()
|
||||
roomsList: list[Room] = []
|
||||
for room_index in all_room_ids:
|
||||
dumped_room = await redis.get(str(room_index.id))
|
||||
if dumped_room:
|
||||
actual_room = Room.model_validate_json(str(dumped_room))
|
||||
if actual_room.status == status and actual_room.category == category:
|
||||
roomsList.append(actual_room)
|
||||
return roomsList
|
||||
# 直接从db获取信息,毕竟都一样
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
resp = await RoomResp.from_db(
|
||||
db_room, include=["current_user_score"], session=db, user=current_user
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@router.delete("/rooms/{room}", tags=["room"])
|
||||
async def delete_room(room: int, db: AsyncSession = Depends(get_db)):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
else:
|
||||
db_room.ends_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.put("/rooms/{room}/users/{user}", tags=["room"])
|
||||
async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_db)):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is not None:
|
||||
await _participate_room(room, user, db_room, db)
|
||||
await db.commit()
|
||||
await db.refresh(db_room)
|
||||
resp = await RoomResp.from_db(db_room, db)
|
||||
|
||||
return resp
|
||||
else:
|
||||
raise HTTPException(404, "room not found0")
|
||||
|
||||
|
||||
@router.delete("/rooms/{room}/users/{user}", tags=["room"])
|
||||
async def remove_user_from_room(
|
||||
room: int, user: int, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is not None:
|
||||
participated_user = (
|
||||
await db.exec(
|
||||
select(RoomParticipatedUser).where(
|
||||
RoomParticipatedUser.room_id == room,
|
||||
RoomParticipatedUser.user_id == user,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if participated_user is not None:
|
||||
participated_user.left_at = datetime.now(UTC)
|
||||
db_room.participant_count -= 1
|
||||
await db.commit()
|
||||
return None
|
||||
else:
|
||||
raise HTTPException(404, "Room not found")
|
||||
|
||||
|
||||
class APILeaderboard(BaseModel):
|
||||
leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
|
||||
user_score: ItemAttemptsResp | None = None
|
||||
|
||||
|
||||
@router.get("/rooms/{room}/leaderboard", tags=["room"], response_model=APILeaderboard)
|
||||
async def get_room_leaderboard(
|
||||
room: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
|
||||
aggs = await db.exec(
|
||||
select(ItemAttemptsCount)
|
||||
.where(ItemAttemptsCount.room_id == room)
|
||||
.order_by(col(ItemAttemptsCount.total_score).desc())
|
||||
)
|
||||
aggs_resp = []
|
||||
user_agg = None
|
||||
for i, agg in enumerate(aggs):
|
||||
resp = await ItemAttemptsResp.from_db(agg, db)
|
||||
resp.position = i + 1
|
||||
# resp.accuracy *= 100
|
||||
aggs_resp.append(resp)
|
||||
if agg.user_id == current_user.id:
|
||||
user_agg = resp
|
||||
return APILeaderboard(
|
||||
leaderboard=aggs_resp,
|
||||
user_score=user_agg,
|
||||
)
|
||||
|
||||
|
||||
class RoomEvents(BaseModel):
|
||||
beatmaps: list[BeatmapResp] = Field(default_factory=list)
|
||||
beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
|
||||
current_playlist_item_id: int = 0
|
||||
events: list[MultiplayerEventResp] = Field(default_factory=list)
|
||||
first_event_id: int = 0
|
||||
last_event_id: int = 0
|
||||
playlist_items: list[PlaylistResp] = Field(default_factory=list)
|
||||
room: RoomResp
|
||||
user: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get("/rooms/{room_id}/events", response_model=RoomEvents, tags=["room"])
|
||||
async def get_room_events(
|
||||
room_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
after: int | None = Query(None, ge=0),
|
||||
before: int | None = Query(None, ge=0),
|
||||
):
|
||||
events = (
|
||||
await db.exec(
|
||||
select(MultiplayerEvent)
|
||||
.where(
|
||||
MultiplayerEvent.room_id == room_id,
|
||||
col(MultiplayerEvent.id) > after if after is not None else True,
|
||||
col(MultiplayerEvent.id) < before if before is not None else True,
|
||||
)
|
||||
.order_by(col(MultiplayerEvent.id).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_ids = set()
|
||||
playlist_items = {}
|
||||
beatmap_ids = set()
|
||||
|
||||
event_resps = []
|
||||
first_event_id = 0
|
||||
last_event_id = 0
|
||||
|
||||
current_playlist_item_id = 0
|
||||
for event in events:
|
||||
event_resps.append(MultiplayerEventResp.from_db(event))
|
||||
|
||||
if event.user_id:
|
||||
user_ids.add(event.user_id)
|
||||
|
||||
if event.playlist_item_id is not None and (
|
||||
playitem := (
|
||||
await db.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == event.playlist_item_id,
|
||||
Playlist.room_id == room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
):
|
||||
current_playlist_item_id = playitem.id
|
||||
playlist_items[event.playlist_item_id] = playitem
|
||||
beatmap_ids.add(playitem.beatmap_id)
|
||||
scores = await db.exec(
|
||||
select(Score).where(
|
||||
Score.playlist_item_id == event.playlist_item_id,
|
||||
Score.room_id == room_id,
|
||||
)
|
||||
)
|
||||
for score in scores:
|
||||
user_ids.add(score.user_id)
|
||||
beatmap_ids.add(score.beatmap_id)
|
||||
|
||||
assert event.id is not None
|
||||
first_event_id = min(first_event_id, event.id)
|
||||
last_event_id = max(last_event_id, event.id)
|
||||
|
||||
if room := MultiplayerHubs.rooms.get(room_id):
|
||||
current_playlist_item_id = room.queue.current_item.id
|
||||
room_resp = await RoomResp.from_hub(room)
|
||||
else:
|
||||
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
room_resp = await RoomResp.from_db(room, db)
|
||||
|
||||
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||
user_resps = [await UserResp.from_db(user, db) for user in users]
|
||||
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
|
||||
beatmap_resps = [
|
||||
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
|
||||
]
|
||||
beatmapset_resps = {}
|
||||
for beatmap_resp in beatmap_resps:
|
||||
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
|
||||
|
||||
playlist_items_resps = [
|
||||
await PlaylistResp.from_db(item) for item in playlist_items.values()
|
||||
]
|
||||
|
||||
return RoomEvents(
|
||||
beatmaps=beatmap_resps,
|
||||
beatmapsets=beatmapset_resps,
|
||||
current_playlist_item_id=current_playlist_item_id,
|
||||
events=event_resps,
|
||||
first_event_id=first_event_id,
|
||||
last_event_id=last_event_id,
|
||||
playlist_items=playlist_items_resps,
|
||||
room=room_resp,
|
||||
user=user_resps,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User
|
||||
from app.database.score import get_leaderboard, process_score, process_user
|
||||
from datetime import UTC, datetime
|
||||
import time
|
||||
|
||||
from app.calculator import clamp
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
Playlist,
|
||||
Room,
|
||||
Score,
|
||||
ScoreResp,
|
||||
ScoreToken,
|
||||
ScoreTokenResp,
|
||||
User,
|
||||
)
|
||||
from app.database.playlist_attempts import ItemAttemptsCount
|
||||
from app.database.playlist_best_score import (
|
||||
PlaylistBestScore,
|
||||
get_position,
|
||||
process_playlist_best_score,
|
||||
)
|
||||
from app.database.score import (
|
||||
MultiplayerScores,
|
||||
ScoreAround,
|
||||
get_leaderboard,
|
||||
process_score,
|
||||
process_user,
|
||||
)
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.room import RoomCategory
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
@@ -17,12 +44,78 @@ from app.models.score import (
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from httpx import HTTPError
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
READ_SCORE_TIMEOUT = 10
|
||||
|
||||
|
||||
async def submit_score(
|
||||
info: SoloScoreSubmissionInfo,
|
||||
beatmap: int,
|
||||
token: int,
|
||||
current_user: User,
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
fetcher: Fetcher,
|
||||
item_id: int | None = None,
|
||||
room_id: int | None = None,
|
||||
):
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token)
|
||||
)
|
||||
).first()
|
||||
if not score_token or score_token.user_id != current_user.id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
if score_token.score_id:
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score).where(
|
||||
Score.id == score_token.score_id,
|
||||
Score.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
else:
|
||||
try:
|
||||
db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
ranked = db_beatmap.beatmap_status in {
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
}
|
||||
score = await process_score(
|
||||
current_user,
|
||||
beatmap,
|
||||
ranked,
|
||||
score_token,
|
||||
info,
|
||||
fetcher,
|
||||
db,
|
||||
redis,
|
||||
item_id,
|
||||
room_id,
|
||||
)
|
||||
await db.refresh(current_user)
|
||||
score_id = score.id
|
||||
score_token.score_id = score_id
|
||||
await process_user(db, current_user, score, ranked)
|
||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
assert score is not None
|
||||
return await ScoreResp.from_db(db, score)
|
||||
|
||||
|
||||
class BeatmapScores(BaseModel):
|
||||
scores: list[ScoreResp]
|
||||
@@ -97,9 +190,10 @@ async def get_user_beatmap_score(
|
||||
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
|
||||
)
|
||||
else:
|
||||
resp = await ScoreResp.from_db(db, user_score)
|
||||
return BeatmapUserScore(
|
||||
position=user_score.position if user_score.position is not None else 0,
|
||||
score=await ScoreResp.from_db(db, user_score),
|
||||
position=resp.rank_global or 0,
|
||||
score=resp,
|
||||
)
|
||||
|
||||
|
||||
@@ -173,55 +267,285 @@ async def submit_solo_score(
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
):
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
async with db:
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token, ScoreToken.user_id == current_user.id)
|
||||
return await submit_score(info, beatmap, token, current_user, db, redis, fetcher)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp
|
||||
)
|
||||
async def create_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
beatmap_id: int = Form(),
|
||||
beatmap_hash: str = Form(),
|
||||
ruleset_id: int = Form(..., ge=0, le=3),
|
||||
version_hash: str = Form(""),
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None
|
||||
if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC):
|
||||
raise HTTPException(status_code=400, detail="Room has ended")
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
).first()
|
||||
if not score_token or score_token.user_id != current_user.id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
if score_token.score_id:
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score).where(
|
||||
Score.id == score_token.score_id,
|
||||
Score.user_id == current_user.id,
|
||||
)
|
||||
).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist not found")
|
||||
|
||||
# validate
|
||||
if not item.freestyle:
|
||||
if item.ruleset_id != ruleset_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Ruleset mismatch in playlist item"
|
||||
)
|
||||
if item.beatmap_id != beatmap_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Beatmap ID mismatch in playlist item"
|
||||
)
|
||||
agg = await session.exec(
|
||||
select(ItemAttemptsCount).where(
|
||||
ItemAttemptsCount.room_id == room_id,
|
||||
ItemAttemptsCount.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
agg = agg.first()
|
||||
if agg and room.max_attempts and agg.attempts >= room.max_attempts:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="You have reached the maximum attempts for this room",
|
||||
)
|
||||
if item.expired:
|
||||
raise HTTPException(status_code=400, detail="Playlist item has expired")
|
||||
if item.played_at:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Playlist item has already been played"
|
||||
)
|
||||
# 这里应该不用验证mod了吧。。。
|
||||
|
||||
score_token = ScoreToken(
|
||||
user_id=current_user.id,
|
||||
beatmap_id=beatmap_id,
|
||||
ruleset_id=INT_TO_MODE[ruleset_id],
|
||||
playlist_item_id=playlist_id,
|
||||
)
|
||||
session.add(score_token)
|
||||
await session.commit()
|
||||
await session.refresh(score_token)
|
||||
return ScoreTokenResp.from_db(score_token)
|
||||
|
||||
|
||||
@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}")
|
||||
async def submit_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist item not found")
|
||||
|
||||
user_id = current_user.id
|
||||
score_resp = await submit_score(
|
||||
info,
|
||||
item.beatmap_id,
|
||||
token,
|
||||
current_user,
|
||||
session,
|
||||
redis,
|
||||
fetcher,
|
||||
item.id,
|
||||
room_id,
|
||||
)
|
||||
await process_playlist_best_score(
|
||||
room_id,
|
||||
playlist_id,
|
||||
user_id,
|
||||
score_resp.id,
|
||||
score_resp.total_score,
|
||||
session,
|
||||
redis,
|
||||
)
|
||||
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
|
||||
return score_resp
|
||||
|
||||
|
||||
class IndexedScoreResp(MultiplayerScores):
|
||||
total: int
|
||||
user_score: ScoreResp | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=IndexedScoreResp
|
||||
)
|
||||
async def index_playlist_scores(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
limit: int = 50,
|
||||
cursor: int = Query(2000000, alias="cursor[total_score]"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
limit = clamp(limit, 1, 50)
|
||||
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore)
|
||||
.where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
PlaylistBestScore.total_score < cursor,
|
||||
)
|
||||
.order_by(col(PlaylistBestScore.total_score).desc())
|
||||
.limit(limit + 1)
|
||||
)
|
||||
).all()
|
||||
has_more = len(scores) > limit
|
||||
if has_more:
|
||||
scores = scores[:-1]
|
||||
|
||||
user_score = None
|
||||
score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
|
||||
for score in score_resp:
|
||||
score.position = await get_position(room_id, playlist_id, score.id, session)
|
||||
if score.user_id == current_user.id:
|
||||
user_score = score
|
||||
|
||||
if room.category == RoomCategory.DAILY_CHALLENGE:
|
||||
score_resp = [s for s in score_resp if s.passed]
|
||||
if user_score and not user_score.passed:
|
||||
user_score = None
|
||||
|
||||
resp = IndexedScoreResp(
|
||||
scores=score_resp,
|
||||
user_score=user_score,
|
||||
total=len(scores),
|
||||
params={
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
if has_more:
|
||||
resp.cursor = {
|
||||
"total_score": scores[-1].total_score,
|
||||
}
|
||||
return resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
|
||||
response_model=ScoreResp,
|
||||
)
|
||||
async def show_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
score_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
|
||||
start_time = time.time()
|
||||
score_record = None
|
||||
completed = room.category != RoomCategory.REALTIME
|
||||
while time.time() - start_time < READ_SCORE_TIMEOUT:
|
||||
if score_record is None:
|
||||
score_record = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.score_id == score_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
else:
|
||||
beatmap_status = (
|
||||
await db.exec(
|
||||
select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
|
||||
if completed_players := await redis.get(
|
||||
f"multiplayer:{room_id}:gameplay:players"
|
||||
):
|
||||
completed = completed_players == "0"
|
||||
if score_record and completed:
|
||||
break
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(room_id, playlist_id, score_id, session)
|
||||
if completed:
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
).first()
|
||||
if beatmap_status is None:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
ranked = beatmap_status in {
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
}
|
||||
score = await process_score(
|
||||
current_user,
|
||||
beatmap,
|
||||
ranked,
|
||||
score_token,
|
||||
info,
|
||||
fetcher,
|
||||
db,
|
||||
redis,
|
||||
)
|
||||
await db.refresh(current_user)
|
||||
score_id = score.id
|
||||
score_token.score_id = score_id
|
||||
await process_user(db, current_user, score, ranked)
|
||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
assert score is not None
|
||||
return await ScoreResp.from_db(db, score)
|
||||
).all()
|
||||
higher_scores = []
|
||||
lower_scores = []
|
||||
for score in scores:
|
||||
if score.total_score > resp.total_score:
|
||||
higher_scores.append(await ScoreResp.from_db(session, score.score))
|
||||
elif score.total_score < resp.total_score:
|
||||
lower_scores.append(await ScoreResp.from_db(session, score.score))
|
||||
resp.scores_around = ScoreAround(
|
||||
higher=MultiplayerScores(scores=higher_scores),
|
||||
lower=MultiplayerScores(scores=lower_scores),
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
|
||||
response_model=ScoreResp,
|
||||
)
|
||||
async def get_user_playlist_score(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
user_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
score_record = None
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < READ_SCORE_TIMEOUT:
|
||||
score_record = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.user_id == user_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if score_record:
|
||||
break
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(
|
||||
room_id, playlist_id, score_record.score_id, session
|
||||
)
|
||||
return resp
|
||||
|
||||
10
app/service/__init__.py
Normal file
10
app/service/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .daily_challenge import create_daily_challenge_room
|
||||
from .room import create_playlist_room, create_playlist_room_from_api
|
||||
|
||||
__all__ = [
|
||||
"create_daily_challenge_room",
|
||||
"create_playlist_room",
|
||||
"create_playlist_room_from_api",
|
||||
]
|
||||
121
app/service/daily_challenge.py
Normal file
121
app/service/daily_challenge.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import json
|
||||
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import Room
|
||||
from app.dependencies.database import engine, get_redis
|
||||
from app.dependencies.scheduler import get_scheduler
|
||||
from app.log import logger
|
||||
from app.models.metadata_hub import DailyChallengeInfo
|
||||
from app.models.mods import APIMod
|
||||
from app.models.room import RoomCategory
|
||||
|
||||
from .room import create_playlist_room
|
||||
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
async def create_daily_challenge_room(
|
||||
beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = []
|
||||
) -> Room:
|
||||
async with AsyncSession(engine) as session:
|
||||
today = datetime.now(UTC).date()
|
||||
return await create_playlist_room(
|
||||
session=session,
|
||||
name=str(today),
|
||||
host_id=3,
|
||||
playlist=[
|
||||
Playlist(
|
||||
id=0,
|
||||
room_id=0,
|
||||
owner_id=3,
|
||||
ruleset_id=ruleset_id,
|
||||
beatmap_id=beatmap,
|
||||
required_mods=required_mods,
|
||||
)
|
||||
],
|
||||
category=RoomCategory.DAILY_CHALLENGE,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
|
||||
@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="daily_challenge")
|
||||
async def daily_challenge_job():
|
||||
from app.signalr.hub import MetadataHubs
|
||||
|
||||
now = datetime.now(UTC)
|
||||
redis = get_redis()
|
||||
key = f"daily_challenge:{now.date()}"
|
||||
if not await redis.exists(key):
|
||||
return
|
||||
async with AsyncSession(engine) as session:
|
||||
room = (
|
||||
await session.exec(
|
||||
select(Room).where(
|
||||
Room.category == RoomCategory.DAILY_CHALLENGE,
|
||||
col(Room.ends_at) > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if room:
|
||||
return
|
||||
|
||||
try:
|
||||
beatmap = await redis.hget(key, "beatmap") # pyright: ignore[reportGeneralTypeIssues]
|
||||
ruleset_id = await redis.hget(key, "ruleset_id") # pyright: ignore[reportGeneralTypeIssues]
|
||||
required_mods = await redis.hget(key, "required_mods") # pyright: ignore[reportGeneralTypeIssues]
|
||||
|
||||
if beatmap is None or ruleset_id is None:
|
||||
logger.warning(
|
||||
f"[DailyChallenge] Missing required data for daily challenge {now}."
|
||||
" Will try again in 5 minutes."
|
||||
)
|
||||
get_scheduler().add_job(
|
||||
daily_challenge_job,
|
||||
"date",
|
||||
run_date=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
return
|
||||
|
||||
beatmap_int = int(beatmap)
|
||||
ruleset_id_int = int(ruleset_id)
|
||||
|
||||
mods_list = []
|
||||
if required_mods:
|
||||
mods_list = json.loads(required_mods)
|
||||
|
||||
next_day = (now + timedelta(days=1)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
room = await create_daily_challenge_room(
|
||||
beatmap=beatmap_int,
|
||||
ruleset_id=ruleset_id_int,
|
||||
required_mods=mods_list,
|
||||
duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60),
|
||||
)
|
||||
await MetadataHubs.broadcast_call(
|
||||
"DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id)
|
||||
)
|
||||
logger.success(
|
||||
"[DailyChallenge] Added today's daily challenge: "
|
||||
f"{beatmap=}, {ruleset_id=}, {required_mods=}"
|
||||
)
|
||||
return
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning(
|
||||
f"[DailyChallenge] Error processing daily challenge data: {e}"
|
||||
" Will try again in 5 minutes."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"[DailyChallenge] Unexpected error in daily challenge job: {e}"
|
||||
" Will try again in 5 minutes."
|
||||
)
|
||||
get_scheduler().add_job(
|
||||
daily_challenge_job,
|
||||
"date",
|
||||
run_date=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
78
app/service/room.py
Normal file
78
app/service/room.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import APIUploadedRoom, Room
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus
|
||||
|
||||
from sqlalchemy import exists
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
async def create_playlist_room_from_api(
|
||||
session: AsyncSession, room: APIUploadedRoom, host_id: int
|
||||
) -> Room:
|
||||
db_room = room.to_room()
|
||||
db_room.host_id = host_id
|
||||
db_room.starts_at = datetime.now(UTC)
|
||||
db_room.ends_at = db_room.starts_at + timedelta(
|
||||
minutes=db_room.duration if db_room.duration is not None else 0
|
||||
)
|
||||
session.add(db_room)
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
await add_playlists_to_room(session, db_room.id, room.playlist, host_id)
|
||||
await session.refresh(db_room)
|
||||
return db_room
|
||||
|
||||
|
||||
async def create_playlist_room(
|
||||
session: AsyncSession,
|
||||
name: str,
|
||||
host_id: int,
|
||||
category: RoomCategory = RoomCategory.NORMAL,
|
||||
duration: int = 30,
|
||||
max_attempts: int | None = None,
|
||||
playlist: list[Playlist] = [],
|
||||
) -> Room:
|
||||
db_room = Room(
|
||||
name=name,
|
||||
category=category,
|
||||
duration=duration,
|
||||
starts_at=datetime.now(UTC),
|
||||
ends_at=datetime.now(UTC) + timedelta(minutes=duration),
|
||||
participant_count=0,
|
||||
max_attempts=max_attempts,
|
||||
type=MatchType.PLAYLISTS,
|
||||
queue_mode=QueueMode.HOST_ONLY,
|
||||
auto_skip=False,
|
||||
auto_start_duration=0,
|
||||
status=RoomStatus.IDLE,
|
||||
host_id=host_id,
|
||||
)
|
||||
session.add(db_room)
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
await add_playlists_to_room(session, db_room.id, playlist, host_id)
|
||||
await session.refresh(db_room)
|
||||
return db_room
|
||||
|
||||
|
||||
async def add_playlists_to_room(
|
||||
session: AsyncSession, room_id: int, playlist: list[Playlist], owner_id: int
|
||||
):
|
||||
for item in playlist:
|
||||
if not (
|
||||
await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap)))
|
||||
).first():
|
||||
fetcher = await get_fetcher()
|
||||
await Beatmap.get_or_fetch(session, fetcher, item.beatmap_id)
|
||||
item.id = await Playlist.get_next_id_for_room(room_id, session)
|
||||
item.room_id = room_id
|
||||
item.owner_id = owner_id
|
||||
session.add(item)
|
||||
await session.commit()
|
||||
48
app/service/subscribers/base.py
Normal file
48
app/service/subscribers/base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from app.dependencies.database import get_redis_pubsub
|
||||
|
||||
|
||||
class RedisSubscriber:
|
||||
def __init__(self):
|
||||
self.pubsub = get_redis_pubsub()
|
||||
self.handlers: dict[str, list[Callable[[str, str], Awaitable[Any]]]] = {}
|
||||
self.task: asyncio.Task | None = None
|
||||
|
||||
async def subscribe(self, channel: str):
|
||||
await self.pubsub.subscribe(channel)
|
||||
if channel not in self.handlers:
|
||||
self.handlers[channel] = []
|
||||
|
||||
async def unsubscribe(self, channel: str):
|
||||
if channel in self.handlers:
|
||||
del self.handlers[channel]
|
||||
await self.pubsub.unsubscribe(channel)
|
||||
|
||||
async def listen(self):
|
||||
while True:
|
||||
message = await self.pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=None
|
||||
)
|
||||
if message is not None and message["type"] == "message":
|
||||
method = self.handlers.get(message["channel"])
|
||||
if method:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
handler(message["channel"], message["data"])
|
||||
for handler in method
|
||||
]
|
||||
)
|
||||
|
||||
def start(self):
|
||||
if self.task is None or self.task.done():
|
||||
self.task = asyncio.create_task(self.listen())
|
||||
|
||||
def stop(self):
|
||||
if self.task is not None and not self.task.done():
|
||||
self.task.cancel()
|
||||
self.task = None
|
||||
87
app/service/subscribers/score_processed.py
Normal file
87
app/service/subscribers/score_processed.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.database import PlaylistBestScore, Score
|
||||
from app.database.playlist_best_score import get_position
|
||||
from app.dependencies.database import engine
|
||||
from app.models.metadata_hub import MultiplayerRoomScoreSetEvent
|
||||
|
||||
from .base import RedisSubscriber
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.signalr.hub import MetadataHub
|
||||
|
||||
|
||||
CHANNEL = "score:processed"
|
||||
|
||||
|
||||
class ScoreSubscriber(RedisSubscriber):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.room_subscriber: dict[int, list[int]] = {}
|
||||
self.metadata_hub: "MetadataHub | None " = None
|
||||
self.subscribed = False
|
||||
self.handlers[CHANNEL] = [self._handler]
|
||||
|
||||
async def subscribe_room_score(self, room_id: int, user_id: int):
|
||||
if room_id not in self.room_subscriber:
|
||||
await self.subscribe(CHANNEL)
|
||||
self.start()
|
||||
self.room_subscriber.setdefault(room_id, []).append(user_id)
|
||||
|
||||
async def unsubscribe_room_score(self, room_id: int, user_id: int):
|
||||
if room_id in self.room_subscriber:
|
||||
self.room_subscriber[room_id].remove(user_id)
|
||||
if not self.room_subscriber[room_id]:
|
||||
del self.room_subscriber[room_id]
|
||||
|
||||
async def _notify_room_score_processed(self, score_id: int):
|
||||
if not self.metadata_hub:
|
||||
return
|
||||
async with AsyncSession(engine) as session:
|
||||
score = await session.get(Score, score_id)
|
||||
if (
|
||||
not score
|
||||
or not score.passed
|
||||
or score.room_id is None
|
||||
or score.playlist_item_id is None
|
||||
):
|
||||
return
|
||||
if not self.room_subscriber.get(score.room_id, []):
|
||||
return
|
||||
|
||||
new_rank = None
|
||||
user_best = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.user_id == score.user_id,
|
||||
PlaylistBestScore.room_id == score.room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if user_best and user_best.score_id == score_id:
|
||||
new_rank = await get_position(
|
||||
user_best.room_id,
|
||||
user_best.playlist_id,
|
||||
user_best.score_id,
|
||||
session,
|
||||
)
|
||||
|
||||
event = MultiplayerRoomScoreSetEvent(
|
||||
room_id=score.room_id,
|
||||
playlist_item_id=score.playlist_item_id,
|
||||
score_id=score_id,
|
||||
user_id=score.user_id,
|
||||
total_score=score.total_score,
|
||||
new_rank=new_rank,
|
||||
)
|
||||
await self.metadata_hub.notify_room_score_processed(event)
|
||||
|
||||
async def _handler(self, channel: str, data: str):
|
||||
score_id = int(data)
|
||||
if self.metadata_hub:
|
||||
await self._notify_room_score_processed(score_id)
|
||||
@@ -6,9 +6,9 @@ import time
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.exception import InvokeException
|
||||
from app.log import logger
|
||||
from app.models.signalr import UserState
|
||||
from app.signalr.exception import InvokeException
|
||||
from app.signalr.packet import (
|
||||
ClosePacket,
|
||||
CompletionPacket,
|
||||
|
||||
@@ -1,18 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Coroutine
|
||||
from datetime import UTC, datetime
|
||||
import math
|
||||
from typing import override
|
||||
|
||||
from app.database import Relationship, RelationshipType
|
||||
from app.database.lazer_user import User
|
||||
from app.calculator import clamp
|
||||
from app.database import Relationship, RelationshipType, User
|
||||
from app.database.playlist_best_score import PlaylistBestScore
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import Room
|
||||
from app.dependencies.database import engine, get_redis
|
||||
from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity
|
||||
from app.models.metadata_hub import (
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
DailyChallengeInfo,
|
||||
MetadataClientState,
|
||||
MultiplayerPlaylistItemStats,
|
||||
MultiplayerRoomScoreSetEvent,
|
||||
MultiplayerRoomStats,
|
||||
OnlineStatus,
|
||||
UserActivity,
|
||||
)
|
||||
from app.models.room import RoomCategory
|
||||
from app.service.subscribers.score_processed import ScoreSubscriber
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
||||
@@ -21,11 +37,33 @@ ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
||||
class MetadataHub(Hub[MetadataClientState]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.subscriber = ScoreSubscriber()
|
||||
self.subscriber.metadata_hub = self
|
||||
self._daily_challenge_stats: MultiplayerRoomStats | None = None
|
||||
self._today = datetime.now(UTC).date()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def get_daily_challenge_stats(
|
||||
self, daily_challenge_room: int
|
||||
) -> MultiplayerRoomStats:
|
||||
if (
|
||||
self._daily_challenge_stats is None
|
||||
or self._today != datetime.now(UTC).date()
|
||||
):
|
||||
self._daily_challenge_stats = MultiplayerRoomStats(
|
||||
room_id=daily_challenge_room,
|
||||
playlist_item_stats={},
|
||||
)
|
||||
return self._daily_challenge_stats
|
||||
|
||||
@staticmethod
|
||||
def online_presence_watchers_group() -> str:
|
||||
return ONLINE_PRESENCE_WATCHERS_GROUP
|
||||
|
||||
@staticmethod
|
||||
def room_watcher_group(room_id: int) -> str:
|
||||
return f"metadata:multiplayer-room-watchers:{room_id}"
|
||||
|
||||
def broadcast_tasks(
|
||||
self, user_id: int, store: MetadataClientState | None
|
||||
) -> set[Coroutine]:
|
||||
@@ -102,10 +140,29 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
self.friend_presence_watchers_group(friend_id),
|
||||
"FriendPresenceUpdated",
|
||||
friend_id,
|
||||
friend_state if friend_state.pushable else None,
|
||||
friend_state.for_push
|
||||
if friend_state.pushable
|
||||
else None,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
daily_challenge_room = (
|
||||
await session.exec(
|
||||
select(Room).where(
|
||||
col(Room.ends_at) > datetime.now(UTC),
|
||||
Room.category == RoomCategory.DAILY_CHALLENGE,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if daily_challenge_room:
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"DailyChallengeUpdated",
|
||||
DailyChallengeInfo(
|
||||
room_id=daily_challenge_room.id,
|
||||
),
|
||||
)
|
||||
redis = get_redis()
|
||||
await redis.set(f"metadata:online:{user_id}", "")
|
||||
|
||||
@@ -161,3 +218,76 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
|
||||
async def EndWatchingUserPresence(self, client: Client) -> None:
|
||||
self.remove_from_group(client, self.online_presence_watchers_group())
|
||||
|
||||
async def notify_room_score_processed(self, event: MultiplayerRoomScoreSetEvent):
|
||||
await self.broadcast_group_call(
|
||||
self.room_watcher_group(event.room_id), "MultiplayerRoomScoreSet", event
|
||||
)
|
||||
|
||||
async def BeginWatchingMultiplayerRoom(self, client: Client, room_id: int):
|
||||
self.add_to_group(client, self.room_watcher_group(room_id))
|
||||
await self.subscriber.subscribe_room_score(room_id, client.user_id)
|
||||
stats = self.get_daily_challenge_stats(room_id)
|
||||
await self.update_daily_challenge_stats(stats)
|
||||
return list(stats.playlist_item_stats.values())
|
||||
|
||||
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
|
||||
async with AsyncSession(engine) as session:
|
||||
playlist_ids = (
|
||||
await session.exec(
|
||||
select(Playlist.id).where(
|
||||
Playlist.room_id == stats.room_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
for playlist_id in playlist_ids:
|
||||
item = stats.playlist_item_stats.get(playlist_id, None)
|
||||
if item is None:
|
||||
item = MultiplayerPlaylistItemStats(
|
||||
playlist_item_id=playlist_id,
|
||||
total_score_distribution=[0] * TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
cumulative_score=0,
|
||||
last_processed_score_id=0,
|
||||
)
|
||||
stats.playlist_item_stats[playlist_id] = item
|
||||
last_processed_score_id = item.last_processed_score_id
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == stats.room_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.score_id > last_processed_score_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
if len(scores) == 0:
|
||||
continue
|
||||
|
||||
async with self._lock:
|
||||
if item.last_processed_score_id == last_processed_score_id:
|
||||
totals = defaultdict(int)
|
||||
for score in scores:
|
||||
bin_index = int(
|
||||
clamp(
|
||||
math.floor(score.total_score / 100000),
|
||||
0,
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS - 1,
|
||||
)
|
||||
)
|
||||
totals[bin_index] += 1
|
||||
|
||||
item.cumulative_score += sum(
|
||||
score.total_score for score in scores
|
||||
)
|
||||
|
||||
for j in range(TOTAL_SCORE_DISTRIBUTION_BINS):
|
||||
item.total_score_distribution[j] += totals.get(j, 0)
|
||||
|
||||
if scores:
|
||||
item.last_processed_score_id = max(
|
||||
score.score_id for score in scores
|
||||
)
|
||||
|
||||
async def EndWatchingMultiplayerRoom(self, client: Client, room_id: int):
|
||||
self.remove_from_group(client, self.room_watcher_group(room_id))
|
||||
await self.subscriber.unsubscribe_room_score(room_id, client.user_id)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,7 @@ from app.database import Beatmap, User
|
||||
from app.database.score import Score
|
||||
from app.database.score_token import ScoreToken
|
||||
from app.dependencies.database import engine
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import mods_to_int
|
||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
|
||||
@@ -179,15 +180,13 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
return
|
||||
if state.beatmap_id is None or state.ruleset_id is None:
|
||||
return
|
||||
|
||||
fetcher = await get_fetcher()
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap).where(Beatmap.id == state.beatmap_id)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
return
|
||||
beatmap = await Beatmap.get_or_fetch(
|
||||
session, fetcher, bid=state.beatmap_id
|
||||
)
|
||||
user = (
|
||||
await session.exec(select(User).where(User.id == user_id))
|
||||
).first()
|
||||
@@ -237,16 +236,16 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
score = store.score
|
||||
assert store.beatmap_status is not None
|
||||
assert store.state is not None
|
||||
assert store.score is not None
|
||||
if not score or not store.score_token:
|
||||
if (
|
||||
score is None
|
||||
or store.score_token is None
|
||||
or store.beatmap_status is None
|
||||
or store.state is None
|
||||
):
|
||||
return
|
||||
if (
|
||||
BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED
|
||||
) and any(
|
||||
k.is_hit() and v > 0 for k, v in store.score.score_info.statistics.items()
|
||||
):
|
||||
) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()):
|
||||
await self._process_score(store, client)
|
||||
store.state = None
|
||||
store.beatmap_status = None
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from app.models.signalr import SignalRMeta, SignalRUnionMessage
|
||||
from app.utils import camel_to_snake, snake_to_camel
|
||||
from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal
|
||||
|
||||
import msgpack_lazer_api as m
|
||||
from pydantic import BaseModel
|
||||
@@ -97,6 +97,8 @@ class MsgpackProtocol:
|
||||
return [cls.serialize_msgpack(item) for item in v]
|
||||
elif issubclass(typ, datetime.datetime):
|
||||
return [v, 0]
|
||||
elif issubclass(typ, datetime.timedelta):
|
||||
return int(v.total_seconds() * 10_000_000)
|
||||
elif isinstance(v, dict):
|
||||
return {
|
||||
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
|
||||
@@ -126,15 +128,19 @@ class MsgpackProtocol:
|
||||
def process_object(v: Any, typ: type[BaseModel]) -> Any:
|
||||
if isinstance(v, list):
|
||||
d = {}
|
||||
for i, f in enumerate(typ.model_fields.items()):
|
||||
field, info = f
|
||||
if info.exclude:
|
||||
i = 0
|
||||
for field, info in typ.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.member_ignore:
|
||||
continue
|
||||
anno = info.annotation
|
||||
if anno is None:
|
||||
d[camel_to_snake(field)] = v[i]
|
||||
continue
|
||||
d[field] = MsgpackProtocol.validate_object(v[i], anno)
|
||||
else:
|
||||
d[field] = MsgpackProtocol.validate_object(v[i], anno)
|
||||
i += 1
|
||||
return d
|
||||
return v
|
||||
|
||||
@@ -209,7 +215,9 @@ class MsgpackProtocol:
|
||||
return typ.model_validate(obj=cls.process_object(v, typ))
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||
return v[0]
|
||||
elif isinstance(v, list):
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
|
||||
return datetime.timedelta(seconds=int(v / 10_000_000))
|
||||
elif get_origin(typ) is list:
|
||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
@@ -234,7 +242,9 @@ class MsgpackProtocol:
|
||||
# except `X (Other Type) | None`
|
||||
if NoneType in args and v is None:
|
||||
return None
|
||||
if not all(issubclass(arg, SignalRUnionMessage) for arg in args):
|
||||
if not all(
|
||||
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot validate {v} to {typ}, "
|
||||
"only SignalRUnionMessage subclasses are supported"
|
||||
@@ -292,36 +302,55 @@ class MsgpackProtocol:
|
||||
|
||||
class JSONProtocol:
|
||||
@classmethod
|
||||
def serialize_to_json(cls, v: Any):
|
||||
def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False):
|
||||
typ = v.__class__
|
||||
if issubclass(typ, BaseModel):
|
||||
return cls.serialize_model(v)
|
||||
return cls.serialize_model(v, in_union)
|
||||
elif isinstance(v, dict):
|
||||
return {
|
||||
cls.serialize_to_json(k): cls.serialize_to_json(value)
|
||||
cls.serialize_to_json(k, True): cls.serialize_to_json(value)
|
||||
for k, value in v.items()
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
return [cls.serialize_to_json(item) for item in v]
|
||||
elif isinstance(v, datetime.datetime):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, Enum):
|
||||
elif isinstance(v, datetime.timedelta):
|
||||
# d.hh:mm:ss
|
||||
total_seconds = int(v.total_seconds())
|
||||
hours, remainder = divmod(total_seconds, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
return f"{hours:02}:{minutes:02}:{seconds:02}"
|
||||
elif isinstance(v, Enum) and dict_key:
|
||||
return v.value
|
||||
elif isinstance(v, Enum):
|
||||
list_ = list(typ)
|
||||
return list_.index(v)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def serialize_model(cls, v: BaseModel) -> dict[str, Any]:
|
||||
def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]:
|
||||
d = {}
|
||||
is_union = issubclass(v.__class__, SignalRUnionMessage)
|
||||
for field, info in v.__class__.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.json_ignore:
|
||||
continue
|
||||
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = (
|
||||
cls.serialize_to_json(getattr(v, field))
|
||||
name = (
|
||||
snake_to_camel(
|
||||
field,
|
||||
metadata.use_abbr if metadata else True,
|
||||
)
|
||||
if not is_union
|
||||
else snake_to_pascal(
|
||||
field,
|
||||
metadata.use_abbr if metadata else True,
|
||||
)
|
||||
)
|
||||
if issubclass(v.__class__, SignalRUnionMessage):
|
||||
d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union)
|
||||
if is_union and not in_union:
|
||||
return {
|
||||
"$dtype": v.__class__.__name__,
|
||||
"$value": d,
|
||||
@@ -339,7 +368,12 @@ class JSONProtocol:
|
||||
)
|
||||
if metadata and metadata.json_ignore:
|
||||
continue
|
||||
value = v.get(snake_to_camel(field, not from_union))
|
||||
name = (
|
||||
snake_to_camel(field, metadata.use_abbr if metadata else True)
|
||||
if not from_union
|
||||
else snake_to_pascal(field, metadata.use_abbr if metadata else True)
|
||||
)
|
||||
value = v.get(name)
|
||||
anno = typ.model_fields[field].annotation
|
||||
if anno is None:
|
||||
d[field] = value
|
||||
@@ -397,7 +431,18 @@ class JSONProtocol:
|
||||
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||
return datetime.datetime.fromisoformat(v)
|
||||
elif isinstance(v, list):
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
|
||||
# d.hh:mm:ss
|
||||
parts = v.split(":")
|
||||
if len(parts) == 3:
|
||||
return datetime.timedelta(
|
||||
hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2])
|
||||
)
|
||||
elif len(parts) == 2:
|
||||
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
|
||||
elif len(parts) == 1:
|
||||
return datetime.timedelta(seconds=int(parts[0]))
|
||||
elif get_origin(typ) is list:
|
||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
from typing import Literal
|
||||
import uuid
|
||||
|
||||
from app.database import User
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user_by_token
|
||||
@@ -25,7 +25,7 @@ router = APIRouter()
|
||||
async def negotiate(
|
||||
hub: Literal["spectator", "multiplayer", "metadata"],
|
||||
negotiate_version: int = Query(1, alias="negotiateVersion"),
|
||||
user: User = Depends(get_current_user),
|
||||
user: DBUser = Depends(get_current_user),
|
||||
):
|
||||
connectionId = str(user.id)
|
||||
connectionToken = f"{connectionId}:{uuid.uuid4()}"
|
||||
|
||||
40
app/utils.py
40
app/utils.py
@@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str:
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def snake_to_camel(name: str, lower_case: bool = True) -> str:
|
||||
def snake_to_camel(name: str, use_abbr: bool = True) -> str:
|
||||
"""Convert a snake_case string to camelCase."""
|
||||
if not name:
|
||||
return name
|
||||
@@ -47,12 +47,46 @@ def snake_to_camel(name: str, lower_case: bool = True) -> str:
|
||||
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.lower() in abbreviations:
|
||||
if part.lower() in abbreviations and use_abbr:
|
||||
result.append(part.upper())
|
||||
else:
|
||||
if result or not lower_case:
|
||||
if result:
|
||||
result.append(part.capitalize())
|
||||
else:
|
||||
result.append(part.lower())
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def snake_to_pascal(name: str, use_abbr: bool = True) -> str:
|
||||
"""Convert a snake_case string to PascalCase."""
|
||||
if not name:
|
||||
return name
|
||||
|
||||
parts = name.split("_")
|
||||
if not parts:
|
||||
return name
|
||||
|
||||
# 常见缩写词列表
|
||||
abbreviations = {
|
||||
"id",
|
||||
"url",
|
||||
"api",
|
||||
"http",
|
||||
"https",
|
||||
"xml",
|
||||
"json",
|
||||
"css",
|
||||
"html",
|
||||
"sql",
|
||||
"db",
|
||||
}
|
||||
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.lower() in abbreviations and use_abbr:
|
||||
result.append(part.upper())
|
||||
else:
|
||||
result.append(part.capitalize())
|
||||
|
||||
return "".join(result)
|
||||
|
||||
Reference in New Issue
Block a user