feat(multiplayer): support countdown

This commit is contained in:
MingxuanGame
2025-08-05 17:21:45 +00:00
parent 0988f1fc0c
commit 0a80c5051c
6 changed files with 131 additions and 56 deletions

View File

@@ -1,11 +1,11 @@
from __future__ import annotations
from enum import IntEnum
from typing import Annotated, ClassVar, Literal
from typing import ClassVar, Literal
from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState
from app.models.signalr import SignalRUnionMessage, UserState
from pydantic import BaseModel, Field
from pydantic import BaseModel
class _UserActivity(SignalRUnionMessage): ...
@@ -100,12 +100,9 @@ UserActivity = (
class UserPresence(BaseModel):
activity: Annotated[
UserActivity | None, Field(default=None), SignalRMeta(use_upper_case=True)
]
status: Annotated[
OnlineStatus | None, Field(default=None), SignalRMeta(use_upper_case=True)
]
activity: UserActivity | None = None
status: OnlineStatus | None = None
@property
def pushable(self) -> bool:

View File

@@ -53,10 +53,14 @@ class MultiplayerRoomSettings(BaseModel):
auto_start_duration: timedelta = timedelta(seconds=0)
auto_skip: bool = False
@property
def auto_start_enabled(self) -> bool:
return self.auto_start_duration != timedelta(seconds=0)
class BeatmapAvailability(BaseModel):
state: DownloadState = DownloadState.UNKNOWN
progress: float | None = None
download_progress: float | None = None
class _MatchUserState(SignalRUnionMessage): ...
@@ -283,10 +287,12 @@ class PlaylistItem(BaseModel):
return copy
class _MultiplayerCountdown(BaseModel):
class _MultiplayerCountdown(SignalRUnionMessage):
id: int = 0
remaining: timedelta
is_exclusive: bool = False
time_remaining: timedelta
is_exclusive: Annotated[
bool, Field(default=True), SignalRMeta(member_ignore=True)
] = True
class MatchStartCountdown(_MultiplayerCountdown):
@@ -310,7 +316,7 @@ class MultiplayerRoomUser(BaseModel):
user_id: int
state: MultiplayerUserState = MultiplayerUserState.IDLE
availability: BeatmapAvailability = BeatmapAvailability(
state=DownloadState.UNKNOWN, progress=None
state=DownloadState.UNKNOWN, download_progress=None
)
mods: list[APIMod] = Field(default_factory=list)
match_state: MatchUserState | None = None
@@ -602,8 +608,8 @@ class CountdownInfo:
def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown
self.duration = (
countdown.remaining
if countdown.remaining > timedelta(seconds=0)
countdown.time_remaining
if countdown.time_remaining > timedelta(seconds=0)
else timedelta(seconds=0)
)
@@ -776,13 +782,12 @@ class ServerMultiplayerRoom:
):
async def _countdown_task(self: "ServerMultiplayerRoom"):
await asyncio.sleep(info.duration.total_seconds())
await self.stop_countdown(countdown)
if on_complete is not None:
await on_complete(self)
await self.stop_countdown(countdown)
if countdown.is_exclusive:
await self.stop_all_countdowns()
countdown.id = await self.get_next_countdown_id()
info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown)
@@ -793,21 +798,14 @@ class ServerMultiplayerRoom:
info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown):
info = next(
(
info
for info in self._tracked_countdown.values()
if info.countdown.id == countdown.id
),
None,
)
info = self._tracked_countdown.get(countdown.id)
if info is None:
return
if info.task is not None and not info.task.done():
info.task.cancel()
del self._tracked_countdown[countdown.id]
self.room.active_countdowns.remove(countdown)
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
if info.task is not None and not info.task.done():
info.task.cancel()
async def stop_all_countdowns(self):
for countdown in list(self._tracked_countdown.values()):
@@ -817,19 +815,19 @@ class ServerMultiplayerRoom:
self.room.active_countdowns.clear()
class _MatchServerEvent(BaseModel): ...
class _MatchServerEvent(SignalRUnionMessage): ...
class CountdownStartedEvent(_MatchServerEvent):
countdown: MultiplayerCountdown
type: Literal[0] = Field(default=0, exclude=True)
union_type: ClassVar[Literal[0]] = 0
class CountdownStoppedEvent(_MatchServerEvent):
id: int
type: Literal[1] = Field(default=1, exclude=True)
union_type: ClassVar[Literal[1]] = 1
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent

View File

@@ -13,7 +13,6 @@ from pydantic import (
class SignalRMeta:
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
json_ignore: bool = False # implement of JsonIgnore (json) attribute
use_upper_case: bool = False # use upper CamelCase for field names
use_abbr: bool = True

View File

@@ -19,12 +19,14 @@ from app.models.multiplayer_hub import (
GameplayAbortReason,
MatchRequest,
MatchServerEvent,
MatchStartCountdown,
MultiplayerClientState,
MultiplayerRoom,
MultiplayerRoomSettings,
MultiplayerRoomUser,
PlaylistItem,
ServerMultiplayerRoom,
ServerShuttingDownCountdown,
StartMatchCountdownRequest,
StopCountdownRequest,
)
@@ -160,7 +162,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
availability = user.availability
if (
availability.state == beatmap_availability.state
and availability.progress == beatmap_availability.progress
and availability.download_progress == beatmap_availability.download_progress
):
return
user.availability = beatmap_availability
@@ -512,6 +514,25 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
async def update_room_state(self, room: ServerMultiplayerRoom):
match room.room.state:
case MultiplayerRoomState.OPEN:
if room.room.settings.auto_start_enabled:
if (
not room.queue.current_item.expired
and any(
u.state == MultiplayerUserState.READY
for u in room.room.users
)
and not any(
isinstance(countdown, MatchStartCountdown)
for countdown in room.room.active_countdowns
)
):
await room.start_countdown(
MatchStartCountdown(
time_remaining=room.room.settings.auto_start_duration
),
self.start_match,
)
case MultiplayerRoomState.WAITING_FOR_LOAD:
played_count = len(
[True for user in room.room.users if user.state.is_playing]
@@ -610,7 +631,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
)
await room.start_countdown(
ForceGameplayStartCountdown(
remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT)
time_remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT)
),
self.start_gameplay,
)
@@ -885,15 +906,34 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
raise InvokeException("You are not in this room")
if isinstance(request, StartMatchCountdownRequest):
# TODO: countdown
...
if room.host and room.host.user_id != user.user_id:
raise InvokeException("You are not the host of this room")
if room.state != MultiplayerRoomState.OPEN:
raise InvokeException("Cannot start a countdown during ongoing play")
await server_room.start_countdown(
MatchStartCountdown(time_remaining=request.duration),
self.start_match,
)
elif isinstance(request, StopCountdownRequest):
...
countdown = next(
(c for c in room.active_countdowns if c.id == request.id),
None,
)
if countdown is None:
return
if (
isinstance(countdown, MatchStartCountdown)
and room.settings.auto_start_enabled
) or isinstance(
countdown, (ForceGameplayStartCountdown | ServerShuttingDownCountdown)
):
raise InvokeException("Cannot stop the requested countdown")
await server_room.stop_countdown(countdown)
else:
await server_room.match_type_handler.handle_request(user, request)
async def InvitePlayer(self, client: Client, user_id: int):
print(f"Inviting player... {client.user_id} {user_id}")
store = self.get_or_create_state(client)
if store.room_id == 0:
raise InvokeException("You are not in a room")

View File

@@ -15,7 +15,7 @@ from typing import (
)
from app.models.signalr import SignalRMeta, SignalRUnionMessage
from app.utils import camel_to_snake, snake_to_camel
from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal
import msgpack_lazer_api as m
from pydantic import BaseModel
@@ -98,7 +98,7 @@ class MsgpackProtocol:
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif issubclass(typ, datetime.timedelta):
return int(v.total_seconds())
return int(v.total_seconds() * 10_000_000)
elif isinstance(v, dict):
return {
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
@@ -216,8 +216,8 @@ class MsgpackProtocol:
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return v[0]
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
return datetime.timedelta(seconds=int(v))
elif isinstance(v, list):
return datetime.timedelta(seconds=int(v / 10_000_000))
elif get_origin(typ) is list:
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
@@ -300,10 +300,10 @@ class MsgpackProtocol:
class JSONProtocol:
@classmethod
def serialize_to_json(cls, v: Any, dict_key: bool = False):
def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False):
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_model(v)
return cls.serialize_model(v, in_union)
elif isinstance(v, dict):
return {
cls.serialize_to_json(k, True): cls.serialize_to_json(value)
@@ -327,22 +327,28 @@ class JSONProtocol:
return v
@classmethod
def serialize_model(cls, v: BaseModel) -> dict[str, Any]:
def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]:
d = {}
is_union = issubclass(v.__class__, SignalRUnionMessage)
for field, info in v.__class__.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.json_ignore:
continue
d[
name = (
snake_to_camel(
field,
metadata.use_upper_case if metadata else False,
metadata.use_abbr if metadata else True,
)
] = cls.serialize_to_json(getattr(v, field))
if issubclass(v.__class__, SignalRUnionMessage):
if not is_union
else snake_to_pascal(
field,
metadata.use_abbr if metadata else True,
)
)
d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union)
if is_union and not in_union:
return {
"$dtype": v.__class__.__name__,
"$value": d,
@@ -360,11 +366,12 @@ class JSONProtocol:
)
if metadata and metadata.json_ignore:
continue
value = v.get(
snake_to_camel(
field, not from_union, metadata.use_abbr if metadata else True
)
name = (
snake_to_camel(field, metadata.use_abbr if metadata else True)
if not from_union
else snake_to_pascal(field, metadata.use_abbr if metadata else True)
)
value = v.get(name)
anno = typ.model_fields[field].annotation
if anno is None:
d[field] = value
@@ -433,7 +440,7 @@ class JSONProtocol:
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
elif len(parts) == 1:
return datetime.timedelta(seconds=int(parts[0]))
elif isinstance(v, list):
elif get_origin(typ) is list:
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)

View File

@@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str:
return "".join(result)
def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) -> str:
def snake_to_camel(name: str, use_abbr: bool = True) -> str:
"""Convert a snake_case string to camelCase."""
if not name:
return name
@@ -50,9 +50,43 @@ def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) ->
if part.lower() in abbreviations and use_abbr:
result.append(part.upper())
else:
if result or not lower_case:
if result:
result.append(part.capitalize())
else:
result.append(part.lower())
return "".join(result)
def snake_to_pascal(name: str, use_abbr: bool = True) -> str:
"""Convert a snake_case string to PascalCase."""
if not name:
return name
parts = name.split("_")
if not parts:
return name
# 常见缩写词列表
abbreviations = {
"id",
"url",
"api",
"http",
"https",
"xml",
"json",
"css",
"html",
"sql",
"db",
}
result = []
for part in parts:
if part.lower() in abbreviations and use_abbr:
result.append(part.upper())
else:
result.append(part.capitalize())
return "".join(result)