feat(multiplayer): support play

WIP
This commit is contained in:
MingxuanGame
2025-08-03 12:53:22 +00:00
parent b7bc87b8b6
commit 2600fa499f
11 changed files with 666 additions and 196 deletions

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel from app.models.model import UTCBaseModel
from app.models.mods import APIMod, msgpack_to_apimod from app.models.mods import APIMod
from app.models.multiplayer_hub import PlaylistItem from app.models.multiplayer_hub import PlaylistItem
from .beatmap import Beatmap, BeatmapResp from .beatmap import Beatmap, BeatmapResp
@@ -79,10 +79,10 @@ class Playlist(PlaylistBase, table=True):
owner_id=playlist.owner_id, owner_id=playlist.owner_id,
ruleset_id=playlist.ruleset_id, ruleset_id=playlist.ruleset_id,
beatmap_id=playlist.beatmap_id, beatmap_id=playlist.beatmap_id,
required_mods=[msgpack_to_apimod(mod) for mod in playlist.required_mods], required_mods=playlist.required_mods,
allowed_mods=[msgpack_to_apimod(mod) for mod in playlist.allowed_mods], allowed_mods=playlist.allowed_mods,
expired=playlist.expired, expired=playlist.expired,
playlist_order=playlist.order, playlist_order=playlist.playlist_order,
played_at=playlist.played_at, played_at=playlist.played_at,
freestyle=playlist.freestyle, freestyle=playlist.freestyle,
room_id=room_id, room_id=room_id,
@@ -99,14 +99,10 @@ class Playlist(PlaylistBase, table=True):
db_playlist.owner_id = playlist.owner_id db_playlist.owner_id = playlist.owner_id
db_playlist.ruleset_id = playlist.ruleset_id db_playlist.ruleset_id = playlist.ruleset_id
db_playlist.beatmap_id = playlist.beatmap_id db_playlist.beatmap_id = playlist.beatmap_id
db_playlist.required_mods = [ db_playlist.required_mods = playlist.required_mods
msgpack_to_apimod(mod) for mod in playlist.required_mods db_playlist.allowed_mods = playlist.allowed_mods
]
db_playlist.allowed_mods = [
msgpack_to_apimod(mod) for mod in playlist.allowed_mods
]
db_playlist.expired = playlist.expired db_playlist.expired = playlist.expired
db_playlist.playlist_order = playlist.order db_playlist.playlist_order = playlist.playlist_order
db_playlist.played_at = playlist.played_at db_playlist.played_at = playlist.played_at
db_playlist.freestyle = playlist.freestyle db_playlist.freestyle = playlist.freestyle
await session.commit() await session.commit()

View File

@@ -125,7 +125,7 @@ class RoomResp(RoomBase):
type=room.settings.match_type, type=room.settings.match_type,
queue_mode=room.settings.queue_mode, queue_mode=room.settings.queue_mode,
auto_skip=room.settings.auto_skip, auto_skip=room.settings.auto_skip,
auto_start_duration=room.settings.auto_start_duration, auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
status=server_room.status, status=server_room.status,
category=server_room.category, category=server_room.category,
# duration = room.settings.duration, # duration = room.settings.duration,

View File

@@ -91,7 +91,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# optional # optional
# TODO: current_user_attributes # TODO: current_user_attributes
position: int | None = Field(default=None) # multiplayer # position: int | None = Field(default=None) # multiplayer
class Score(ScoreBase, table=True): class Score(ScoreBase, table=True):
@@ -162,6 +162,7 @@ class ScoreResp(ScoreBase):
maximum_statistics: ScoreStatistics | None = None maximum_statistics: ScoreStatistics | None = None
rank_global: int | None = None rank_global: int | None = None
rank_country: int | None = None rank_country: int | None = None
position: int = 1 # TODO
@classmethod @classmethod
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
@@ -618,6 +619,8 @@ async def process_score(
fetcher: "Fetcher", fetcher: "Fetcher",
session: AsyncSession, session: AsyncSession,
redis: Redis, redis: Redis,
item_id: int | None = None,
room_id: int | None = None,
) -> Score: ) -> Score:
assert user.id assert user.id
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods) can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
@@ -649,6 +652,8 @@ async def process_score(
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0), nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0), nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0), nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
playlist_item_id=item_id,
room_id=room_id,
) )
if can_get_pp: if can_get_pp:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)

View File

@@ -5,8 +5,6 @@ from typing import Literal, NotRequired, TypedDict
from app.path import STATIC_DIR from app.path import STATIC_DIR
from msgpack_lazer_api import APIMod as MsgpackAPIMod
class APIMod(TypedDict): class APIMod(TypedDict):
acronym: str acronym: str
@@ -169,13 +167,3 @@ def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
if expected_value != NO_CHECK and value != expected_value: if expected_value != NO_CHECK and value != expected_value:
return False return False
return True return True
def msgpack_to_apimod(mod: MsgpackAPIMod) -> APIMod:
"""
Convert a MsgpackAPIMod to an APIMod.
"""
return APIMod(
acronym=mod.acronym,
settings=mod.settings,
)

View File

@@ -1,13 +1,16 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass import asyncio
import datetime from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Annotated, Any, Literal from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
from app.database.beatmap import Beatmap from app.database.beatmap import Beatmap
from app.dependencies.database import engine from app.dependencies.database import engine
from app.exception import InvokeException from app.exception import InvokeException
from .mods import APIMod
from .room import ( from .room import (
DownloadState, DownloadState,
MatchType, MatchType,
@@ -18,15 +21,14 @@ from .room import (
RoomStatus, RoomStatus,
) )
from .signalr import ( from .signalr import (
EnumByIndex, SignalRMeta,
MessagePackArrayModel, SignalRUnionMessage,
UserState, UserState,
msgpack_union,
msgpack_union_dump,
) )
from msgpack_lazer_api import APIMod from pydantic import BaseModel, Field
from pydantic import Field, field_serializer, field_validator from sqlalchemy import update
from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -40,37 +42,37 @@ class MultiplayerClientState(UserState):
room_id: int = 0 room_id: int = 0
class MultiplayerRoomSettings(MessagePackArrayModel): class MultiplayerRoomSettings(BaseModel):
name: str = "Unnamed Room" name: str = "Unnamed Room"
playlist_item_id: int = 0 playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
password: str = "" password: str = ""
match_type: Annotated[MatchType, EnumByIndex(MatchType)] = MatchType.HEAD_TO_HEAD match_type: MatchType = MatchType.HEAD_TO_HEAD
queue_mode: Annotated[QueueMode, EnumByIndex(QueueMode)] = QueueMode.HOST_ONLY queue_mode: QueueMode = QueueMode.HOST_ONLY
auto_start_duration: int = 0 auto_start_duration: timedelta = timedelta(seconds=0)
auto_skip: bool = False auto_skip: bool = False
class BeatmapAvailability(MessagePackArrayModel): class BeatmapAvailability(BaseModel):
state: Annotated[DownloadState, EnumByIndex(DownloadState)] = DownloadState.UNKNOWN state: DownloadState = DownloadState.UNKNOWN
progress: float | None = None progress: float | None = None
class _MatchUserState(MessagePackArrayModel): ... class _MatchUserState(SignalRUnionMessage): ...
class TeamVersusUserState(_MatchUserState): class TeamVersusUserState(_MatchUserState):
team_id: int team_id: int
type: Literal[0] = Field(0, exclude=True) union_type: ClassVar[Literal[0]] = 0
MatchUserState = TeamVersusUserState MatchUserState = TeamVersusUserState
class _MatchRoomState(MessagePackArrayModel): ... class _MatchRoomState(SignalRUnionMessage): ...
class MultiplayerTeam(MessagePackArrayModel): class MultiplayerTeam(BaseModel):
id: int id: int
name: str name: str
@@ -83,24 +85,24 @@ class TeamVersusRoomState(_MatchRoomState):
] ]
) )
type: Literal[0] = Field(0, exclude=True) union_type: ClassVar[Literal[0]] = 0
MatchRoomState = TeamVersusRoomState MatchRoomState = TeamVersusRoomState
class PlaylistItem(MessagePackArrayModel): class PlaylistItem(BaseModel):
id: int id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
owner_id: int owner_id: int
beatmap_id: int beatmap_id: int
checksum: str beatmap_checksum: str
ruleset_id: int ruleset_id: int
required_mods: list[APIMod] = Field(default_factory=list) required_mods: list[APIMod] = Field(default_factory=list)
allowed_mods: list[APIMod] = Field(default_factory=list) allowed_mods: list[APIMod] = Field(default_factory=list)
expired: bool expired: bool
order: int playlist_order: int
played_at: datetime.datetime | None = None played_at: datetime | None = None
star: float star_rating: float
freestyle: bool freestyle: bool
def validate_user_mods( def validate_user_mods(
@@ -127,7 +129,10 @@ class PlaylistItem(MessagePackArrayModel):
# Check if mods are valid for the ruleset # Check if mods are valid for the ruleset
for mod in proposed_mods: for mod in proposed_mods:
if ruleset_key not in API_MODS or mod.acronym not in API_MODS[ruleset_key]: if (
ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[ruleset_key]
):
all_proposed_valid = False all_proposed_valid = False
continue continue
valid_mods.append(mod) valid_mods.append(mod)
@@ -136,35 +141,35 @@ class PlaylistItem(MessagePackArrayModel):
incompatible_mods = set() incompatible_mods = set()
final_valid_mods = [] final_valid_mods = []
for mod in valid_mods: for mod in valid_mods:
if mod.acronym in incompatible_mods: if mod["acronym"] in incompatible_mods:
all_proposed_valid = False all_proposed_valid = False
continue continue
setting_mods = API_MODS[ruleset_key].get(mod.acronym) setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
if setting_mods: if setting_mods:
incompatible_mods.update(setting_mods["IncompatibleMods"]) incompatible_mods.update(setting_mods["IncompatibleMods"])
final_valid_mods.append(mod) final_valid_mods.append(mod)
# If not freestyle, check against allowed mods # If not freestyle, check against allowed mods
if not self.freestyle: if not self.freestyle:
allowed_acronyms = {mod.acronym for mod in self.allowed_mods} allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
filtered_valid_mods = [] filtered_valid_mods = []
for mod in final_valid_mods: for mod in final_valid_mods:
if mod.acronym not in allowed_acronyms: if mod["acronym"] not in allowed_acronyms:
all_proposed_valid = False all_proposed_valid = False
else: else:
filtered_valid_mods.append(mod) filtered_valid_mods.append(mod)
final_valid_mods = filtered_valid_mods final_valid_mods = filtered_valid_mods
# Check compatibility with required mods # Check compatibility with required mods
required_mod_acronyms = {mod.acronym for mod in self.required_mods} required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
all_mod_acronyms = { all_mod_acronyms = {
mod.acronym for mod in final_valid_mods mod["acronym"] for mod in final_valid_mods
} | required_mod_acronyms } | required_mod_acronyms
# Check for incompatibility between required and user mods # Check for incompatibility between required and user mods
filtered_valid_mods = [] filtered_valid_mods = []
for mod in final_valid_mods: for mod in final_valid_mods:
mod_acronym = mod.acronym mod_acronym = mod["acronym"]
is_compatible = True is_compatible = True
for other_acronym in all_mod_acronyms: for other_acronym in all_mod_acronyms:
@@ -181,23 +186,29 @@ class PlaylistItem(MessagePackArrayModel):
return all_proposed_valid, filtered_valid_mods return all_proposed_valid, filtered_valid_mods
def clone(self) -> "PlaylistItem":
copy = self.model_copy()
copy.required_mods = list(self.required_mods)
copy.allowed_mods = list(self.allowed_mods)
return copy
class _MultiplayerCountdown(MessagePackArrayModel):
id: int class _MultiplayerCountdown(BaseModel):
remaining: int id: int = 0
is_exclusive: bool remaining: timedelta
is_exclusive: bool = False
class MatchStartCountdown(_MultiplayerCountdown): class MatchStartCountdown(_MultiplayerCountdown):
type: Literal[0] = Field(0, exclude=True) union_type: ClassVar[Literal[0]] = 0
class ForceGameplayStartCountdown(_MultiplayerCountdown): class ForceGameplayStartCountdown(_MultiplayerCountdown):
type: Literal[1] = Field(1, exclude=True) union_type: ClassVar[Literal[1]] = 1
class ServerShuttingDownCountdown(_MultiplayerCountdown): class ServerShuttingDownCountdown(_MultiplayerCountdown):
type: Literal[2] = Field(2, exclude=True) union_type: ClassVar[Literal[2]] = 2
MultiplayerCountdown = ( MultiplayerCountdown = (
@@ -205,11 +216,9 @@ MultiplayerCountdown = (
) )
class MultiplayerRoomUser(MessagePackArrayModel): class MultiplayerRoomUser(BaseModel):
user_id: int user_id: int
state: Annotated[MultiplayerUserState, EnumByIndex(MultiplayerUserState)] = ( state: MultiplayerUserState = MultiplayerUserState.IDLE
MultiplayerUserState.IDLE
)
availability: BeatmapAvailability = BeatmapAvailability( availability: BeatmapAvailability = BeatmapAvailability(
state=DownloadState.UNKNOWN, progress=None state=DownloadState.UNKNOWN, progress=None
) )
@@ -218,50 +227,33 @@ class MultiplayerRoomUser(MessagePackArrayModel):
ruleset_id: int | None = None # freestyle ruleset_id: int | None = None # freestyle
beatmap_id: int | None = None # freestyle beatmap_id: int | None = None # freestyle
@field_validator("match_state", mode="before")
def union_validate(v: Any):
if isinstance(v, list):
return msgpack_union(v)
return v
@field_serializer("match_state") class MultiplayerRoom(BaseModel):
def union_serialize(v: Any):
return msgpack_union_dump(v)
class MultiplayerRoom(MessagePackArrayModel):
room_id: int room_id: int
state: Annotated[MultiplayerRoomState, EnumByIndex(MultiplayerRoomState)] state: MultiplayerRoomState
settings: MultiplayerRoomSettings settings: MultiplayerRoomSettings
users: list[MultiplayerRoomUser] = Field(default_factory=list) users: list[MultiplayerRoomUser] = Field(default_factory=list)
host: MultiplayerRoomUser | None = None host: MultiplayerRoomUser | None = None
match_state: MatchRoomState | None = None match_state: MatchRoomState | None = None
playlist: list[PlaylistItem] = Field(default_factory=list) playlist: list[PlaylistItem] = Field(default_factory=list)
active_cooldowns: list[MultiplayerCountdown] = Field(default_factory=list) active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
channel_id: int channel_id: int
@field_validator("match_state", mode="before")
def union_validate(v: Any):
if isinstance(v, list):
return msgpack_union(v)
return v
@field_serializer("match_state")
def union_serialize(v: Any):
return msgpack_union_dump(v)
class MultiplayerQueue: class MultiplayerQueue:
def __init__(self, room: "ServerMultiplayerRoom", hub: "MultiplayerHub"): def __init__(self, room: "ServerMultiplayerRoom"):
self.server_room = room self.server_room = room
self.hub = hub
self.current_index = 0 self.current_index = 0
@property
def hub(self) -> "MultiplayerHub":
return self.server_room.hub
@property @property
def upcoming_items(self): def upcoming_items(self):
return sorted( return sorted(
(item for item in self.room.playlist if not item.expired), (item for item in self.room.playlist if not item.expired),
key=lambda i: i.order, key=lambda i: i.playlist_order,
) )
@property @property
@@ -323,9 +315,9 @@ class MultiplayerQueue:
) )
async with AsyncSession(engine) as session: async with AsyncSession(engine) as session:
for idx, item in enumerate(ordered_active_items): for idx, item in enumerate(ordered_active_items):
if item.order == idx: if item.playlist_order == idx:
continue continue
item.order = idx item.playlist_order = idx
await Playlist.update(item, self.room.room_id, session) await Playlist.update(item, self.room.room_id, session)
await self.hub.playlist_changed( await self.hub.playlist_changed(
self.server_room, item, beatmap_changed=False self.server_room, item, beatmap_changed=False
@@ -338,7 +330,7 @@ class MultiplayerQueue:
if upcoming_items if upcoming_items
else max( else max(
self.room.playlist, self.room.playlist,
key=lambda i: i.played_at or datetime.datetime.min, key=lambda i: i.played_at or datetime.min,
) )
) )
self.current_index = self.room.playlist.index(next_item) self.current_index = self.room.playlist.index(next_item)
@@ -356,14 +348,7 @@ class MultiplayerQueue:
limit = HOST_LIMIT if is_host else PER_USER_LIMIT limit = HOST_LIMIT if is_host else PER_USER_LIMIT
if ( if (
len( len([True for u in self.room.playlist if u.owner_id == user.user_id])
list(
filter(
lambda x: x.owner_id == user.user_id,
self.room.playlist,
)
)
)
>= limit >= limit
): ):
raise InvokeException(f"You can only have {limit} items in the queue") raise InvokeException(f"You can only have {limit} items in the queue")
@@ -376,11 +361,11 @@ class MultiplayerQueue:
beatmap = await session.get(Beatmap, item.beatmap_id) beatmap = await session.get(Beatmap, item.beatmap_id)
if beatmap is None: if beatmap is None:
raise InvokeException("Beatmap not found") raise InvokeException("Beatmap not found")
if item.checksum != beatmap.checksum: if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch") raise InvokeException("Checksum mismatch")
# TODO: mods validation # TODO: mods validation
item.owner_id = user.user_id item.owner_id = user.user_id
item.star = float( item.star_rating = float(
beatmap.difficulty_rating beatmap.difficulty_rating
) # FIXME: beatmap use decimal ) # FIXME: beatmap use decimal
await Playlist.add_to_db(item, self.room.room_id, session) await Playlist.add_to_db(item, self.room.room_id, session)
@@ -400,7 +385,7 @@ class MultiplayerQueue:
beatmap = await session.get(Beatmap, item.beatmap_id) beatmap = await session.get(Beatmap, item.beatmap_id)
if beatmap is None: if beatmap is None:
raise InvokeException("Beatmap not found") raise InvokeException("Beatmap not found")
if item.checksum != beatmap.checksum: if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch") raise InvokeException("Checksum mismatch")
existing_item = next( existing_item = next(
@@ -423,8 +408,8 @@ class MultiplayerQueue:
# TODO: mods validation # TODO: mods validation
item.owner_id = user.user_id item.owner_id = user.user_id
item.star = float(beatmap.difficulty_rating) item.star_rating = float(beatmap.difficulty_rating)
item.order = existing_item.order item.playlist_order = existing_item.playlist_order
await Playlist.update(item, self.room.room_id, session) await Playlist.update(item, self.room.room_id, session)
@@ -437,7 +422,8 @@ class MultiplayerQueue:
await self.hub.playlist_changed( await self.hub.playlist_changed(
self.server_room, self.server_room,
item, item,
beatmap_changed=item.checksum != existing_item.checksum, beatmap_changed=item.beatmap_checksum
!= existing_item.beatmap_checksum,
) )
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser): async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
@@ -477,12 +463,46 @@ class MultiplayerQueue:
await self.update_current_item() await self.update_current_item()
await self.hub.playlist_removed(self.server_room, item.id) await self.hub.playlist_removed(self.server_room, item.id)
async def finish_current_item(self):
from app.database import Playlist
async with AsyncSession(engine) as session:
played_at = datetime.now(UTC)
await session.execute(
update(Playlist)
.where(
col(Playlist.id) == self.current_item.id,
col(Playlist.room_id) == self.room.room_id,
)
.values(expired=True, played_at=played_at)
)
self.room.playlist[self.current_index].expired = True
self.room.playlist[self.current_index].played_at = played_at
await self.hub.playlist_changed(self.server_room, self.current_item, True)
await self.update_order()
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
@property @property
def current_item(self): def current_item(self):
"""Get the current playlist item""" return self.room.playlist[self.current_index]
current_id = self.room.settings.playlist_item_id
return next(
(item for item in self.room.playlist if item.id == current_id), @dataclass
class CountdownInfo:
countdown: MultiplayerCountdown
duration: timedelta
task: asyncio.Task | None = None
def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown
self.duration = (
countdown.remaining
if countdown.remaining > timedelta(seconds=0)
else timedelta(seconds=0)
) )
@@ -491,5 +511,79 @@ class ServerMultiplayerRoom:
room: MultiplayerRoom room: MultiplayerRoom
category: RoomCategory category: RoomCategory
status: RoomStatus status: RoomStatus
start_at: datetime.datetime start_at: datetime
hub: "MultiplayerHub"
queue: MultiplayerQueue | None = None queue: MultiplayerQueue | None = None
_next_countdown_id: int = 0
_countdown_id_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_tracked_countdown: dict[int, CountdownInfo] = field(default_factory=dict)
async def get_next_countdown_id(self) -> int:
async with self._countdown_id_lock:
self._next_countdown_id += 1
return self._next_countdown_id
async def start_countdown(
self,
countdown: MultiplayerCountdown,
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
):
async def _countdown_task(self: "ServerMultiplayerRoom"):
await asyncio.sleep(info.duration.total_seconds())
await self.stop_countdown(countdown)
if on_complete is not None:
await on_complete(self)
if countdown.is_exclusive:
await self.stop_all_countdowns()
countdown.id = await self.get_next_countdown_id()
info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown)
self._tracked_countdown[countdown.id] = info
await self.hub.send_match_event(
self, CountdownStartedEvent(countdown=info.countdown)
)
info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown):
info = next(
(
info
for info in self._tracked_countdown.values()
if info.countdown.id == countdown.id
),
None,
)
if info is None:
return
if info.task is not None and not info.task.done():
info.task.cancel()
del self._tracked_countdown[countdown.id]
self.room.active_countdowns.remove(countdown)
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
async def stop_all_countdowns(self):
for countdown in list(self._tracked_countdown.values()):
await self.stop_countdown(countdown.countdown)
self._tracked_countdown.clear()
self.room.active_countdowns.clear()
class _MatchServerEvent(BaseModel): ...
class CountdownStartedEvent(_MatchServerEvent):
countdown: MultiplayerCountdown
type: Literal[0] = Field(default=0, exclude=True)
class CountdownStoppedEvent(_MatchServerEvent):
id: int
type: Literal[1] = Field(default=1, exclude=True)
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent

View File

@@ -53,6 +53,15 @@ class MultiplayerUserState(str, Enum):
RESULTS = "results" RESULTS = "results"
SPECTATING = "spectating" SPECTATING = "spectating"
@property
def is_playing(self) -> bool:
return self in {
self.WAITING_FOR_LOAD,
self.PLAYING,
self.READY_FOR_GAMEPLAY,
self.LOADED,
}
class DownloadState(str, Enum): class DownloadState(str, Enum):
UNKNOWN = "unknown" UNKNOWN = "unknown"

View File

@@ -14,6 +14,7 @@ class SignalRMeta:
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
json_ignore: bool = False # implement of JsonIgnore (json) attribute json_ignore: bool = False # implement of JsonIgnore (json) attribute
use_upper_case: bool = False # use upper CamelCase for field names use_upper_case: bool = False # use upper CamelCase for field names
use_abbr: bool = True
class SignalRUnionMessage(BaseModel): class SignalRUnionMessage(BaseModel):

View File

@@ -1,10 +1,19 @@
from __future__ import annotations from __future__ import annotations
from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User from app.database import (
Beatmap,
Playlist,
Score,
ScoreResp,
ScoreToken,
ScoreTokenResp,
User,
)
from app.database.score import get_leaderboard, process_score, process_user from app.database.score import get_leaderboard, process_score, process_user
from app.dependencies.database import get_db, get_redis from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
from app.models.beatmap import BeatmapRankStatus from app.models.beatmap import BeatmapRankStatus
from app.models.score import ( from app.models.score import (
INT_TO_MODE, INT_TO_MODE,
@@ -13,6 +22,7 @@ from app.models.score import (
Rank, Rank,
SoloScoreSubmissionInfo, SoloScoreSubmissionInfo,
) )
from app.signalr.hub import MultiplayerHubs
from .api_router import router from .api_router import router
@@ -24,6 +34,68 @@ from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
async def submit_score(
info: SoloScoreSubmissionInfo,
beatmap: int,
token: int,
current_user: User,
db: AsyncSession,
redis: Redis,
fetcher: Fetcher,
item_id: int | None = None,
room_id: int | None = None,
):
if not info.passed:
info.rank = Rank.F
score_token = (
await db.exec(
select(ScoreToken)
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
.where(ScoreToken.id == token)
)
).first()
if not score_token or score_token.user_id != current_user.id:
raise HTTPException(status_code=404, detail="Score token not found")
if score_token.score_id:
score = (
await db.exec(
select(Score).where(
Score.id == score_token.score_id,
Score.user_id == current_user.id,
)
)
).first()
if not score:
raise HTTPException(status_code=404, detail="Score not found")
else:
beatmap_status = (
await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap))
).first()
if beatmap_status is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
ranked = beatmap_status in {
BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED,
}
score = await process_score(
current_user,
beatmap,
ranked,
score_token,
info,
fetcher,
db,
redis,
)
await db.refresh(current_user)
score_id = score.id
score_token.score_id = score_id
await process_user(db, current_user, score, ranked)
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
assert score is not None
return await ScoreResp.from_db(db, score)
class BeatmapScores(BaseModel): class BeatmapScores(BaseModel):
scores: list[ScoreResp] scores: list[ScoreResp]
userScore: ScoreResp | None = None userScore: ScoreResp | None = None
@@ -97,9 +169,10 @@ async def get_user_beatmap_score(
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap" status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
) )
else: else:
resp = await ScoreResp.from_db(db, user_score)
return BeatmapUserScore( return BeatmapUserScore(
position=user_score.position if user_score.position is not None else 0, position=resp.rank_global or 0,
score=await ScoreResp.from_db(db, user_score), score=resp,
) )
@@ -173,55 +246,95 @@ async def submit_solo_score(
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher), fetcher=Depends(get_fetcher),
): ):
if not info.passed: return await submit_score(info, beatmap, token, current_user, db, redis, fetcher)
info.rank = Rank.F
async with db:
score_token = ( @router.post(
await db.exec( "/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp
select(ScoreToken) )
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] async def create_playlist_score(
.where(ScoreToken.id == token, ScoreToken.user_id == current_user.id) room_id: int,
playlist_id: int,
beatmap_id: int = Form(),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
version_hash: str = Form(""),
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
):
room = MultiplayerHubs.rooms[room_id]
if not room:
raise HTTPException(status_code=404, detail="Room not found")
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
) )
).first() )
if not score_token or score_token.user_id != current_user.id: ).first()
raise HTTPException(status_code=404, detail="Score token not found") if not item:
if score_token.score_id: raise HTTPException(status_code=404, detail="Playlist not found")
score = (
await db.exec( # validate
select(Score).where( if not item.freestyle:
Score.id == score_token.score_id, if item.ruleset_id != ruleset_id:
Score.user_id == current_user.id, raise HTTPException(
) status_code=400, detail="Ruleset mismatch in playlist item"
)
).first()
if not score:
raise HTTPException(status_code=404, detail="Score not found")
else:
beatmap_status = (
await db.exec(
select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
)
).first()
if beatmap_status is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
ranked = beatmap_status in {
BeatmapRankStatus.RANKED,
BeatmapRankStatus.APPROVED,
}
score = await process_score(
current_user,
beatmap,
ranked,
score_token,
info,
fetcher,
db,
redis,
) )
await db.refresh(current_user) if item.beatmap_id != beatmap_id:
score_id = score.id raise HTTPException(
score_token.score_id = score_id status_code=400, detail="Beatmap ID mismatch in playlist item"
await process_user(db, current_user, score, ranked) )
score = (await db.exec(select(Score).where(Score.id == score_id))).first() # TODO: max attempts
assert score is not None if item.expired:
return await ScoreResp.from_db(db, score) raise HTTPException(status_code=400, detail="Playlist item has expired")
if item.played_at:
raise HTTPException(
status_code=400, detail="Playlist item has already been played"
)
# 这里应该不用验证mod了吧。。。
score_token = ScoreToken(
user_id=current_user.id,
beatmap_id=beatmap_id,
ruleset_id=INT_TO_MODE[ruleset_id],
playlist_item_id=playlist_id,
)
session.add(score_token)
await session.commit()
await session.refresh(score_token)
return ScoreTokenResp.from_db(score_token)
@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}")
async def submit_playlist_score(
room_id: int,
playlist_id: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist item not found")
score_resp = await submit_score(
info,
item.beatmap_id,
token,
current_user,
session,
redis,
fetcher,
item.id,
room_id,
)
return score_resp

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from datetime import timedelta
from typing import override from typing import override
from app.database import Room from app.database import Room
@@ -8,8 +10,11 @@ from app.database.playlists import Playlist
from app.dependencies.database import engine from app.dependencies.database import engine
from app.exception import InvokeException from app.exception import InvokeException
from app.log import logger from app.log import logger
from app.models.mods import APIMod
from app.models.multiplayer_hub import ( from app.models.multiplayer_hub import (
BeatmapAvailability, BeatmapAvailability,
ForceGameplayStartCountdown,
MatchServerEvent,
MultiplayerClientState, MultiplayerClientState,
MultiplayerQueue, MultiplayerQueue,
MultiplayerRoom, MultiplayerRoom,
@@ -17,16 +22,22 @@ from app.models.multiplayer_hub import (
PlaylistItem, PlaylistItem,
ServerMultiplayerRoom, ServerMultiplayerRoom,
) )
from app.models.room import RoomCategory, RoomStatus from app.models.room import (
DownloadState,
MultiplayerRoomState,
MultiplayerUserState,
RoomCategory,
RoomStatus,
)
from app.models.score import GameMode from app.models.score import GameMode
from app.models.signalr import serialize_to_list
from .hub import Client, Hub from .hub import Client, Hub
from msgpack_lazer_api import APIMod
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
GAMEPLAY_LOAD_TIMEOUT = 30
class MultiplayerHub(Hub[MultiplayerClientState]): class MultiplayerHub(Hub[MultiplayerClientState]):
@override @override
@@ -58,7 +69,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
type=room.settings.match_type, type=room.settings.match_type,
queue_mode=room.settings.queue_mode, queue_mode=room.settings.queue_mode,
auto_skip=room.settings.auto_skip, auto_skip=room.settings.auto_skip,
auto_start_duration=room.settings.auto_start_duration, auto_start_duration=int(
room.settings.auto_start_duration.total_seconds()
),
host_id=client.user_id, host_id=client.user_id,
status=RoomStatus.IDLE, status=RoomStatus.IDLE,
) )
@@ -75,10 +88,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
category=RoomCategory.NORMAL, category=RoomCategory.NORMAL,
status=RoomStatus.IDLE, status=RoomStatus.IDLE,
start_at=starts_at, start_at=starts_at,
hub=self,
) )
queue = MultiplayerQueue( queue = MultiplayerQueue(
room=server_room, room=server_room,
hub=self,
) )
server_room.queue = queue server_room.queue = queue
self.rooms[room.room_id] = server_room self.rooms[room.room_id] = server_room
@@ -86,6 +99,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
client, room.room_id, room.settings.password client, room.room_id, room.settings.password
) )
async def JoinRoom(self, client: Client, room_id: int):
return self.JoinRoomWithPassword(client, room_id, "")
async def JoinRoomWithPassword(self, client: Client, room_id: int, password: str): async def JoinRoomWithPassword(self, client: Client, room_id: int, password: str):
logger.info(f"[MultiplayerHub] {client.user_id} joining room {room_id}") logger.info(f"[MultiplayerHub] {client.user_id} joining room {room_id}")
store = self.get_or_create_state(client) store = self.get_or_create_state(client)
@@ -105,12 +121,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
# from CreateRoom # from CreateRoom
room.host = user room.host = user
store.room_id = room_id store.room_id = room_id
await self.broadcast_group_call( await self.broadcast_group_call(self.group_id(room_id), "UserJoined", user)
self.group_id(room_id), "UserJoined", serialize_to_list(user)
)
room.users.append(user) room.users.append(user)
self.add_to_group(client, self.group_id(room_id)) self.add_to_group(client, self.group_id(room_id))
return serialize_to_list(room) return room
async def ChangeBeatmapAvailability( async def ChangeBeatmapAvailability(
self, client: Client, beatmap_availability: BeatmapAvailability self, client: Client, beatmap_availability: BeatmapAvailability
@@ -132,12 +146,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
and availability.progress == beatmap_availability.progress and availability.progress == beatmap_availability.progress
): ):
return return
user.availability = availability user.availability = beatmap_availability
await self.broadcast_group_call( await self.broadcast_group_call(
self.group_id(store.room_id), self.group_id(store.room_id),
"UserBeatmapAvailabilityChanged", "UserBeatmapAvailabilityChanged",
user.user_id, user.user_id,
serialize_to_list(beatmap_availability), (beatmap_availability),
) )
async def AddPlaylistItem(self, client: Client, item: PlaylistItem): async def AddPlaylistItem(self, client: Client, item: PlaylistItem):
@@ -198,14 +212,14 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
await self.broadcast_group_call( await self.broadcast_group_call(
self.group_id(room.room.room_id), self.group_id(room.room.room_id),
"SettingsChanged", "SettingsChanged",
serialize_to_list(room.room.settings), (room.room.settings),
) )
async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem): async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem):
await self.broadcast_group_call( await self.broadcast_group_call(
self.group_id(room.room.room_id), self.group_id(room.room.room_id),
"PlaylistItemAdded", "PlaylistItemAdded",
serialize_to_list(item), (item),
) )
async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int): async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int):
@@ -221,7 +235,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
await self.broadcast_group_call( await self.broadcast_group_call(
self.group_id(room.room.room_id), self.group_id(room.room.room_id),
"PlaylistItemChanged", "PlaylistItemChanged",
serialize_to_list(item), (item),
) )
async def ChangeUserStyle( async def ChangeUserStyle(
@@ -378,7 +392,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
) )
if not is_valid: if not is_valid:
incompatible_mods = [ incompatible_mods = [
mod.acronym for mod in new_mods if mod not in valid_mods mod["acronym"] for mod in new_mods if mod not in valid_mods
] ]
raise InvokeException( raise InvokeException(
f"Incompatible mods were selected: {','.join(incompatible_mods)}" f"Incompatible mods were selected: {','.join(incompatible_mods)}"
@@ -395,3 +409,221 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
user.user_id, user.user_id,
valid_mods, valid_mods,
) )
async def validate_user_stare(
self,
room: ServerMultiplayerRoom,
old: MultiplayerUserState,
new: MultiplayerUserState,
):
assert room.queue
match new:
case MultiplayerUserState.IDLE:
if old.is_playing:
raise InvokeException(
"Cannot return to idle without aborting gameplay."
)
case MultiplayerUserState.READY:
if old != MultiplayerUserState.IDLE:
raise InvokeException(f"Cannot change state from {old} to {new}")
if room.queue.current_item.expired:
raise InvokeException(
"Cannot ready up while all items have been played."
)
case MultiplayerUserState.WAITING_FOR_LOAD:
raise InvokeException("Cannot change state from {old} to {new}")
case MultiplayerUserState.LOADED:
if old != MultiplayerUserState.WAITING_FOR_LOAD:
raise InvokeException(f"Cannot change state from {old} to {new}")
case MultiplayerUserState.READY_FOR_GAMEPLAY:
if old != MultiplayerUserState.LOADED:
raise InvokeException(f"Cannot change state from {old} to {new}")
case MultiplayerUserState.PLAYING:
raise InvokeException("State is managed by the server.")
case MultiplayerUserState.FINISHED_PLAY:
if old != MultiplayerUserState.PLAYING:
raise InvokeException(f"Cannot change state from {old} to {new}")
case MultiplayerUserState.RESULTS:
raise InvokeException("Cannot change state from {old} to {new}")
case MultiplayerUserState.SPECTATING:
if old not in (MultiplayerUserState.IDLE, MultiplayerUserState.READY):
raise InvokeException(f"Cannot change state from {old} to {new}")
async def ChangeState(self, client: Client, state: MultiplayerUserState):
store = self.get_or_create_state(client)
if store.room_id == 0:
raise InvokeException("You are not in a room")
if store.room_id not in self.rooms:
raise InvokeException("Room does not exist")
server_room = self.rooms[store.room_id]
room = server_room.room
user = next((u for u in room.users if u.user_id == client.user_id), None)
if user is None:
raise InvokeException("You are not in this room")
if user.state == state:
return
match state:
case MultiplayerUserState.IDLE:
if user.state.is_playing:
return
case MultiplayerUserState.LOADED | MultiplayerUserState.READY_FOR_GAMEPLAY:
if not user.state.is_playing:
return
await self.validate_user_stare(
server_room,
user.state,
state,
)
await self.change_user_state(server_room, user, state)
if state == MultiplayerUserState.SPECTATING and (
room.state == MultiplayerRoomState.PLAYING
or room.state == MultiplayerRoomState.WAITING_FOR_LOAD
):
await self.call_noblock(client, "LoadRequested")
await self.update_room_state(server_room)
async def change_user_state(
self,
room: ServerMultiplayerRoom,
user: MultiplayerRoomUser,
state: MultiplayerUserState,
):
user.state = state
await self.broadcast_group_call(
self.group_id(room.room.room_id),
"UserStateChanged",
user.user_id,
user.state,
)
async def update_room_state(self, room: ServerMultiplayerRoom):
match room.room.state:
case MultiplayerRoomState.WAITING_FOR_LOAD:
played_count = len(
[True for user in room.room.users if user.state.is_playing]
)
ready_count = len(
[
True
for user in room.room.users
if user.state == MultiplayerUserState.READY_FOR_GAMEPLAY
]
)
if played_count == ready_count:
await self.start_gameplay(room)
case MultiplayerRoomState.PLAYING:
assert room.queue
if all(
u.state != MultiplayerUserState.PLAYING for u in room.room.users
):
for u in filter(
lambda u: u.state == MultiplayerUserState.FINISHED_PLAY,
room.room.users,
):
await self.change_user_state(
room, u, MultiplayerUserState.RESULTS
)
await self.change_room_state(room, MultiplayerRoomState.OPEN)
await self.broadcast_group_call(
self.group_id(room.room.room_id),
"ResultsReady",
)
await room.queue.finish_current_item()
async def change_room_state(
self, room: ServerMultiplayerRoom, state: MultiplayerRoomState
):
room.room.state = state
await self.broadcast_group_call(
self.group_id(room.room.room_id),
"RoomStateChanged",
state,
)
async def StartMatch(self, client: Client):
store = self.get_or_create_state(client)
if store.room_id == 0:
raise InvokeException("You are not in a room")
if store.room_id not in self.rooms:
raise InvokeException("Room does not exist")
server_room = self.rooms[store.room_id]
room = server_room.room
user = next((u for u in room.users if u.user_id == client.user_id), None)
if user is None:
raise InvokeException("You are not in this room")
if room.host is None or room.host.user_id != client.user_id:
raise InvokeException("You are not the host of this room")
if any(u.state != MultiplayerUserState.READY for u in room.users):
raise InvokeException("Not all users are ready")
await self.start_match(server_room)
async def start_match(self, room: ServerMultiplayerRoom):
assert room.queue
if room.room.state != MultiplayerRoomState.OPEN:
raise InvokeException("Can't start match when already in a running state.")
if room.queue.current_item.expired:
raise InvokeException("Current playlist item is expired")
ready_users = [
u
for u in room.room.users
if u.availability.state == DownloadState.LOCALLY_AVAILABLE
and (
u.state == MultiplayerUserState.READY
or u.state == MultiplayerUserState.IDLE
)
]
await asyncio.gather(
*[
self.change_user_state(room, u, MultiplayerUserState.WAITING_FOR_LOAD)
for u in ready_users
]
)
await self.change_room_state(
room,
MultiplayerRoomState.WAITING_FOR_LOAD,
)
await self.broadcast_group_call(
self.group_id(room.room.room_id),
"LoadRequested",
)
await room.start_countdown(
ForceGameplayStartCountdown(
remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT)
),
self.start_gameplay,
)
async def start_gameplay(self, room: ServerMultiplayerRoom):
assert room.queue
if room.room.state != MultiplayerRoomState.WAITING_FOR_LOAD:
raise InvokeException("Room is not ready for gameplay")
if room.queue.current_item.expired:
raise InvokeException("Current playlist item is expired")
playing = False
for user in room.room.users:
client = self.get_client_by_id(str(user.user_id))
if client is None:
continue
if user.state in (
MultiplayerUserState.READY_FOR_GAMEPLAY,
MultiplayerUserState.LOADED,
):
playing = True
await self.change_user_state(room, user, MultiplayerUserState.PLAYING)
await self.call_noblock(client, "GameplayStarted")
await self.change_room_state(
room,
(MultiplayerRoomState.PLAYING if playing else MultiplayerRoomState.OPEN),
)
async def send_match_event(
self, room: ServerMultiplayerRoom, event: MatchServerEvent
):
await self.broadcast_group_call(
self.group_id(room.room.room_id),
"MatchEvent",
event,
)

View File

@@ -97,6 +97,8 @@ class MsgpackProtocol:
return [cls.serialize_msgpack(item) for item in v] return [cls.serialize_msgpack(item) for item in v]
elif issubclass(typ, datetime.datetime): elif issubclass(typ, datetime.datetime):
return [v, 0] return [v, 0]
elif issubclass(typ, datetime.timedelta):
return int(v.total_seconds())
elif isinstance(v, dict): elif isinstance(v, dict):
return { return {
cls.serialize_msgpack(k): cls.serialize_msgpack(value) cls.serialize_msgpack(k): cls.serialize_msgpack(value)
@@ -213,6 +215,8 @@ class MsgpackProtocol:
return typ.model_validate(obj=cls.process_object(v, typ)) return typ.model_validate(obj=cls.process_object(v, typ))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return v[0] return v[0]
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
return datetime.timedelta(seconds=int(v))
elif isinstance(v, list): elif isinstance(v, list):
return [cls.validate_object(item, get_args(typ)[0]) for item in v] return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum): elif inspect.isclass(typ) and issubclass(typ, Enum):
@@ -296,21 +300,30 @@ class MsgpackProtocol:
class JSONProtocol: class JSONProtocol:
@classmethod @classmethod
def serialize_to_json(cls, v: Any): def serialize_to_json(cls, v: Any, dict_key: bool = False):
typ = v.__class__ typ = v.__class__
if issubclass(typ, BaseModel): if issubclass(typ, BaseModel):
return cls.serialize_model(v) return cls.serialize_model(v)
elif isinstance(v, dict): elif isinstance(v, dict):
return { return {
cls.serialize_to_json(k): cls.serialize_to_json(value) cls.serialize_to_json(k, True): cls.serialize_to_json(value)
for k, value in v.items() for k, value in v.items()
} }
elif isinstance(v, list): elif isinstance(v, list):
return [cls.serialize_to_json(item) for item in v] return [cls.serialize_to_json(item) for item in v]
elif isinstance(v, datetime.datetime): elif isinstance(v, datetime.datetime):
return v.isoformat() return v.isoformat()
elif isinstance(v, Enum): elif isinstance(v, datetime.timedelta):
# d.hh:mm:ss
total_seconds = int(v.total_seconds())
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{hours:02}:{minutes:02}:{seconds:02}"
elif isinstance(v, Enum) and dict_key:
return v.value return v.value
elif isinstance(v, Enum):
list_ = list(typ)
return list_.index(v)
return v return v
@classmethod @classmethod
@@ -322,9 +335,13 @@ class JSONProtocol:
) )
if metadata and metadata.json_ignore: if metadata and metadata.json_ignore:
continue continue
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = ( d[
cls.serialize_to_json(getattr(v, field)) snake_to_camel(
) field,
metadata.use_upper_case if metadata else False,
metadata.use_abbr if metadata else True,
)
] = cls.serialize_to_json(getattr(v, field))
if issubclass(v.__class__, SignalRUnionMessage): if issubclass(v.__class__, SignalRUnionMessage):
return { return {
"$dtype": v.__class__.__name__, "$dtype": v.__class__.__name__,
@@ -343,7 +360,11 @@ class JSONProtocol:
) )
if metadata and metadata.json_ignore: if metadata and metadata.json_ignore:
continue continue
value = v.get(snake_to_camel(field, not from_union)) value = v.get(
snake_to_camel(
field, not from_union, metadata.use_abbr if metadata else True
)
)
anno = typ.model_fields[field].annotation anno = typ.model_fields[field].annotation
if anno is None: if anno is None:
d[field] = value d[field] = value
@@ -401,6 +422,17 @@ class JSONProtocol:
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union)) return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return datetime.datetime.fromisoformat(v) return datetime.datetime.fromisoformat(v)
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
# d.hh:mm:ss
parts = v.split(":")
if len(parts) == 3:
return datetime.timedelta(
hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2])
)
elif len(parts) == 2:
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
elif len(parts) == 1:
return datetime.timedelta(seconds=int(parts[0]))
elif isinstance(v, list): elif isinstance(v, list):
return [cls.validate_object(item, get_args(typ)[0]) for item in v] return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum): elif inspect.isclass(typ) and issubclass(typ, Enum):

View File

@@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str:
return "".join(result) return "".join(result)
def snake_to_camel(name: str, lower_case: bool = True) -> str: def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) -> str:
"""Convert a snake_case string to camelCase.""" """Convert a snake_case string to camelCase."""
if not name: if not name:
return name return name
@@ -47,7 +47,7 @@ def snake_to_camel(name: str, lower_case: bool = True) -> str:
result = [] result = []
for part in parts: for part in parts:
if part.lower() in abbreviations: if part.lower() in abbreviations and use_abbr:
result.append(part.upper()) result.append(part.upper())
else: else:
if result or not lower_case: if result or not lower_case: