From 2600fa499f05d266072fb447cd708af4587fc1c9 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 3 Aug 2025 12:53:22 +0000 Subject: [PATCH] feat(multiplayer): support play WIP --- app/database/playlists.py | 18 +-- app/database/room.py | 2 +- app/database/score.py | 7 +- app/models/mods.py | 12 -- app/models/multiplayer_hub.py | 282 ++++++++++++++++++++++----------- app/models/room.py | 9 ++ app/models/signalr.py | 1 + app/router/score.py | 219 ++++++++++++++++++------- app/signalr/hub/multiplayer.py | 262 ++++++++++++++++++++++++++++-- app/signalr/packet.py | 46 +++++- app/utils.py | 4 +- 11 files changed, 666 insertions(+), 196 deletions(-) diff --git a/app/database/playlists.py b/app/database/playlists.py index 10ad86b..328f17d 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import TYPE_CHECKING from app.models.model import UTCBaseModel -from app.models.mods import APIMod, msgpack_to_apimod +from app.models.mods import APIMod from app.models.multiplayer_hub import PlaylistItem from .beatmap import Beatmap, BeatmapResp @@ -79,10 +79,10 @@ class Playlist(PlaylistBase, table=True): owner_id=playlist.owner_id, ruleset_id=playlist.ruleset_id, beatmap_id=playlist.beatmap_id, - required_mods=[msgpack_to_apimod(mod) for mod in playlist.required_mods], - allowed_mods=[msgpack_to_apimod(mod) for mod in playlist.allowed_mods], + required_mods=playlist.required_mods, + allowed_mods=playlist.allowed_mods, expired=playlist.expired, - playlist_order=playlist.order, + playlist_order=playlist.playlist_order, played_at=playlist.played_at, freestyle=playlist.freestyle, room_id=room_id, @@ -99,14 +99,10 @@ class Playlist(PlaylistBase, table=True): db_playlist.owner_id = playlist.owner_id db_playlist.ruleset_id = playlist.ruleset_id db_playlist.beatmap_id = playlist.beatmap_id - db_playlist.required_mods = [ - msgpack_to_apimod(mod) for mod in playlist.required_mods - ] - db_playlist.allowed_mods = [ - msgpack_to_apimod(mod) for mod in playlist.allowed_mods - ] + db_playlist.required_mods = playlist.required_mods + db_playlist.allowed_mods = playlist.allowed_mods db_playlist.expired = playlist.expired - db_playlist.playlist_order = playlist.order + db_playlist.playlist_order = playlist.playlist_order db_playlist.played_at = playlist.played_at db_playlist.freestyle = playlist.freestyle await session.commit() diff --git a/app/database/room.py b/app/database/room.py index 8eb882d..80457b6 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -125,7 +125,7 @@ class RoomResp(RoomBase): type=room.settings.match_type, queue_mode=room.settings.queue_mode, auto_skip=room.settings.auto_skip, - auto_start_duration=room.settings.auto_start_duration, + auto_start_duration=int(room.settings.auto_start_duration.total_seconds()), status=server_room.status, category=server_room.category, # duration = room.settings.duration, diff --git a/app/database/score.py b/app/database/score.py index 1bd5978..abc3d75 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -91,7 +91,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): # optional # TODO: current_user_attributes - position: int | None = Field(default=None) # multiplayer + # position: int | None = Field(default=None) # multiplayer class Score(ScoreBase, table=True): @@ -162,6 +162,7 @@ class ScoreResp(ScoreBase): maximum_statistics: ScoreStatistics | None = None rank_global: int | None = None rank_country: int | None = None + position: int = 1 # TODO @classmethod async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": @@ -618,6 +619,8 @@ async def process_score( fetcher: "Fetcher", session: AsyncSession, redis: Redis, + item_id: int | None = None, + room_id: int | None = None, ) -> Score: assert user.id can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods) @@ -649,6 +652,8 @@ async def process_score( nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0), nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0), nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0), + playlist_item_id=item_id, + room_id=room_id, ) if can_get_pp: beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) diff --git a/app/models/mods.py b/app/models/mods.py index 4b20138..299a05f 100644 --- a/app/models/mods.py +++ b/app/models/mods.py @@ -5,8 +5,6 @@ from typing import Literal, NotRequired, TypedDict from app.path import STATIC_DIR -from msgpack_lazer_api import APIMod as MsgpackAPIMod - class APIMod(TypedDict): acronym: str @@ -169,13 +167,3 @@ def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool: if expected_value != NO_CHECK and value != expected_value: return False return True - - -def msgpack_to_apimod(mod: MsgpackAPIMod) -> APIMod: - """ - Convert a MsgpackAPIMod to an APIMod. - """ - return APIMod( - acronym=mod.acronym, - settings=mod.settings, - ) diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 9bccb71..ba8a050 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -1,13 +1,16 @@ from __future__ import annotations -from dataclasses import dataclass -import datetime -from typing import TYPE_CHECKING, Annotated, Any, Literal +import asyncio +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal from app.database.beatmap import Beatmap from app.dependencies.database import engine from app.exception import InvokeException +from .mods import APIMod from .room import ( DownloadState, MatchType, @@ -18,15 +21,14 @@ from .room import ( RoomStatus, ) from .signalr import ( - EnumByIndex, - MessagePackArrayModel, + SignalRMeta, + SignalRUnionMessage, UserState, - msgpack_union, - msgpack_union_dump, ) -from msgpack_lazer_api import APIMod -from pydantic import Field, field_serializer, field_validator +from pydantic import BaseModel, Field +from sqlalchemy import update +from sqlmodel import col from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: @@ -40,37 +42,37 @@ class MultiplayerClientState(UserState): room_id: int = 0 -class MultiplayerRoomSettings(MessagePackArrayModel): +class MultiplayerRoomSettings(BaseModel): name: str = "Unnamed Room" - playlist_item_id: int = 0 + playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)] password: str = "" - match_type: Annotated[MatchType, EnumByIndex(MatchType)] = MatchType.HEAD_TO_HEAD - queue_mode: Annotated[QueueMode, EnumByIndex(QueueMode)] = QueueMode.HOST_ONLY - auto_start_duration: int = 0 + match_type: MatchType = MatchType.HEAD_TO_HEAD + queue_mode: QueueMode = QueueMode.HOST_ONLY + auto_start_duration: timedelta = timedelta(seconds=0) auto_skip: bool = False -class BeatmapAvailability(MessagePackArrayModel): - state: Annotated[DownloadState, EnumByIndex(DownloadState)] = DownloadState.UNKNOWN +class BeatmapAvailability(BaseModel): + state: DownloadState = DownloadState.UNKNOWN progress: float | None = None -class _MatchUserState(MessagePackArrayModel): ... +class _MatchUserState(SignalRUnionMessage): ... class TeamVersusUserState(_MatchUserState): team_id: int - type: Literal[0] = Field(0, exclude=True) + union_type: ClassVar[Literal[0]] = 0 MatchUserState = TeamVersusUserState -class _MatchRoomState(MessagePackArrayModel): ... +class _MatchRoomState(SignalRUnionMessage): ... -class MultiplayerTeam(MessagePackArrayModel): +class MultiplayerTeam(BaseModel): id: int name: str @@ -83,24 +85,24 @@ class TeamVersusRoomState(_MatchRoomState): ] ) - type: Literal[0] = Field(0, exclude=True) + union_type: ClassVar[Literal[0]] = 0 MatchRoomState = TeamVersusRoomState -class PlaylistItem(MessagePackArrayModel): - id: int +class PlaylistItem(BaseModel): + id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)] owner_id: int beatmap_id: int - checksum: str + beatmap_checksum: str ruleset_id: int required_mods: list[APIMod] = Field(default_factory=list) allowed_mods: list[APIMod] = Field(default_factory=list) expired: bool - order: int - played_at: datetime.datetime | None = None - star: float + playlist_order: int + played_at: datetime | None = None + star_rating: float freestyle: bool def validate_user_mods( @@ -127,7 +129,10 @@ class PlaylistItem(MessagePackArrayModel): # Check if mods are valid for the ruleset for mod in proposed_mods: - if ruleset_key not in API_MODS or mod.acronym not in API_MODS[ruleset_key]: + if ( + ruleset_key not in API_MODS + or mod["acronym"] not in API_MODS[ruleset_key] + ): all_proposed_valid = False continue valid_mods.append(mod) @@ -136,35 +141,35 @@ class PlaylistItem(MessagePackArrayModel): incompatible_mods = set() final_valid_mods = [] for mod in valid_mods: - if mod.acronym in incompatible_mods: + if mod["acronym"] in incompatible_mods: all_proposed_valid = False continue - setting_mods = API_MODS[ruleset_key].get(mod.acronym) + setting_mods = API_MODS[ruleset_key].get(mod["acronym"]) if setting_mods: incompatible_mods.update(setting_mods["IncompatibleMods"]) final_valid_mods.append(mod) # If not freestyle, check against allowed mods if not self.freestyle: - allowed_acronyms = {mod.acronym for mod in self.allowed_mods} + allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} filtered_valid_mods = [] for mod in final_valid_mods: - if mod.acronym not in allowed_acronyms: + if mod["acronym"] not in allowed_acronyms: all_proposed_valid = False else: filtered_valid_mods.append(mod) final_valid_mods = filtered_valid_mods # Check compatibility with required mods - required_mod_acronyms = {mod.acronym for mod in self.required_mods} + required_mod_acronyms = {mod["acronym"] for mod in self.required_mods} all_mod_acronyms = { - mod.acronym for mod in final_valid_mods + mod["acronym"] for mod in final_valid_mods } | required_mod_acronyms # Check for incompatibility between required and user mods filtered_valid_mods = [] for mod in final_valid_mods: - mod_acronym = mod.acronym + mod_acronym = mod["acronym"] is_compatible = True for other_acronym in all_mod_acronyms: @@ -181,23 +186,29 @@ class PlaylistItem(MessagePackArrayModel): return all_proposed_valid, filtered_valid_mods + def clone(self) -> "PlaylistItem": + copy = self.model_copy() + copy.required_mods = list(self.required_mods) + copy.allowed_mods = list(self.allowed_mods) + return copy -class _MultiplayerCountdown(MessagePackArrayModel): - id: int - remaining: int - is_exclusive: bool + +class _MultiplayerCountdown(BaseModel): + id: int = 0 + remaining: timedelta + is_exclusive: bool = False class MatchStartCountdown(_MultiplayerCountdown): - type: Literal[0] = Field(0, exclude=True) + union_type: ClassVar[Literal[0]] = 0 class ForceGameplayStartCountdown(_MultiplayerCountdown): - type: Literal[1] = Field(1, exclude=True) + union_type: ClassVar[Literal[1]] = 1 class ServerShuttingDownCountdown(_MultiplayerCountdown): - type: Literal[2] = Field(2, exclude=True) + union_type: ClassVar[Literal[2]] = 2 MultiplayerCountdown = ( @@ -205,11 +216,9 @@ MultiplayerCountdown = ( ) -class MultiplayerRoomUser(MessagePackArrayModel): +class MultiplayerRoomUser(BaseModel): user_id: int - state: Annotated[MultiplayerUserState, EnumByIndex(MultiplayerUserState)] = ( - MultiplayerUserState.IDLE - ) + state: MultiplayerUserState = MultiplayerUserState.IDLE availability: BeatmapAvailability = BeatmapAvailability( state=DownloadState.UNKNOWN, progress=None ) @@ -218,50 +227,33 @@ class MultiplayerRoomUser(MessagePackArrayModel): ruleset_id: int | None = None # freestyle beatmap_id: int | None = None # freestyle - @field_validator("match_state", mode="before") - def union_validate(v: Any): - if isinstance(v, list): - return msgpack_union(v) - return v - @field_serializer("match_state") - def union_serialize(v: Any): - return msgpack_union_dump(v) - - -class MultiplayerRoom(MessagePackArrayModel): +class MultiplayerRoom(BaseModel): room_id: int - state: Annotated[MultiplayerRoomState, EnumByIndex(MultiplayerRoomState)] + state: MultiplayerRoomState settings: MultiplayerRoomSettings users: list[MultiplayerRoomUser] = Field(default_factory=list) host: MultiplayerRoomUser | None = None match_state: MatchRoomState | None = None playlist: list[PlaylistItem] = Field(default_factory=list) - active_cooldowns: list[MultiplayerCountdown] = Field(default_factory=list) + active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list) channel_id: int - @field_validator("match_state", mode="before") - def union_validate(v: Any): - if isinstance(v, list): - return msgpack_union(v) - return v - - @field_serializer("match_state") - def union_serialize(v: Any): - return msgpack_union_dump(v) - class MultiplayerQueue: - def __init__(self, room: "ServerMultiplayerRoom", hub: "MultiplayerHub"): + def __init__(self, room: "ServerMultiplayerRoom"): self.server_room = room - self.hub = hub self.current_index = 0 + @property + def hub(self) -> "MultiplayerHub": + return self.server_room.hub + @property def upcoming_items(self): return sorted( (item for item in self.room.playlist if not item.expired), - key=lambda i: i.order, + key=lambda i: i.playlist_order, ) @property @@ -323,9 +315,9 @@ class MultiplayerQueue: ) async with AsyncSession(engine) as session: for idx, item in enumerate(ordered_active_items): - if item.order == idx: + if item.playlist_order == idx: continue - item.order = idx + item.playlist_order = idx await Playlist.update(item, self.room.room_id, session) await self.hub.playlist_changed( self.server_room, item, beatmap_changed=False @@ -338,7 +330,7 @@ class MultiplayerQueue: if upcoming_items else max( self.room.playlist, - key=lambda i: i.played_at or datetime.datetime.min, + key=lambda i: i.played_at or datetime.min, ) ) self.current_index = self.room.playlist.index(next_item) @@ -356,14 +348,7 @@ class MultiplayerQueue: limit = HOST_LIMIT if is_host else PER_USER_LIMIT if ( - len( - list( - filter( - lambda x: x.owner_id == user.user_id, - self.room.playlist, - ) - ) - ) + len([True for u in self.room.playlist if u.owner_id == user.user_id]) >= limit ): raise InvokeException(f"You can only have {limit} items in the queue") @@ -376,11 +361,11 @@ class MultiplayerQueue: beatmap = await session.get(Beatmap, item.beatmap_id) if beatmap is None: raise InvokeException("Beatmap not found") - if item.checksum != beatmap.checksum: + if item.beatmap_checksum != beatmap.checksum: raise InvokeException("Checksum mismatch") # TODO: mods validation item.owner_id = user.user_id - item.star = float( + item.star_rating = float( beatmap.difficulty_rating ) # FIXME: beatmap use decimal await Playlist.add_to_db(item, self.room.room_id, session) @@ -400,7 +385,7 @@ class MultiplayerQueue: beatmap = await session.get(Beatmap, item.beatmap_id) if beatmap is None: raise InvokeException("Beatmap not found") - if item.checksum != beatmap.checksum: + if item.beatmap_checksum != beatmap.checksum: raise InvokeException("Checksum mismatch") existing_item = next( @@ -423,8 +408,8 @@ class MultiplayerQueue: # TODO: mods validation item.owner_id = user.user_id - item.star = float(beatmap.difficulty_rating) - item.order = existing_item.order + item.star_rating = float(beatmap.difficulty_rating) + item.playlist_order = existing_item.playlist_order await Playlist.update(item, self.room.room_id, session) @@ -437,7 +422,8 @@ class MultiplayerQueue: await self.hub.playlist_changed( self.server_room, item, - beatmap_changed=item.checksum != existing_item.checksum, + beatmap_changed=item.beatmap_checksum + != existing_item.beatmap_checksum, ) async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser): @@ -477,12 +463,46 @@ class MultiplayerQueue: await self.update_current_item() await self.hub.playlist_removed(self.server_room, item.id) + async def finish_current_item(self): + from app.database import Playlist + + async with AsyncSession(engine) as session: + played_at = datetime.now(UTC) + await session.execute( + update(Playlist) + .where( + col(Playlist.id) == self.current_item.id, + col(Playlist.room_id) == self.room.room_id, + ) + .values(expired=True, played_at=played_at) + ) + self.room.playlist[self.current_index].expired = True + self.room.playlist[self.current_index].played_at = played_at + await self.hub.playlist_changed(self.server_room, self.current_item, True) + await self.update_order() + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all( + playitem.expired for playitem in self.room.playlist + ): + assert self.room.host + await self.add_item(self.current_item.clone(), self.room.host) + @property def current_item(self): - """Get the current playlist item""" - current_id = self.room.settings.playlist_item_id - return next( - (item for item in self.room.playlist if item.id == current_id), + return self.room.playlist[self.current_index] + + +@dataclass +class CountdownInfo: + countdown: MultiplayerCountdown + duration: timedelta + task: asyncio.Task | None = None + + def __init__(self, countdown: MultiplayerCountdown): + self.countdown = countdown + self.duration = ( + countdown.remaining + if countdown.remaining > timedelta(seconds=0) + else timedelta(seconds=0) ) @@ -491,5 +511,79 @@ class ServerMultiplayerRoom: room: MultiplayerRoom category: RoomCategory status: RoomStatus - start_at: datetime.datetime + start_at: datetime + hub: "MultiplayerHub" queue: MultiplayerQueue | None = None + _next_countdown_id: int = 0 + _countdown_id_lock: asyncio.Lock = field(default_factory=asyncio.Lock) + _tracked_countdown: dict[int, CountdownInfo] = field(default_factory=dict) + + async def get_next_countdown_id(self) -> int: + async with self._countdown_id_lock: + self._next_countdown_id += 1 + return self._next_countdown_id + + async def start_countdown( + self, + countdown: MultiplayerCountdown, + on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None, + ): + async def _countdown_task(self: "ServerMultiplayerRoom"): + await asyncio.sleep(info.duration.total_seconds()) + await self.stop_countdown(countdown) + if on_complete is not None: + await on_complete(self) + + if countdown.is_exclusive: + await self.stop_all_countdowns() + + countdown.id = await self.get_next_countdown_id() + info = CountdownInfo(countdown) + self.room.active_countdowns.append(info.countdown) + self._tracked_countdown[countdown.id] = info + await self.hub.send_match_event( + self, CountdownStartedEvent(countdown=info.countdown) + ) + info.task = asyncio.create_task(_countdown_task(self)) + + async def stop_countdown(self, countdown: MultiplayerCountdown): + info = next( + ( + info + for info in self._tracked_countdown.values() + if info.countdown.id == countdown.id + ), + None, + ) + if info is None: + return + if info.task is not None and not info.task.done(): + info.task.cancel() + del self._tracked_countdown[countdown.id] + self.room.active_countdowns.remove(countdown) + await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id)) + + async def stop_all_countdowns(self): + for countdown in list(self._tracked_countdown.values()): + await self.stop_countdown(countdown.countdown) + + self._tracked_countdown.clear() + self.room.active_countdowns.clear() + + +class _MatchServerEvent(BaseModel): ... + + +class CountdownStartedEvent(_MatchServerEvent): + countdown: MultiplayerCountdown + + type: Literal[0] = Field(default=0, exclude=True) + + +class CountdownStoppedEvent(_MatchServerEvent): + id: int + + type: Literal[1] = Field(default=1, exclude=True) + + +MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent diff --git a/app/models/room.py b/app/models/room.py index 42f897c..392562a 100644 --- a/app/models/room.py +++ b/app/models/room.py @@ -53,6 +53,15 @@ class MultiplayerUserState(str, Enum): RESULTS = "results" SPECTATING = "spectating" + @property + def is_playing(self) -> bool: + return self in { + self.WAITING_FOR_LOAD, + self.PLAYING, + self.READY_FOR_GAMEPLAY, + self.LOADED, + } + class DownloadState(str, Enum): UNKNOWN = "unknown" diff --git a/app/models/signalr.py b/app/models/signalr.py index de66e30..7116ea0 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -14,6 +14,7 @@ class SignalRMeta: member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute json_ignore: bool = False # implement of JsonIgnore (json) attribute use_upper_case: bool = False # use upper CamelCase for field names + use_abbr: bool = True class SignalRUnionMessage(BaseModel): diff --git a/app/router/score.py b/app/router/score.py index 2f1303e..b50911d 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,10 +1,19 @@ from __future__ import annotations -from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User +from app.database import ( + Beatmap, + Playlist, + Score, + ScoreResp, + ScoreToken, + ScoreTokenResp, + User, +) from app.database.score import get_leaderboard, process_score, process_user from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user +from app.fetcher import Fetcher from app.models.beatmap import BeatmapRankStatus from app.models.score import ( INT_TO_MODE, @@ -13,6 +22,7 @@ from app.models.score import ( Rank, SoloScoreSubmissionInfo, ) +from app.signalr.hub import MultiplayerHubs from .api_router import router @@ -24,6 +34,68 @@ from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession +async def submit_score( + info: SoloScoreSubmissionInfo, + beatmap: int, + token: int, + current_user: User, + db: AsyncSession, + redis: Redis, + fetcher: Fetcher, + item_id: int | None = None, + room_id: int | None = None, +): + if not info.passed: + info.rank = Rank.F + score_token = ( + await db.exec( + select(ScoreToken) + .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] + .where(ScoreToken.id == token) + ) + ).first() + if not score_token or score_token.user_id != current_user.id: + raise HTTPException(status_code=404, detail="Score token not found") + if score_token.score_id: + score = ( + await db.exec( + select(Score).where( + Score.id == score_token.score_id, + Score.user_id == current_user.id, + ) + ) + ).first() + if not score: + raise HTTPException(status_code=404, detail="Score not found") + else: + beatmap_status = ( + await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)) + ).first() + if beatmap_status is None: + raise HTTPException(status_code=404, detail="Beatmap not found") + ranked = beatmap_status in { + BeatmapRankStatus.RANKED, + BeatmapRankStatus.APPROVED, + } + score = await process_score( + current_user, + beatmap, + ranked, + score_token, + info, + fetcher, + db, + redis, + ) + await db.refresh(current_user) + score_id = score.id + score_token.score_id = score_id + await process_user(db, current_user, score, ranked) + score = (await db.exec(select(Score).where(Score.id == score_id))).first() + assert score is not None + return await ScoreResp.from_db(db, score) + + class BeatmapScores(BaseModel): scores: list[ScoreResp] userScore: ScoreResp | None = None @@ -97,9 +169,10 @@ async def get_user_beatmap_score( status_code=404, detail=f"Cannot find user {user}'s score on this beatmap" ) else: + resp = await ScoreResp.from_db(db, user_score) return BeatmapUserScore( - position=user_score.position if user_score.position is not None else 0, - score=await ScoreResp.from_db(db, user_score), + position=resp.rank_global or 0, + score=resp, ) @@ -173,55 +246,95 @@ async def submit_solo_score( redis: Redis = Depends(get_redis), fetcher=Depends(get_fetcher), ): - if not info.passed: - info.rank = Rank.F - async with db: - score_token = ( - await db.exec( - select(ScoreToken) - .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] - .where(ScoreToken.id == token, ScoreToken.user_id == current_user.id) + return await submit_score(info, beatmap, token, current_user, db, redis, fetcher) + + +@router.post( + "/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp +) +async def create_playlist_score( + room_id: int, + playlist_id: int, + beatmap_id: int = Form(), + beatmap_hash: str = Form(), + ruleset_id: int = Form(..., ge=0, le=3), + version_hash: str = Form(""), + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), +): + room = MultiplayerHubs.rooms[room_id] + if not room: + raise HTTPException(status_code=404, detail="Room not found") + item = ( + await session.exec( + select(Playlist).where( + Playlist.id == playlist_id, Playlist.room_id == room_id ) - ).first() - if not score_token or score_token.user_id != current_user.id: - raise HTTPException(status_code=404, detail="Score token not found") - if score_token.score_id: - score = ( - await db.exec( - select(Score).where( - Score.id == score_token.score_id, - Score.user_id == current_user.id, - ) - ) - ).first() - if not score: - raise HTTPException(status_code=404, detail="Score not found") - else: - beatmap_status = ( - await db.exec( - select(Beatmap.beatmap_status).where(Beatmap.id == beatmap) - ) - ).first() - if beatmap_status is None: - raise HTTPException(status_code=404, detail="Beatmap not found") - ranked = beatmap_status in { - BeatmapRankStatus.RANKED, - BeatmapRankStatus.APPROVED, - } - score = await process_score( - current_user, - beatmap, - ranked, - score_token, - info, - fetcher, - db, - redis, + ) + ).first() + if not item: + raise HTTPException(status_code=404, detail="Playlist not found") + + # validate + if not item.freestyle: + if item.ruleset_id != ruleset_id: + raise HTTPException( + status_code=400, detail="Ruleset mismatch in playlist item" ) - await db.refresh(current_user) - score_id = score.id - score_token.score_id = score_id - await process_user(db, current_user, score, ranked) - score = (await db.exec(select(Score).where(Score.id == score_id))).first() - assert score is not None - return await ScoreResp.from_db(db, score) + if item.beatmap_id != beatmap_id: + raise HTTPException( + status_code=400, detail="Beatmap ID mismatch in playlist item" + ) + # TODO: max attempts + if item.expired: + raise HTTPException(status_code=400, detail="Playlist item has expired") + if item.played_at: + raise HTTPException( + status_code=400, detail="Playlist item has already been played" + ) + # 这里应该不用验证mod了吧。。。 + + score_token = ScoreToken( + user_id=current_user.id, + beatmap_id=beatmap_id, + ruleset_id=INT_TO_MODE[ruleset_id], + playlist_item_id=playlist_id, + ) + session.add(score_token) + await session.commit() + await session.refresh(score_token) + return ScoreTokenResp.from_db(score_token) + + +@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}") +async def submit_playlist_score( + room_id: int, + playlist_id: int, + token: int, + info: SoloScoreSubmissionInfo, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), + fetcher: Fetcher = Depends(get_fetcher), +): + item = ( + await session.exec( + select(Playlist).where( + Playlist.id == playlist_id, Playlist.room_id == room_id + ) + ) + ).first() + if not item: + raise HTTPException(status_code=404, detail="Playlist item not found") + score_resp = await submit_score( + info, + item.beatmap_id, + token, + current_user, + session, + redis, + fetcher, + item.id, + room_id, + ) + return score_resp diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index bd34be0..21f192c 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +from datetime import timedelta from typing import override from app.database import Room @@ -8,8 +10,11 @@ from app.database.playlists import Playlist from app.dependencies.database import engine from app.exception import InvokeException from app.log import logger +from app.models.mods import APIMod from app.models.multiplayer_hub import ( BeatmapAvailability, + ForceGameplayStartCountdown, + MatchServerEvent, MultiplayerClientState, MultiplayerQueue, MultiplayerRoom, @@ -17,16 +22,22 @@ from app.models.multiplayer_hub import ( PlaylistItem, ServerMultiplayerRoom, ) -from app.models.room import RoomCategory, RoomStatus +from app.models.room import ( + DownloadState, + MultiplayerRoomState, + MultiplayerUserState, + RoomCategory, + RoomStatus, +) from app.models.score import GameMode -from app.models.signalr import serialize_to_list from .hub import Client, Hub -from msgpack_lazer_api import APIMod from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +GAMEPLAY_LOAD_TIMEOUT = 30 + class MultiplayerHub(Hub[MultiplayerClientState]): @override @@ -58,7 +69,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): type=room.settings.match_type, queue_mode=room.settings.queue_mode, auto_skip=room.settings.auto_skip, - auto_start_duration=room.settings.auto_start_duration, + auto_start_duration=int( + room.settings.auto_start_duration.total_seconds() + ), host_id=client.user_id, status=RoomStatus.IDLE, ) @@ -75,10 +88,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): category=RoomCategory.NORMAL, status=RoomStatus.IDLE, start_at=starts_at, + hub=self, ) queue = MultiplayerQueue( room=server_room, - hub=self, ) server_room.queue = queue self.rooms[room.room_id] = server_room @@ -86,6 +99,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): client, room.room_id, room.settings.password ) + async def JoinRoom(self, client: Client, room_id: int): + return self.JoinRoomWithPassword(client, room_id, "") + async def JoinRoomWithPassword(self, client: Client, room_id: int, password: str): logger.info(f"[MultiplayerHub] {client.user_id} joining room {room_id}") store = self.get_or_create_state(client) @@ -105,12 +121,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): # from CreateRoom room.host = user store.room_id = room_id - await self.broadcast_group_call( - self.group_id(room_id), "UserJoined", serialize_to_list(user) - ) + await self.broadcast_group_call(self.group_id(room_id), "UserJoined", user) room.users.append(user) self.add_to_group(client, self.group_id(room_id)) - return serialize_to_list(room) + return room async def ChangeBeatmapAvailability( self, client: Client, beatmap_availability: BeatmapAvailability @@ -132,12 +146,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]): and availability.progress == beatmap_availability.progress ): return - user.availability = availability + user.availability = beatmap_availability await self.broadcast_group_call( self.group_id(store.room_id), "UserBeatmapAvailabilityChanged", user.user_id, - serialize_to_list(beatmap_availability), + (beatmap_availability), ) async def AddPlaylistItem(self, client: Client, item: PlaylistItem): @@ -198,14 +212,14 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await self.broadcast_group_call( self.group_id(room.room.room_id), "SettingsChanged", - serialize_to_list(room.room.settings), + (room.room.settings), ) async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem): await self.broadcast_group_call( self.group_id(room.room.room_id), "PlaylistItemAdded", - serialize_to_list(item), + (item), ) async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int): @@ -221,7 +235,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await self.broadcast_group_call( self.group_id(room.room.room_id), "PlaylistItemChanged", - serialize_to_list(item), + (item), ) async def ChangeUserStyle( @@ -378,7 +392,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) if not is_valid: incompatible_mods = [ - mod.acronym for mod in new_mods if mod not in valid_mods + mod["acronym"] for mod in new_mods if mod not in valid_mods ] raise InvokeException( f"Incompatible mods were selected: {','.join(incompatible_mods)}" @@ -395,3 +409,221 @@ class MultiplayerHub(Hub[MultiplayerClientState]): user.user_id, valid_mods, ) + + async def validate_user_stare( + self, + room: ServerMultiplayerRoom, + old: MultiplayerUserState, + new: MultiplayerUserState, + ): + assert room.queue + match new: + case MultiplayerUserState.IDLE: + if old.is_playing: + raise InvokeException( + "Cannot return to idle without aborting gameplay." + ) + case MultiplayerUserState.READY: + if old != MultiplayerUserState.IDLE: + raise InvokeException(f"Cannot change state from {old} to {new}") + if room.queue.current_item.expired: + raise InvokeException( + "Cannot ready up while all items have been played." + ) + case MultiplayerUserState.WAITING_FOR_LOAD: + raise InvokeException("Cannot change state from {old} to {new}") + case MultiplayerUserState.LOADED: + if old != MultiplayerUserState.WAITING_FOR_LOAD: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.READY_FOR_GAMEPLAY: + if old != MultiplayerUserState.LOADED: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.PLAYING: + raise InvokeException("State is managed by the server.") + case MultiplayerUserState.FINISHED_PLAY: + if old != MultiplayerUserState.PLAYING: + raise InvokeException(f"Cannot change state from {old} to {new}") + case MultiplayerUserState.RESULTS: + raise InvokeException("Cannot change state from {old} to {new}") + case MultiplayerUserState.SPECTATING: + if old not in (MultiplayerUserState.IDLE, MultiplayerUserState.READY): + raise InvokeException(f"Cannot change state from {old} to {new}") + + async def ChangeState(self, client: Client, state: MultiplayerUserState): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + if user.state == state: + return + match state: + case MultiplayerUserState.IDLE: + if user.state.is_playing: + return + case MultiplayerUserState.LOADED | MultiplayerUserState.READY_FOR_GAMEPLAY: + if not user.state.is_playing: + return + await self.validate_user_stare( + server_room, + user.state, + state, + ) + await self.change_user_state(server_room, user, state) + if state == MultiplayerUserState.SPECTATING and ( + room.state == MultiplayerRoomState.PLAYING + or room.state == MultiplayerRoomState.WAITING_FOR_LOAD + ): + await self.call_noblock(client, "LoadRequested") + await self.update_room_state(server_room) + + async def change_user_state( + self, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + state: MultiplayerUserState, + ): + user.state = state + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserStateChanged", + user.user_id, + user.state, + ) + + async def update_room_state(self, room: ServerMultiplayerRoom): + match room.room.state: + case MultiplayerRoomState.WAITING_FOR_LOAD: + played_count = len( + [True for user in room.room.users if user.state.is_playing] + ) + ready_count = len( + [ + True + for user in room.room.users + if user.state == MultiplayerUserState.READY_FOR_GAMEPLAY + ] + ) + if played_count == ready_count: + await self.start_gameplay(room) + case MultiplayerRoomState.PLAYING: + assert room.queue + if all( + u.state != MultiplayerUserState.PLAYING for u in room.room.users + ): + for u in filter( + lambda u: u.state == MultiplayerUserState.FINISHED_PLAY, + room.room.users, + ): + await self.change_user_state( + room, u, MultiplayerUserState.RESULTS + ) + await self.change_room_state(room, MultiplayerRoomState.OPEN) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "ResultsReady", + ) + await room.queue.finish_current_item() + + async def change_room_state( + self, room: ServerMultiplayerRoom, state: MultiplayerRoomState + ): + room.room.state = state + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "RoomStateChanged", + state, + ) + + async def StartMatch(self, client: Client): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + if room.host is None or room.host.user_id != client.user_id: + raise InvokeException("You are not the host of this room") + if any(u.state != MultiplayerUserState.READY for u in room.users): + raise InvokeException("Not all users are ready") + + await self.start_match(server_room) + + async def start_match(self, room: ServerMultiplayerRoom): + assert room.queue + if room.room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Can't start match when already in a running state.") + if room.queue.current_item.expired: + raise InvokeException("Current playlist item is expired") + ready_users = [ + u + for u in room.room.users + if u.availability.state == DownloadState.LOCALLY_AVAILABLE + and ( + u.state == MultiplayerUserState.READY + or u.state == MultiplayerUserState.IDLE + ) + ] + await asyncio.gather( + *[ + self.change_user_state(room, u, MultiplayerUserState.WAITING_FOR_LOAD) + for u in ready_users + ] + ) + await self.change_room_state( + room, + MultiplayerRoomState.WAITING_FOR_LOAD, + ) + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "LoadRequested", + ) + await room.start_countdown( + ForceGameplayStartCountdown( + remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT) + ), + self.start_gameplay, + ) + + async def start_gameplay(self, room: ServerMultiplayerRoom): + assert room.queue + if room.room.state != MultiplayerRoomState.WAITING_FOR_LOAD: + raise InvokeException("Room is not ready for gameplay") + if room.queue.current_item.expired: + raise InvokeException("Current playlist item is expired") + playing = False + for user in room.room.users: + client = self.get_client_by_id(str(user.user_id)) + if client is None: + continue + + if user.state in ( + MultiplayerUserState.READY_FOR_GAMEPLAY, + MultiplayerUserState.LOADED, + ): + playing = True + await self.change_user_state(room, user, MultiplayerUserState.PLAYING) + await self.call_noblock(client, "GameplayStarted") + await self.change_room_state( + room, + (MultiplayerRoomState.PLAYING if playing else MultiplayerRoomState.OPEN), + ) + + async def send_match_event( + self, room: ServerMultiplayerRoom, event: MatchServerEvent + ): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "MatchEvent", + event, + ) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 70c2276..9afb78d 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -97,6 +97,8 @@ class MsgpackProtocol: return [cls.serialize_msgpack(item) for item in v] elif issubclass(typ, datetime.datetime): return [v, 0] + elif issubclass(typ, datetime.timedelta): + return int(v.total_seconds()) elif isinstance(v, dict): return { cls.serialize_msgpack(k): cls.serialize_msgpack(value) @@ -213,6 +215,8 @@ class MsgpackProtocol: return typ.model_validate(obj=cls.process_object(v, typ)) elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): return v[0] + elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): + return datetime.timedelta(seconds=int(v)) elif isinstance(v, list): return [cls.validate_object(item, get_args(typ)[0]) for item in v] elif inspect.isclass(typ) and issubclass(typ, Enum): @@ -296,21 +300,30 @@ class MsgpackProtocol: class JSONProtocol: @classmethod - def serialize_to_json(cls, v: Any): + def serialize_to_json(cls, v: Any, dict_key: bool = False): typ = v.__class__ if issubclass(typ, BaseModel): return cls.serialize_model(v) elif isinstance(v, dict): return { - cls.serialize_to_json(k): cls.serialize_to_json(value) + cls.serialize_to_json(k, True): cls.serialize_to_json(value) for k, value in v.items() } elif isinstance(v, list): return [cls.serialize_to_json(item) for item in v] elif isinstance(v, datetime.datetime): return v.isoformat() - elif isinstance(v, Enum): + elif isinstance(v, datetime.timedelta): + # d.hh:mm:ss + total_seconds = int(v.total_seconds()) + hours, remainder = divmod(total_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + return f"{hours:02}:{minutes:02}:{seconds:02}" + elif isinstance(v, Enum) and dict_key: return v.value + elif isinstance(v, Enum): + list_ = list(typ) + return list_.index(v) return v @classmethod @@ -322,9 +335,13 @@ class JSONProtocol: ) if metadata and metadata.json_ignore: continue - d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = ( - cls.serialize_to_json(getattr(v, field)) - ) + d[ + snake_to_camel( + field, + metadata.use_upper_case if metadata else False, + metadata.use_abbr if metadata else True, + ) + ] = cls.serialize_to_json(getattr(v, field)) if issubclass(v.__class__, SignalRUnionMessage): return { "$dtype": v.__class__.__name__, @@ -343,7 +360,11 @@ class JSONProtocol: ) if metadata and metadata.json_ignore: continue - value = v.get(snake_to_camel(field, not from_union)) + value = v.get( + snake_to_camel( + field, not from_union, metadata.use_abbr if metadata else True + ) + ) anno = typ.model_fields[field].annotation if anno is None: d[field] = value @@ -401,6 +422,17 @@ class JSONProtocol: return typ.model_validate(JSONProtocol.process_object(v, typ, from_union)) elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): return datetime.datetime.fromisoformat(v) + elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): + # d.hh:mm:ss + parts = v.split(":") + if len(parts) == 3: + return datetime.timedelta( + hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2]) + ) + elif len(parts) == 2: + return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1])) + elif len(parts) == 1: + return datetime.timedelta(seconds=int(parts[0])) elif isinstance(v, list): return [cls.validate_object(item, get_args(typ)[0]) for item in v] elif inspect.isclass(typ) and issubclass(typ, Enum): diff --git a/app/utils.py b/app/utils.py index 0d759a1..ac51b90 100644 --- a/app/utils.py +++ b/app/utils.py @@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str: return "".join(result) -def snake_to_camel(name: str, lower_case: bool = True) -> str: +def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) -> str: """Convert a snake_case string to camelCase.""" if not name: return name @@ -47,7 +47,7 @@ def snake_to_camel(name: str, lower_case: bool = True) -> str: result = [] for part in parts: - if part.lower() in abbreviations: + if part.lower() in abbreviations and use_abbr: result.append(part.upper()) else: if result or not lower_case: