feat(multiplayer): support countdown

This commit is contained in:
MingxuanGame
2025-08-05 17:21:45 +00:00
parent 0988f1fc0c
commit 0a80c5051c
6 changed files with 131 additions and 56 deletions

View File

@@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
from enum import IntEnum from enum import IntEnum
from typing import Annotated, ClassVar, Literal from typing import ClassVar, Literal
from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState from app.models.signalr import SignalRUnionMessage, UserState
from pydantic import BaseModel, Field from pydantic import BaseModel
class _UserActivity(SignalRUnionMessage): ... class _UserActivity(SignalRUnionMessage): ...
@@ -100,12 +100,9 @@ UserActivity = (
class UserPresence(BaseModel): class UserPresence(BaseModel):
activity: Annotated[ activity: UserActivity | None = None
UserActivity | None, Field(default=None), SignalRMeta(use_upper_case=True)
] status: OnlineStatus | None = None
status: Annotated[
OnlineStatus | None, Field(default=None), SignalRMeta(use_upper_case=True)
]
@property @property
def pushable(self) -> bool: def pushable(self) -> bool:

View File

@@ -53,10 +53,14 @@ class MultiplayerRoomSettings(BaseModel):
auto_start_duration: timedelta = timedelta(seconds=0) auto_start_duration: timedelta = timedelta(seconds=0)
auto_skip: bool = False auto_skip: bool = False
@property
def auto_start_enabled(self) -> bool:
return self.auto_start_duration != timedelta(seconds=0)
class BeatmapAvailability(BaseModel): class BeatmapAvailability(BaseModel):
state: DownloadState = DownloadState.UNKNOWN state: DownloadState = DownloadState.UNKNOWN
progress: float | None = None download_progress: float | None = None
class _MatchUserState(SignalRUnionMessage): ... class _MatchUserState(SignalRUnionMessage): ...
@@ -283,10 +287,12 @@ class PlaylistItem(BaseModel):
return copy return copy
class _MultiplayerCountdown(BaseModel): class _MultiplayerCountdown(SignalRUnionMessage):
id: int = 0 id: int = 0
remaining: timedelta time_remaining: timedelta
is_exclusive: bool = False is_exclusive: Annotated[
bool, Field(default=True), SignalRMeta(member_ignore=True)
] = True
class MatchStartCountdown(_MultiplayerCountdown): class MatchStartCountdown(_MultiplayerCountdown):
@@ -310,7 +316,7 @@ class MultiplayerRoomUser(BaseModel):
user_id: int user_id: int
state: MultiplayerUserState = MultiplayerUserState.IDLE state: MultiplayerUserState = MultiplayerUserState.IDLE
availability: BeatmapAvailability = BeatmapAvailability( availability: BeatmapAvailability = BeatmapAvailability(
state=DownloadState.UNKNOWN, progress=None state=DownloadState.UNKNOWN, download_progress=None
) )
mods: list[APIMod] = Field(default_factory=list) mods: list[APIMod] = Field(default_factory=list)
match_state: MatchUserState | None = None match_state: MatchUserState | None = None
@@ -602,8 +608,8 @@ class CountdownInfo:
def __init__(self, countdown: MultiplayerCountdown): def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown self.countdown = countdown
self.duration = ( self.duration = (
countdown.remaining countdown.time_remaining
if countdown.remaining > timedelta(seconds=0) if countdown.time_remaining > timedelta(seconds=0)
else timedelta(seconds=0) else timedelta(seconds=0)
) )
@@ -776,13 +782,12 @@ class ServerMultiplayerRoom:
): ):
async def _countdown_task(self: "ServerMultiplayerRoom"): async def _countdown_task(self: "ServerMultiplayerRoom"):
await asyncio.sleep(info.duration.total_seconds()) await asyncio.sleep(info.duration.total_seconds())
await self.stop_countdown(countdown)
if on_complete is not None: if on_complete is not None:
await on_complete(self) await on_complete(self)
await self.stop_countdown(countdown)
if countdown.is_exclusive: if countdown.is_exclusive:
await self.stop_all_countdowns() await self.stop_all_countdowns()
countdown.id = await self.get_next_countdown_id() countdown.id = await self.get_next_countdown_id()
info = CountdownInfo(countdown) info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown) self.room.active_countdowns.append(info.countdown)
@@ -793,21 +798,14 @@ class ServerMultiplayerRoom:
info.task = asyncio.create_task(_countdown_task(self)) info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown): async def stop_countdown(self, countdown: MultiplayerCountdown):
info = next( info = self._tracked_countdown.get(countdown.id)
(
info
for info in self._tracked_countdown.values()
if info.countdown.id == countdown.id
),
None,
)
if info is None: if info is None:
return return
if info.task is not None and not info.task.done():
info.task.cancel()
del self._tracked_countdown[countdown.id] del self._tracked_countdown[countdown.id]
self.room.active_countdowns.remove(countdown) self.room.active_countdowns.remove(countdown)
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id)) await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
if info.task is not None and not info.task.done():
info.task.cancel()
async def stop_all_countdowns(self): async def stop_all_countdowns(self):
for countdown in list(self._tracked_countdown.values()): for countdown in list(self._tracked_countdown.values()):
@@ -817,19 +815,19 @@ class ServerMultiplayerRoom:
self.room.active_countdowns.clear() self.room.active_countdowns.clear()
class _MatchServerEvent(BaseModel): ... class _MatchServerEvent(SignalRUnionMessage): ...
class CountdownStartedEvent(_MatchServerEvent): class CountdownStartedEvent(_MatchServerEvent):
countdown: MultiplayerCountdown countdown: MultiplayerCountdown
type: Literal[0] = Field(default=0, exclude=True) union_type: ClassVar[Literal[0]] = 0
class CountdownStoppedEvent(_MatchServerEvent): class CountdownStoppedEvent(_MatchServerEvent):
id: int id: int
type: Literal[1] = Field(default=1, exclude=True) union_type: ClassVar[Literal[1]] = 1
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent

View File

@@ -13,7 +13,6 @@ from pydantic import (
class SignalRMeta: class SignalRMeta:
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
json_ignore: bool = False # implement of JsonIgnore (json) attribute json_ignore: bool = False # implement of JsonIgnore (json) attribute
use_upper_case: bool = False # use upper CamelCase for field names
use_abbr: bool = True use_abbr: bool = True

View File

@@ -19,12 +19,14 @@ from app.models.multiplayer_hub import (
GameplayAbortReason, GameplayAbortReason,
MatchRequest, MatchRequest,
MatchServerEvent, MatchServerEvent,
MatchStartCountdown,
MultiplayerClientState, MultiplayerClientState,
MultiplayerRoom, MultiplayerRoom,
MultiplayerRoomSettings, MultiplayerRoomSettings,
MultiplayerRoomUser, MultiplayerRoomUser,
PlaylistItem, PlaylistItem,
ServerMultiplayerRoom, ServerMultiplayerRoom,
ServerShuttingDownCountdown,
StartMatchCountdownRequest, StartMatchCountdownRequest,
StopCountdownRequest, StopCountdownRequest,
) )
@@ -160,7 +162,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
availability = user.availability availability = user.availability
if ( if (
availability.state == beatmap_availability.state availability.state == beatmap_availability.state
and availability.progress == beatmap_availability.progress and availability.download_progress == beatmap_availability.download_progress
): ):
return return
user.availability = beatmap_availability user.availability = beatmap_availability
@@ -512,6 +514,25 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
async def update_room_state(self, room: ServerMultiplayerRoom): async def update_room_state(self, room: ServerMultiplayerRoom):
match room.room.state: match room.room.state:
case MultiplayerRoomState.OPEN:
if room.room.settings.auto_start_enabled:
if (
not room.queue.current_item.expired
and any(
u.state == MultiplayerUserState.READY
for u in room.room.users
)
and not any(
isinstance(countdown, MatchStartCountdown)
for countdown in room.room.active_countdowns
)
):
await room.start_countdown(
MatchStartCountdown(
time_remaining=room.room.settings.auto_start_duration
),
self.start_match,
)
case MultiplayerRoomState.WAITING_FOR_LOAD: case MultiplayerRoomState.WAITING_FOR_LOAD:
played_count = len( played_count = len(
[True for user in room.room.users if user.state.is_playing] [True for user in room.room.users if user.state.is_playing]
@@ -610,7 +631,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
) )
await room.start_countdown( await room.start_countdown(
ForceGameplayStartCountdown( ForceGameplayStartCountdown(
remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT) time_remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT)
), ),
self.start_gameplay, self.start_gameplay,
) )
@@ -885,15 +906,34 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
raise InvokeException("You are not in this room") raise InvokeException("You are not in this room")
if isinstance(request, StartMatchCountdownRequest): if isinstance(request, StartMatchCountdownRequest):
# TODO: countdown if room.host and room.host.user_id != user.user_id:
... raise InvokeException("You are not the host of this room")
if room.state != MultiplayerRoomState.OPEN:
raise InvokeException("Cannot start a countdown during ongoing play")
await server_room.start_countdown(
MatchStartCountdown(time_remaining=request.duration),
self.start_match,
)
elif isinstance(request, StopCountdownRequest): elif isinstance(request, StopCountdownRequest):
... countdown = next(
(c for c in room.active_countdowns if c.id == request.id),
None,
)
if countdown is None:
return
if (
isinstance(countdown, MatchStartCountdown)
and room.settings.auto_start_enabled
) or isinstance(
countdown, (ForceGameplayStartCountdown | ServerShuttingDownCountdown)
):
raise InvokeException("Cannot stop the requested countdown")
await server_room.stop_countdown(countdown)
else: else:
await server_room.match_type_handler.handle_request(user, request) await server_room.match_type_handler.handle_request(user, request)
async def InvitePlayer(self, client: Client, user_id: int): async def InvitePlayer(self, client: Client, user_id: int):
print(f"Inviting player... {client.user_id} {user_id}")
store = self.get_or_create_state(client) store = self.get_or_create_state(client)
if store.room_id == 0: if store.room_id == 0:
raise InvokeException("You are not in a room") raise InvokeException("You are not in a room")

View File

@@ -15,7 +15,7 @@ from typing import (
) )
from app.models.signalr import SignalRMeta, SignalRUnionMessage from app.models.signalr import SignalRMeta, SignalRUnionMessage
from app.utils import camel_to_snake, snake_to_camel from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal
import msgpack_lazer_api as m import msgpack_lazer_api as m
from pydantic import BaseModel from pydantic import BaseModel
@@ -98,7 +98,7 @@ class MsgpackProtocol:
elif issubclass(typ, datetime.datetime): elif issubclass(typ, datetime.datetime):
return [v, 0] return [v, 0]
elif issubclass(typ, datetime.timedelta): elif issubclass(typ, datetime.timedelta):
return int(v.total_seconds()) return int(v.total_seconds() * 10_000_000)
elif isinstance(v, dict): elif isinstance(v, dict):
return { return {
cls.serialize_msgpack(k): cls.serialize_msgpack(value) cls.serialize_msgpack(k): cls.serialize_msgpack(value)
@@ -216,8 +216,8 @@ class MsgpackProtocol:
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return v[0] return v[0]
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta): elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
return datetime.timedelta(seconds=int(v)) return datetime.timedelta(seconds=int(v / 10_000_000))
elif isinstance(v, list): elif get_origin(typ) is list:
return [cls.validate_object(item, get_args(typ)[0]) for item in v] return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum): elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ) list_ = list(typ)
@@ -300,10 +300,10 @@ class MsgpackProtocol:
class JSONProtocol: class JSONProtocol:
@classmethod @classmethod
def serialize_to_json(cls, v: Any, dict_key: bool = False): def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False):
typ = v.__class__ typ = v.__class__
if issubclass(typ, BaseModel): if issubclass(typ, BaseModel):
return cls.serialize_model(v) return cls.serialize_model(v, in_union)
elif isinstance(v, dict): elif isinstance(v, dict):
return { return {
cls.serialize_to_json(k, True): cls.serialize_to_json(value) cls.serialize_to_json(k, True): cls.serialize_to_json(value)
@@ -327,22 +327,28 @@ class JSONProtocol:
return v return v
@classmethod @classmethod
def serialize_model(cls, v: BaseModel) -> dict[str, Any]: def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]:
d = {} d = {}
is_union = issubclass(v.__class__, SignalRUnionMessage)
for field, info in v.__class__.model_fields.items(): for field, info in v.__class__.model_fields.items():
metadata = next( metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None (m for m in info.metadata if isinstance(m, SignalRMeta)), None
) )
if metadata and metadata.json_ignore: if metadata and metadata.json_ignore:
continue continue
d[ name = (
snake_to_camel( snake_to_camel(
field, field,
metadata.use_upper_case if metadata else False,
metadata.use_abbr if metadata else True, metadata.use_abbr if metadata else True,
) )
] = cls.serialize_to_json(getattr(v, field)) if not is_union
if issubclass(v.__class__, SignalRUnionMessage): else snake_to_pascal(
field,
metadata.use_abbr if metadata else True,
)
)
d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union)
if is_union and not in_union:
return { return {
"$dtype": v.__class__.__name__, "$dtype": v.__class__.__name__,
"$value": d, "$value": d,
@@ -360,11 +366,12 @@ class JSONProtocol:
) )
if metadata and metadata.json_ignore: if metadata and metadata.json_ignore:
continue continue
value = v.get( name = (
snake_to_camel( snake_to_camel(field, metadata.use_abbr if metadata else True)
field, not from_union, metadata.use_abbr if metadata else True if not from_union
) else snake_to_pascal(field, metadata.use_abbr if metadata else True)
) )
value = v.get(name)
anno = typ.model_fields[field].annotation anno = typ.model_fields[field].annotation
if anno is None: if anno is None:
d[field] = value d[field] = value
@@ -433,7 +440,7 @@ class JSONProtocol:
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1])) return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
elif len(parts) == 1: elif len(parts) == 1:
return datetime.timedelta(seconds=int(parts[0])) return datetime.timedelta(seconds=int(parts[0]))
elif isinstance(v, list): elif get_origin(typ) is list:
return [cls.validate_object(item, get_args(typ)[0]) for item in v] return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum): elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ) list_ = list(typ)

View File

@@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str:
return "".join(result) return "".join(result)
def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) -> str: def snake_to_camel(name: str, use_abbr: bool = True) -> str:
"""Convert a snake_case string to camelCase.""" """Convert a snake_case string to camelCase."""
if not name: if not name:
return name return name
@@ -50,9 +50,43 @@ def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) ->
if part.lower() in abbreviations and use_abbr: if part.lower() in abbreviations and use_abbr:
result.append(part.upper()) result.append(part.upper())
else: else:
if result or not lower_case: if result:
result.append(part.capitalize()) result.append(part.capitalize())
else: else:
result.append(part.lower()) result.append(part.lower())
return "".join(result) return "".join(result)
def snake_to_pascal(name: str, use_abbr: bool = True) -> str:
"""Convert a snake_case string to PascalCase."""
if not name:
return name
parts = name.split("_")
if not parts:
return name
# 常见缩写词列表
abbreviations = {
"id",
"url",
"api",
"http",
"https",
"xml",
"json",
"css",
"html",
"sql",
"db",
}
result = []
for part in parts:
if part.lower() in abbreviations and use_abbr:
result.append(part.upper())
else:
result.append(part.capitalize())
return "".join(result)