From 0a80c5051cb7e5c1651322c174466474ebd8c7ea Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Tue, 5 Aug 2025 17:21:45 +0000 Subject: [PATCH] feat(multiplayer): support countdown --- app/models/metadata_hub.py | 15 ++++------ app/models/multiplayer_hub.py | 42 +++++++++++++-------------- app/models/signalr.py | 1 - app/signalr/hub/multiplayer.py | 52 ++++++++++++++++++++++++++++++---- app/signalr/packet.py | 39 ++++++++++++++----------- app/utils.py | 38 +++++++++++++++++++++++-- 6 files changed, 131 insertions(+), 56 deletions(-) diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py index a678d7f..684ab54 100644 --- a/app/models/metadata_hub.py +++ b/app/models/metadata_hub.py @@ -1,11 +1,11 @@ from __future__ import annotations from enum import IntEnum -from typing import Annotated, ClassVar, Literal +from typing import ClassVar, Literal -from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState +from app.models.signalr import SignalRUnionMessage, UserState -from pydantic import BaseModel, Field +from pydantic import BaseModel class _UserActivity(SignalRUnionMessage): ... @@ -100,12 +100,9 @@ UserActivity = ( class UserPresence(BaseModel): - activity: Annotated[ - UserActivity | None, Field(default=None), SignalRMeta(use_upper_case=True) - ] - status: Annotated[ - OnlineStatus | None, Field(default=None), SignalRMeta(use_upper_case=True) - ] + activity: UserActivity | None = None + + status: OnlineStatus | None = None @property def pushable(self) -> bool: diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index e2f4edf..9d78282 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -53,10 +53,14 @@ class MultiplayerRoomSettings(BaseModel): auto_start_duration: timedelta = timedelta(seconds=0) auto_skip: bool = False + @property + def auto_start_enabled(self) -> bool: + return self.auto_start_duration != timedelta(seconds=0) + class BeatmapAvailability(BaseModel): state: DownloadState = DownloadState.UNKNOWN - progress: float | None = None + download_progress: float | None = None class _MatchUserState(SignalRUnionMessage): ... @@ -283,10 +287,12 @@ class PlaylistItem(BaseModel): return copy -class _MultiplayerCountdown(BaseModel): +class _MultiplayerCountdown(SignalRUnionMessage): id: int = 0 - remaining: timedelta - is_exclusive: bool = False + time_remaining: timedelta + is_exclusive: Annotated[ + bool, Field(default=True), SignalRMeta(member_ignore=True) + ] = True class MatchStartCountdown(_MultiplayerCountdown): @@ -310,7 +316,7 @@ class MultiplayerRoomUser(BaseModel): user_id: int state: MultiplayerUserState = MultiplayerUserState.IDLE availability: BeatmapAvailability = BeatmapAvailability( - state=DownloadState.UNKNOWN, progress=None + state=DownloadState.UNKNOWN, download_progress=None ) mods: list[APIMod] = Field(default_factory=list) match_state: MatchUserState | None = None @@ -602,8 +608,8 @@ class CountdownInfo: def __init__(self, countdown: MultiplayerCountdown): self.countdown = countdown self.duration = ( - countdown.remaining - if countdown.remaining > timedelta(seconds=0) + countdown.time_remaining + if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0) ) @@ -776,13 +782,12 @@ class ServerMultiplayerRoom: ): async def _countdown_task(self: "ServerMultiplayerRoom"): await asyncio.sleep(info.duration.total_seconds()) - await self.stop_countdown(countdown) if on_complete is not None: await on_complete(self) + await self.stop_countdown(countdown) if countdown.is_exclusive: await self.stop_all_countdowns() - countdown.id = await self.get_next_countdown_id() info = CountdownInfo(countdown) self.room.active_countdowns.append(info.countdown) @@ -793,21 +798,14 @@ class ServerMultiplayerRoom: info.task = asyncio.create_task(_countdown_task(self)) async def stop_countdown(self, countdown: MultiplayerCountdown): - info = next( - ( - info - for info in self._tracked_countdown.values() - if info.countdown.id == countdown.id - ), - None, - ) + info = self._tracked_countdown.get(countdown.id) if info is None: return - if info.task is not None and not info.task.done(): - info.task.cancel() del self._tracked_countdown[countdown.id] self.room.active_countdowns.remove(countdown) await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id)) + if info.task is not None and not info.task.done(): + info.task.cancel() async def stop_all_countdowns(self): for countdown in list(self._tracked_countdown.values()): @@ -817,19 +815,19 @@ class ServerMultiplayerRoom: self.room.active_countdowns.clear() -class _MatchServerEvent(BaseModel): ... +class _MatchServerEvent(SignalRUnionMessage): ... class CountdownStartedEvent(_MatchServerEvent): countdown: MultiplayerCountdown - type: Literal[0] = Field(default=0, exclude=True) + union_type: ClassVar[Literal[0]] = 0 class CountdownStoppedEvent(_MatchServerEvent): id: int - type: Literal[1] = Field(default=1, exclude=True) + union_type: ClassVar[Literal[1]] = 1 MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent diff --git a/app/models/signalr.py b/app/models/signalr.py index 7116ea0..ffbaf6b 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -13,7 +13,6 @@ from pydantic import ( class SignalRMeta: member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute json_ignore: bool = False # implement of JsonIgnore (json) attribute - use_upper_case: bool = False # use upper CamelCase for field names use_abbr: bool = True diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 3be8024..ef3dfcd 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -19,12 +19,14 @@ from app.models.multiplayer_hub import ( GameplayAbortReason, MatchRequest, MatchServerEvent, + MatchStartCountdown, MultiplayerClientState, MultiplayerRoom, MultiplayerRoomSettings, MultiplayerRoomUser, PlaylistItem, ServerMultiplayerRoom, + ServerShuttingDownCountdown, StartMatchCountdownRequest, StopCountdownRequest, ) @@ -160,7 +162,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): availability = user.availability if ( availability.state == beatmap_availability.state - and availability.progress == beatmap_availability.progress + and availability.download_progress == beatmap_availability.download_progress ): return user.availability = beatmap_availability @@ -512,6 +514,25 @@ class MultiplayerHub(Hub[MultiplayerClientState]): async def update_room_state(self, room: ServerMultiplayerRoom): match room.room.state: + case MultiplayerRoomState.OPEN: + if room.room.settings.auto_start_enabled: + if ( + not room.queue.current_item.expired + and any( + u.state == MultiplayerUserState.READY + for u in room.room.users + ) + and not any( + isinstance(countdown, MatchStartCountdown) + for countdown in room.room.active_countdowns + ) + ): + await room.start_countdown( + MatchStartCountdown( + time_remaining=room.room.settings.auto_start_duration + ), + self.start_match, + ) case MultiplayerRoomState.WAITING_FOR_LOAD: played_count = len( [True for user in room.room.users if user.state.is_playing] @@ -610,7 +631,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) await room.start_countdown( ForceGameplayStartCountdown( - remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT) + time_remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT) ), self.start_gameplay, ) @@ -885,15 +906,34 @@ class MultiplayerHub(Hub[MultiplayerClientState]): raise InvokeException("You are not in this room") if isinstance(request, StartMatchCountdownRequest): - # TODO: countdown - ... + if room.host and room.host.user_id != user.user_id: + raise InvokeException("You are not the host of this room") + if room.state != MultiplayerRoomState.OPEN: + raise InvokeException("Cannot start a countdown during ongoing play") + await server_room.start_countdown( + MatchStartCountdown(time_remaining=request.duration), + self.start_match, + ) elif isinstance(request, StopCountdownRequest): - ... + countdown = next( + (c for c in room.active_countdowns if c.id == request.id), + None, + ) + if countdown is None: + return + if ( + isinstance(countdown, MatchStartCountdown) + and room.settings.auto_start_enabled + ) or isinstance( + countdown, (ForceGameplayStartCountdown | ServerShuttingDownCountdown) + ): + raise InvokeException("Cannot stop the requested countdown") + + await server_room.stop_countdown(countdown) else: await server_room.match_type_handler.handle_request(user, request) async def InvitePlayer(self, client: Client, user_id: int): - print(f"Inviting player... {client.user_id} {user_id}") store = self.get_or_create_state(client) if store.room_id == 0: raise InvokeException("You are not in a room") diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 9afb78d..09a36bd 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -15,7 +15,7 @@ from typing import ( ) from app.models.signalr import SignalRMeta, SignalRUnionMessage -from app.utils import camel_to_snake, snake_to_camel +from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal import msgpack_lazer_api as m from pydantic import BaseModel @@ -98,7 +98,7 @@ class MsgpackProtocol: elif issubclass(typ, datetime.datetime): return [v, 0] elif issubclass(typ, datetime.timedelta): - return int(v.total_seconds()) + return int(v.total_seconds() * 10_000_000) elif isinstance(v, dict): return { cls.serialize_msgpack(k): cls.serialize_msgpack(value) @@ -216,8 +216,8 @@ class MsgpackProtocol: elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): return v[0] elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): - return datetime.timedelta(seconds=int(v)) - elif isinstance(v, list): + return datetime.timedelta(seconds=int(v / 10_000_000)) + elif get_origin(typ) is list: return [cls.validate_object(item, get_args(typ)[0]) for item in v] elif inspect.isclass(typ) and issubclass(typ, Enum): list_ = list(typ) @@ -300,10 +300,10 @@ class MsgpackProtocol: class JSONProtocol: @classmethod - def serialize_to_json(cls, v: Any, dict_key: bool = False): + def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False): typ = v.__class__ if issubclass(typ, BaseModel): - return cls.serialize_model(v) + return cls.serialize_model(v, in_union) elif isinstance(v, dict): return { cls.serialize_to_json(k, True): cls.serialize_to_json(value) @@ -327,22 +327,28 @@ class JSONProtocol: return v @classmethod - def serialize_model(cls, v: BaseModel) -> dict[str, Any]: + def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]: d = {} + is_union = issubclass(v.__class__, SignalRUnionMessage) for field, info in v.__class__.model_fields.items(): metadata = next( (m for m in info.metadata if isinstance(m, SignalRMeta)), None ) if metadata and metadata.json_ignore: continue - d[ + name = ( snake_to_camel( field, - metadata.use_upper_case if metadata else False, metadata.use_abbr if metadata else True, ) - ] = cls.serialize_to_json(getattr(v, field)) - if issubclass(v.__class__, SignalRUnionMessage): + if not is_union + else snake_to_pascal( + field, + metadata.use_abbr if metadata else True, + ) + ) + d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union) + if is_union and not in_union: return { "$dtype": v.__class__.__name__, "$value": d, @@ -360,11 +366,12 @@ class JSONProtocol: ) if metadata and metadata.json_ignore: continue - value = v.get( - snake_to_camel( - field, not from_union, metadata.use_abbr if metadata else True - ) + name = ( + snake_to_camel(field, metadata.use_abbr if metadata else True) + if not from_union + else snake_to_pascal(field, metadata.use_abbr if metadata else True) ) + value = v.get(name) anno = typ.model_fields[field].annotation if anno is None: d[field] = value @@ -433,7 +440,7 @@ class JSONProtocol: return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1])) elif len(parts) == 1: return datetime.timedelta(seconds=int(parts[0])) - elif isinstance(v, list): + elif get_origin(typ) is list: return [cls.validate_object(item, get_args(typ)[0]) for item in v] elif inspect.isclass(typ) and issubclass(typ, Enum): list_ = list(typ) diff --git a/app/utils.py b/app/utils.py index ac51b90..22f06dd 100644 --- a/app/utils.py +++ b/app/utils.py @@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str: return "".join(result) -def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) -> str: +def snake_to_camel(name: str, use_abbr: bool = True) -> str: """Convert a snake_case string to camelCase.""" if not name: return name @@ -50,9 +50,43 @@ def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) -> if part.lower() in abbreviations and use_abbr: result.append(part.upper()) else: - if result or not lower_case: + if result: result.append(part.capitalize()) else: result.append(part.lower()) return "".join(result) + + +def snake_to_pascal(name: str, use_abbr: bool = True) -> str: + """Convert a snake_case string to PascalCase.""" + if not name: + return name + + parts = name.split("_") + if not parts: + return name + + # 常见缩写词列表 + abbreviations = { + "id", + "url", + "api", + "http", + "https", + "xml", + "json", + "css", + "html", + "sql", + "db", + } + + result = [] + for part in parts: + if part.lower() in abbreviations and use_abbr: + result.append(part.upper()) + else: + result.append(part.capitalize()) + + return "".join(result)