feat(multiplayer): support play
WIP
This commit is contained in:
@@ -2,7 +2,7 @@ from datetime import datetime
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from app.models.model import UTCBaseModel
|
from app.models.model import UTCBaseModel
|
||||||
from app.models.mods import APIMod, msgpack_to_apimod
|
from app.models.mods import APIMod
|
||||||
from app.models.multiplayer_hub import PlaylistItem
|
from app.models.multiplayer_hub import PlaylistItem
|
||||||
|
|
||||||
from .beatmap import Beatmap, BeatmapResp
|
from .beatmap import Beatmap, BeatmapResp
|
||||||
@@ -79,10 +79,10 @@ class Playlist(PlaylistBase, table=True):
|
|||||||
owner_id=playlist.owner_id,
|
owner_id=playlist.owner_id,
|
||||||
ruleset_id=playlist.ruleset_id,
|
ruleset_id=playlist.ruleset_id,
|
||||||
beatmap_id=playlist.beatmap_id,
|
beatmap_id=playlist.beatmap_id,
|
||||||
required_mods=[msgpack_to_apimod(mod) for mod in playlist.required_mods],
|
required_mods=playlist.required_mods,
|
||||||
allowed_mods=[msgpack_to_apimod(mod) for mod in playlist.allowed_mods],
|
allowed_mods=playlist.allowed_mods,
|
||||||
expired=playlist.expired,
|
expired=playlist.expired,
|
||||||
playlist_order=playlist.order,
|
playlist_order=playlist.playlist_order,
|
||||||
played_at=playlist.played_at,
|
played_at=playlist.played_at,
|
||||||
freestyle=playlist.freestyle,
|
freestyle=playlist.freestyle,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
@@ -99,14 +99,10 @@ class Playlist(PlaylistBase, table=True):
|
|||||||
db_playlist.owner_id = playlist.owner_id
|
db_playlist.owner_id = playlist.owner_id
|
||||||
db_playlist.ruleset_id = playlist.ruleset_id
|
db_playlist.ruleset_id = playlist.ruleset_id
|
||||||
db_playlist.beatmap_id = playlist.beatmap_id
|
db_playlist.beatmap_id = playlist.beatmap_id
|
||||||
db_playlist.required_mods = [
|
db_playlist.required_mods = playlist.required_mods
|
||||||
msgpack_to_apimod(mod) for mod in playlist.required_mods
|
db_playlist.allowed_mods = playlist.allowed_mods
|
||||||
]
|
|
||||||
db_playlist.allowed_mods = [
|
|
||||||
msgpack_to_apimod(mod) for mod in playlist.allowed_mods
|
|
||||||
]
|
|
||||||
db_playlist.expired = playlist.expired
|
db_playlist.expired = playlist.expired
|
||||||
db_playlist.playlist_order = playlist.order
|
db_playlist.playlist_order = playlist.playlist_order
|
||||||
db_playlist.played_at = playlist.played_at
|
db_playlist.played_at = playlist.played_at
|
||||||
db_playlist.freestyle = playlist.freestyle
|
db_playlist.freestyle = playlist.freestyle
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ class RoomResp(RoomBase):
|
|||||||
type=room.settings.match_type,
|
type=room.settings.match_type,
|
||||||
queue_mode=room.settings.queue_mode,
|
queue_mode=room.settings.queue_mode,
|
||||||
auto_skip=room.settings.auto_skip,
|
auto_skip=room.settings.auto_skip,
|
||||||
auto_start_duration=room.settings.auto_start_duration,
|
auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
|
||||||
status=server_room.status,
|
status=server_room.status,
|
||||||
category=server_room.category,
|
category=server_room.category,
|
||||||
# duration = room.settings.duration,
|
# duration = room.settings.duration,
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
|||||||
|
|
||||||
# optional
|
# optional
|
||||||
# TODO: current_user_attributes
|
# TODO: current_user_attributes
|
||||||
position: int | None = Field(default=None) # multiplayer
|
# position: int | None = Field(default=None) # multiplayer
|
||||||
|
|
||||||
|
|
||||||
class Score(ScoreBase, table=True):
|
class Score(ScoreBase, table=True):
|
||||||
@@ -162,6 +162,7 @@ class ScoreResp(ScoreBase):
|
|||||||
maximum_statistics: ScoreStatistics | None = None
|
maximum_statistics: ScoreStatistics | None = None
|
||||||
rank_global: int | None = None
|
rank_global: int | None = None
|
||||||
rank_country: int | None = None
|
rank_country: int | None = None
|
||||||
|
position: int = 1 # TODO
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
||||||
@@ -618,6 +619,8 @@ async def process_score(
|
|||||||
fetcher: "Fetcher",
|
fetcher: "Fetcher",
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
redis: Redis,
|
redis: Redis,
|
||||||
|
item_id: int | None = None,
|
||||||
|
room_id: int | None = None,
|
||||||
) -> Score:
|
) -> Score:
|
||||||
assert user.id
|
assert user.id
|
||||||
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
|
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
|
||||||
@@ -649,6 +652,8 @@ async def process_score(
|
|||||||
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
|
nsmall_tick_hit=info.statistics.get(HitResult.SMALL_TICK_HIT, 0),
|
||||||
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
|
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
|
||||||
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
|
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
|
||||||
|
playlist_item_id=item_id,
|
||||||
|
room_id=room_id,
|
||||||
)
|
)
|
||||||
if can_get_pp:
|
if can_get_pp:
|
||||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ from typing import Literal, NotRequired, TypedDict
|
|||||||
|
|
||||||
from app.path import STATIC_DIR
|
from app.path import STATIC_DIR
|
||||||
|
|
||||||
from msgpack_lazer_api import APIMod as MsgpackAPIMod
|
|
||||||
|
|
||||||
|
|
||||||
class APIMod(TypedDict):
|
class APIMod(TypedDict):
|
||||||
acronym: str
|
acronym: str
|
||||||
@@ -169,13 +167,3 @@ def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
|
|||||||
if expected_value != NO_CHECK and value != expected_value:
|
if expected_value != NO_CHECK and value != expected_value:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def msgpack_to_apimod(mod: MsgpackAPIMod) -> APIMod:
|
|
||||||
"""
|
|
||||||
Convert a MsgpackAPIMod to an APIMod.
|
|
||||||
"""
|
|
||||||
return APIMod(
|
|
||||||
acronym=mod.acronym,
|
|
||||||
settings=mod.settings,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
import asyncio
|
||||||
import datetime
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
|
||||||
|
|
||||||
from app.database.beatmap import Beatmap
|
from app.database.beatmap import Beatmap
|
||||||
from app.dependencies.database import engine
|
from app.dependencies.database import engine
|
||||||
from app.exception import InvokeException
|
from app.exception import InvokeException
|
||||||
|
|
||||||
|
from .mods import APIMod
|
||||||
from .room import (
|
from .room import (
|
||||||
DownloadState,
|
DownloadState,
|
||||||
MatchType,
|
MatchType,
|
||||||
@@ -18,15 +21,14 @@ from .room import (
|
|||||||
RoomStatus,
|
RoomStatus,
|
||||||
)
|
)
|
||||||
from .signalr import (
|
from .signalr import (
|
||||||
EnumByIndex,
|
SignalRMeta,
|
||||||
MessagePackArrayModel,
|
SignalRUnionMessage,
|
||||||
UserState,
|
UserState,
|
||||||
msgpack_union,
|
|
||||||
msgpack_union_dump,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from msgpack_lazer_api import APIMod
|
from pydantic import BaseModel, Field
|
||||||
from pydantic import Field, field_serializer, field_validator
|
from sqlalchemy import update
|
||||||
|
from sqlmodel import col
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -40,37 +42,37 @@ class MultiplayerClientState(UserState):
|
|||||||
room_id: int = 0
|
room_id: int = 0
|
||||||
|
|
||||||
|
|
||||||
class MultiplayerRoomSettings(MessagePackArrayModel):
|
class MultiplayerRoomSettings(BaseModel):
|
||||||
name: str = "Unnamed Room"
|
name: str = "Unnamed Room"
|
||||||
playlist_item_id: int = 0
|
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||||
password: str = ""
|
password: str = ""
|
||||||
match_type: Annotated[MatchType, EnumByIndex(MatchType)] = MatchType.HEAD_TO_HEAD
|
match_type: MatchType = MatchType.HEAD_TO_HEAD
|
||||||
queue_mode: Annotated[QueueMode, EnumByIndex(QueueMode)] = QueueMode.HOST_ONLY
|
queue_mode: QueueMode = QueueMode.HOST_ONLY
|
||||||
auto_start_duration: int = 0
|
auto_start_duration: timedelta = timedelta(seconds=0)
|
||||||
auto_skip: bool = False
|
auto_skip: bool = False
|
||||||
|
|
||||||
|
|
||||||
class BeatmapAvailability(MessagePackArrayModel):
|
class BeatmapAvailability(BaseModel):
|
||||||
state: Annotated[DownloadState, EnumByIndex(DownloadState)] = DownloadState.UNKNOWN
|
state: DownloadState = DownloadState.UNKNOWN
|
||||||
progress: float | None = None
|
progress: float | None = None
|
||||||
|
|
||||||
|
|
||||||
class _MatchUserState(MessagePackArrayModel): ...
|
class _MatchUserState(SignalRUnionMessage): ...
|
||||||
|
|
||||||
|
|
||||||
class TeamVersusUserState(_MatchUserState):
|
class TeamVersusUserState(_MatchUserState):
|
||||||
team_id: int
|
team_id: int
|
||||||
|
|
||||||
type: Literal[0] = Field(0, exclude=True)
|
union_type: ClassVar[Literal[0]] = 0
|
||||||
|
|
||||||
|
|
||||||
MatchUserState = TeamVersusUserState
|
MatchUserState = TeamVersusUserState
|
||||||
|
|
||||||
|
|
||||||
class _MatchRoomState(MessagePackArrayModel): ...
|
class _MatchRoomState(SignalRUnionMessage): ...
|
||||||
|
|
||||||
|
|
||||||
class MultiplayerTeam(MessagePackArrayModel):
|
class MultiplayerTeam(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
@@ -83,24 +85,24 @@ class TeamVersusRoomState(_MatchRoomState):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
type: Literal[0] = Field(0, exclude=True)
|
union_type: ClassVar[Literal[0]] = 0
|
||||||
|
|
||||||
|
|
||||||
MatchRoomState = TeamVersusRoomState
|
MatchRoomState = TeamVersusRoomState
|
||||||
|
|
||||||
|
|
||||||
class PlaylistItem(MessagePackArrayModel):
|
class PlaylistItem(BaseModel):
|
||||||
id: int
|
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||||
owner_id: int
|
owner_id: int
|
||||||
beatmap_id: int
|
beatmap_id: int
|
||||||
checksum: str
|
beatmap_checksum: str
|
||||||
ruleset_id: int
|
ruleset_id: int
|
||||||
required_mods: list[APIMod] = Field(default_factory=list)
|
required_mods: list[APIMod] = Field(default_factory=list)
|
||||||
allowed_mods: list[APIMod] = Field(default_factory=list)
|
allowed_mods: list[APIMod] = Field(default_factory=list)
|
||||||
expired: bool
|
expired: bool
|
||||||
order: int
|
playlist_order: int
|
||||||
played_at: datetime.datetime | None = None
|
played_at: datetime | None = None
|
||||||
star: float
|
star_rating: float
|
||||||
freestyle: bool
|
freestyle: bool
|
||||||
|
|
||||||
def validate_user_mods(
|
def validate_user_mods(
|
||||||
@@ -127,7 +129,10 @@ class PlaylistItem(MessagePackArrayModel):
|
|||||||
|
|
||||||
# Check if mods are valid for the ruleset
|
# Check if mods are valid for the ruleset
|
||||||
for mod in proposed_mods:
|
for mod in proposed_mods:
|
||||||
if ruleset_key not in API_MODS or mod.acronym not in API_MODS[ruleset_key]:
|
if (
|
||||||
|
ruleset_key not in API_MODS
|
||||||
|
or mod["acronym"] not in API_MODS[ruleset_key]
|
||||||
|
):
|
||||||
all_proposed_valid = False
|
all_proposed_valid = False
|
||||||
continue
|
continue
|
||||||
valid_mods.append(mod)
|
valid_mods.append(mod)
|
||||||
@@ -136,35 +141,35 @@ class PlaylistItem(MessagePackArrayModel):
|
|||||||
incompatible_mods = set()
|
incompatible_mods = set()
|
||||||
final_valid_mods = []
|
final_valid_mods = []
|
||||||
for mod in valid_mods:
|
for mod in valid_mods:
|
||||||
if mod.acronym in incompatible_mods:
|
if mod["acronym"] in incompatible_mods:
|
||||||
all_proposed_valid = False
|
all_proposed_valid = False
|
||||||
continue
|
continue
|
||||||
setting_mods = API_MODS[ruleset_key].get(mod.acronym)
|
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
|
||||||
if setting_mods:
|
if setting_mods:
|
||||||
incompatible_mods.update(setting_mods["IncompatibleMods"])
|
incompatible_mods.update(setting_mods["IncompatibleMods"])
|
||||||
final_valid_mods.append(mod)
|
final_valid_mods.append(mod)
|
||||||
|
|
||||||
# If not freestyle, check against allowed mods
|
# If not freestyle, check against allowed mods
|
||||||
if not self.freestyle:
|
if not self.freestyle:
|
||||||
allowed_acronyms = {mod.acronym for mod in self.allowed_mods}
|
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||||
filtered_valid_mods = []
|
filtered_valid_mods = []
|
||||||
for mod in final_valid_mods:
|
for mod in final_valid_mods:
|
||||||
if mod.acronym not in allowed_acronyms:
|
if mod["acronym"] not in allowed_acronyms:
|
||||||
all_proposed_valid = False
|
all_proposed_valid = False
|
||||||
else:
|
else:
|
||||||
filtered_valid_mods.append(mod)
|
filtered_valid_mods.append(mod)
|
||||||
final_valid_mods = filtered_valid_mods
|
final_valid_mods = filtered_valid_mods
|
||||||
|
|
||||||
# Check compatibility with required mods
|
# Check compatibility with required mods
|
||||||
required_mod_acronyms = {mod.acronym for mod in self.required_mods}
|
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
|
||||||
all_mod_acronyms = {
|
all_mod_acronyms = {
|
||||||
mod.acronym for mod in final_valid_mods
|
mod["acronym"] for mod in final_valid_mods
|
||||||
} | required_mod_acronyms
|
} | required_mod_acronyms
|
||||||
|
|
||||||
# Check for incompatibility between required and user mods
|
# Check for incompatibility between required and user mods
|
||||||
filtered_valid_mods = []
|
filtered_valid_mods = []
|
||||||
for mod in final_valid_mods:
|
for mod in final_valid_mods:
|
||||||
mod_acronym = mod.acronym
|
mod_acronym = mod["acronym"]
|
||||||
is_compatible = True
|
is_compatible = True
|
||||||
|
|
||||||
for other_acronym in all_mod_acronyms:
|
for other_acronym in all_mod_acronyms:
|
||||||
@@ -181,23 +186,29 @@ class PlaylistItem(MessagePackArrayModel):
|
|||||||
|
|
||||||
return all_proposed_valid, filtered_valid_mods
|
return all_proposed_valid, filtered_valid_mods
|
||||||
|
|
||||||
|
def clone(self) -> "PlaylistItem":
|
||||||
|
copy = self.model_copy()
|
||||||
|
copy.required_mods = list(self.required_mods)
|
||||||
|
copy.allowed_mods = list(self.allowed_mods)
|
||||||
|
return copy
|
||||||
|
|
||||||
class _MultiplayerCountdown(MessagePackArrayModel):
|
|
||||||
id: int
|
class _MultiplayerCountdown(BaseModel):
|
||||||
remaining: int
|
id: int = 0
|
||||||
is_exclusive: bool
|
remaining: timedelta
|
||||||
|
is_exclusive: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MatchStartCountdown(_MultiplayerCountdown):
|
class MatchStartCountdown(_MultiplayerCountdown):
|
||||||
type: Literal[0] = Field(0, exclude=True)
|
union_type: ClassVar[Literal[0]] = 0
|
||||||
|
|
||||||
|
|
||||||
class ForceGameplayStartCountdown(_MultiplayerCountdown):
|
class ForceGameplayStartCountdown(_MultiplayerCountdown):
|
||||||
type: Literal[1] = Field(1, exclude=True)
|
union_type: ClassVar[Literal[1]] = 1
|
||||||
|
|
||||||
|
|
||||||
class ServerShuttingDownCountdown(_MultiplayerCountdown):
|
class ServerShuttingDownCountdown(_MultiplayerCountdown):
|
||||||
type: Literal[2] = Field(2, exclude=True)
|
union_type: ClassVar[Literal[2]] = 2
|
||||||
|
|
||||||
|
|
||||||
MultiplayerCountdown = (
|
MultiplayerCountdown = (
|
||||||
@@ -205,11 +216,9 @@ MultiplayerCountdown = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiplayerRoomUser(MessagePackArrayModel):
|
class MultiplayerRoomUser(BaseModel):
|
||||||
user_id: int
|
user_id: int
|
||||||
state: Annotated[MultiplayerUserState, EnumByIndex(MultiplayerUserState)] = (
|
state: MultiplayerUserState = MultiplayerUserState.IDLE
|
||||||
MultiplayerUserState.IDLE
|
|
||||||
)
|
|
||||||
availability: BeatmapAvailability = BeatmapAvailability(
|
availability: BeatmapAvailability = BeatmapAvailability(
|
||||||
state=DownloadState.UNKNOWN, progress=None
|
state=DownloadState.UNKNOWN, progress=None
|
||||||
)
|
)
|
||||||
@@ -218,50 +227,33 @@ class MultiplayerRoomUser(MessagePackArrayModel):
|
|||||||
ruleset_id: int | None = None # freestyle
|
ruleset_id: int | None = None # freestyle
|
||||||
beatmap_id: int | None = None # freestyle
|
beatmap_id: int | None = None # freestyle
|
||||||
|
|
||||||
@field_validator("match_state", mode="before")
|
|
||||||
def union_validate(v: Any):
|
|
||||||
if isinstance(v, list):
|
|
||||||
return msgpack_union(v)
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_serializer("match_state")
|
class MultiplayerRoom(BaseModel):
|
||||||
def union_serialize(v: Any):
|
|
||||||
return msgpack_union_dump(v)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiplayerRoom(MessagePackArrayModel):
|
|
||||||
room_id: int
|
room_id: int
|
||||||
state: Annotated[MultiplayerRoomState, EnumByIndex(MultiplayerRoomState)]
|
state: MultiplayerRoomState
|
||||||
settings: MultiplayerRoomSettings
|
settings: MultiplayerRoomSettings
|
||||||
users: list[MultiplayerRoomUser] = Field(default_factory=list)
|
users: list[MultiplayerRoomUser] = Field(default_factory=list)
|
||||||
host: MultiplayerRoomUser | None = None
|
host: MultiplayerRoomUser | None = None
|
||||||
match_state: MatchRoomState | None = None
|
match_state: MatchRoomState | None = None
|
||||||
playlist: list[PlaylistItem] = Field(default_factory=list)
|
playlist: list[PlaylistItem] = Field(default_factory=list)
|
||||||
active_cooldowns: list[MultiplayerCountdown] = Field(default_factory=list)
|
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
|
||||||
channel_id: int
|
channel_id: int
|
||||||
|
|
||||||
@field_validator("match_state", mode="before")
|
|
||||||
def union_validate(v: Any):
|
|
||||||
if isinstance(v, list):
|
|
||||||
return msgpack_union(v)
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_serializer("match_state")
|
|
||||||
def union_serialize(v: Any):
|
|
||||||
return msgpack_union_dump(v)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiplayerQueue:
|
class MultiplayerQueue:
|
||||||
def __init__(self, room: "ServerMultiplayerRoom", hub: "MultiplayerHub"):
|
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||||
self.server_room = room
|
self.server_room = room
|
||||||
self.hub = hub
|
|
||||||
self.current_index = 0
|
self.current_index = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hub(self) -> "MultiplayerHub":
|
||||||
|
return self.server_room.hub
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def upcoming_items(self):
|
def upcoming_items(self):
|
||||||
return sorted(
|
return sorted(
|
||||||
(item for item in self.room.playlist if not item.expired),
|
(item for item in self.room.playlist if not item.expired),
|
||||||
key=lambda i: i.order,
|
key=lambda i: i.playlist_order,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -323,9 +315,9 @@ class MultiplayerQueue:
|
|||||||
)
|
)
|
||||||
async with AsyncSession(engine) as session:
|
async with AsyncSession(engine) as session:
|
||||||
for idx, item in enumerate(ordered_active_items):
|
for idx, item in enumerate(ordered_active_items):
|
||||||
if item.order == idx:
|
if item.playlist_order == idx:
|
||||||
continue
|
continue
|
||||||
item.order = idx
|
item.playlist_order = idx
|
||||||
await Playlist.update(item, self.room.room_id, session)
|
await Playlist.update(item, self.room.room_id, session)
|
||||||
await self.hub.playlist_changed(
|
await self.hub.playlist_changed(
|
||||||
self.server_room, item, beatmap_changed=False
|
self.server_room, item, beatmap_changed=False
|
||||||
@@ -338,7 +330,7 @@ class MultiplayerQueue:
|
|||||||
if upcoming_items
|
if upcoming_items
|
||||||
else max(
|
else max(
|
||||||
self.room.playlist,
|
self.room.playlist,
|
||||||
key=lambda i: i.played_at or datetime.datetime.min,
|
key=lambda i: i.played_at or datetime.min,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.current_index = self.room.playlist.index(next_item)
|
self.current_index = self.room.playlist.index(next_item)
|
||||||
@@ -356,14 +348,7 @@ class MultiplayerQueue:
|
|||||||
|
|
||||||
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
|
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
|
||||||
if (
|
if (
|
||||||
len(
|
len([True for u in self.room.playlist if u.owner_id == user.user_id])
|
||||||
list(
|
|
||||||
filter(
|
|
||||||
lambda x: x.owner_id == user.user_id,
|
|
||||||
self.room.playlist,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
>= limit
|
>= limit
|
||||||
):
|
):
|
||||||
raise InvokeException(f"You can only have {limit} items in the queue")
|
raise InvokeException(f"You can only have {limit} items in the queue")
|
||||||
@@ -376,11 +361,11 @@ class MultiplayerQueue:
|
|||||||
beatmap = await session.get(Beatmap, item.beatmap_id)
|
beatmap = await session.get(Beatmap, item.beatmap_id)
|
||||||
if beatmap is None:
|
if beatmap is None:
|
||||||
raise InvokeException("Beatmap not found")
|
raise InvokeException("Beatmap not found")
|
||||||
if item.checksum != beatmap.checksum:
|
if item.beatmap_checksum != beatmap.checksum:
|
||||||
raise InvokeException("Checksum mismatch")
|
raise InvokeException("Checksum mismatch")
|
||||||
# TODO: mods validation
|
# TODO: mods validation
|
||||||
item.owner_id = user.user_id
|
item.owner_id = user.user_id
|
||||||
item.star = float(
|
item.star_rating = float(
|
||||||
beatmap.difficulty_rating
|
beatmap.difficulty_rating
|
||||||
) # FIXME: beatmap use decimal
|
) # FIXME: beatmap use decimal
|
||||||
await Playlist.add_to_db(item, self.room.room_id, session)
|
await Playlist.add_to_db(item, self.room.room_id, session)
|
||||||
@@ -400,7 +385,7 @@ class MultiplayerQueue:
|
|||||||
beatmap = await session.get(Beatmap, item.beatmap_id)
|
beatmap = await session.get(Beatmap, item.beatmap_id)
|
||||||
if beatmap is None:
|
if beatmap is None:
|
||||||
raise InvokeException("Beatmap not found")
|
raise InvokeException("Beatmap not found")
|
||||||
if item.checksum != beatmap.checksum:
|
if item.beatmap_checksum != beatmap.checksum:
|
||||||
raise InvokeException("Checksum mismatch")
|
raise InvokeException("Checksum mismatch")
|
||||||
|
|
||||||
existing_item = next(
|
existing_item = next(
|
||||||
@@ -423,8 +408,8 @@ class MultiplayerQueue:
|
|||||||
|
|
||||||
# TODO: mods validation
|
# TODO: mods validation
|
||||||
item.owner_id = user.user_id
|
item.owner_id = user.user_id
|
||||||
item.star = float(beatmap.difficulty_rating)
|
item.star_rating = float(beatmap.difficulty_rating)
|
||||||
item.order = existing_item.order
|
item.playlist_order = existing_item.playlist_order
|
||||||
|
|
||||||
await Playlist.update(item, self.room.room_id, session)
|
await Playlist.update(item, self.room.room_id, session)
|
||||||
|
|
||||||
@@ -437,7 +422,8 @@ class MultiplayerQueue:
|
|||||||
await self.hub.playlist_changed(
|
await self.hub.playlist_changed(
|
||||||
self.server_room,
|
self.server_room,
|
||||||
item,
|
item,
|
||||||
beatmap_changed=item.checksum != existing_item.checksum,
|
beatmap_changed=item.beatmap_checksum
|
||||||
|
!= existing_item.beatmap_checksum,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
|
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
|
||||||
@@ -477,12 +463,46 @@ class MultiplayerQueue:
|
|||||||
await self.update_current_item()
|
await self.update_current_item()
|
||||||
await self.hub.playlist_removed(self.server_room, item.id)
|
await self.hub.playlist_removed(self.server_room, item.id)
|
||||||
|
|
||||||
|
async def finish_current_item(self):
|
||||||
|
from app.database import Playlist
|
||||||
|
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
played_at = datetime.now(UTC)
|
||||||
|
await session.execute(
|
||||||
|
update(Playlist)
|
||||||
|
.where(
|
||||||
|
col(Playlist.id) == self.current_item.id,
|
||||||
|
col(Playlist.room_id) == self.room.room_id,
|
||||||
|
)
|
||||||
|
.values(expired=True, played_at=played_at)
|
||||||
|
)
|
||||||
|
self.room.playlist[self.current_index].expired = True
|
||||||
|
self.room.playlist[self.current_index].played_at = played_at
|
||||||
|
await self.hub.playlist_changed(self.server_room, self.current_item, True)
|
||||||
|
await self.update_order()
|
||||||
|
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||||
|
playitem.expired for playitem in self.room.playlist
|
||||||
|
):
|
||||||
|
assert self.room.host
|
||||||
|
await self.add_item(self.current_item.clone(), self.room.host)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_item(self):
|
def current_item(self):
|
||||||
"""Get the current playlist item"""
|
return self.room.playlist[self.current_index]
|
||||||
current_id = self.room.settings.playlist_item_id
|
|
||||||
return next(
|
|
||||||
(item for item in self.room.playlist if item.id == current_id),
|
@dataclass
|
||||||
|
class CountdownInfo:
|
||||||
|
countdown: MultiplayerCountdown
|
||||||
|
duration: timedelta
|
||||||
|
task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
def __init__(self, countdown: MultiplayerCountdown):
|
||||||
|
self.countdown = countdown
|
||||||
|
self.duration = (
|
||||||
|
countdown.remaining
|
||||||
|
if countdown.remaining > timedelta(seconds=0)
|
||||||
|
else timedelta(seconds=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -491,5 +511,79 @@ class ServerMultiplayerRoom:
|
|||||||
room: MultiplayerRoom
|
room: MultiplayerRoom
|
||||||
category: RoomCategory
|
category: RoomCategory
|
||||||
status: RoomStatus
|
status: RoomStatus
|
||||||
start_at: datetime.datetime
|
start_at: datetime
|
||||||
|
hub: "MultiplayerHub"
|
||||||
queue: MultiplayerQueue | None = None
|
queue: MultiplayerQueue | None = None
|
||||||
|
_next_countdown_id: int = 0
|
||||||
|
_countdown_id_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||||
|
_tracked_countdown: dict[int, CountdownInfo] = field(default_factory=dict)
|
||||||
|
|
||||||
|
async def get_next_countdown_id(self) -> int:
|
||||||
|
async with self._countdown_id_lock:
|
||||||
|
self._next_countdown_id += 1
|
||||||
|
return self._next_countdown_id
|
||||||
|
|
||||||
|
async def start_countdown(
|
||||||
|
self,
|
||||||
|
countdown: MultiplayerCountdown,
|
||||||
|
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
|
||||||
|
):
|
||||||
|
async def _countdown_task(self: "ServerMultiplayerRoom"):
|
||||||
|
await asyncio.sleep(info.duration.total_seconds())
|
||||||
|
await self.stop_countdown(countdown)
|
||||||
|
if on_complete is not None:
|
||||||
|
await on_complete(self)
|
||||||
|
|
||||||
|
if countdown.is_exclusive:
|
||||||
|
await self.stop_all_countdowns()
|
||||||
|
|
||||||
|
countdown.id = await self.get_next_countdown_id()
|
||||||
|
info = CountdownInfo(countdown)
|
||||||
|
self.room.active_countdowns.append(info.countdown)
|
||||||
|
self._tracked_countdown[countdown.id] = info
|
||||||
|
await self.hub.send_match_event(
|
||||||
|
self, CountdownStartedEvent(countdown=info.countdown)
|
||||||
|
)
|
||||||
|
info.task = asyncio.create_task(_countdown_task(self))
|
||||||
|
|
||||||
|
async def stop_countdown(self, countdown: MultiplayerCountdown):
|
||||||
|
info = next(
|
||||||
|
(
|
||||||
|
info
|
||||||
|
for info in self._tracked_countdown.values()
|
||||||
|
if info.countdown.id == countdown.id
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if info is None:
|
||||||
|
return
|
||||||
|
if info.task is not None and not info.task.done():
|
||||||
|
info.task.cancel()
|
||||||
|
del self._tracked_countdown[countdown.id]
|
||||||
|
self.room.active_countdowns.remove(countdown)
|
||||||
|
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
|
||||||
|
|
||||||
|
async def stop_all_countdowns(self):
|
||||||
|
for countdown in list(self._tracked_countdown.values()):
|
||||||
|
await self.stop_countdown(countdown.countdown)
|
||||||
|
|
||||||
|
self._tracked_countdown.clear()
|
||||||
|
self.room.active_countdowns.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class _MatchServerEvent(BaseModel): ...
|
||||||
|
|
||||||
|
|
||||||
|
class CountdownStartedEvent(_MatchServerEvent):
|
||||||
|
countdown: MultiplayerCountdown
|
||||||
|
|
||||||
|
type: Literal[0] = Field(default=0, exclude=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CountdownStoppedEvent(_MatchServerEvent):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
type: Literal[1] = Field(default=1, exclude=True)
|
||||||
|
|
||||||
|
|
||||||
|
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
|
||||||
|
|||||||
@@ -53,6 +53,15 @@ class MultiplayerUserState(str, Enum):
|
|||||||
RESULTS = "results"
|
RESULTS = "results"
|
||||||
SPECTATING = "spectating"
|
SPECTATING = "spectating"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_playing(self) -> bool:
|
||||||
|
return self in {
|
||||||
|
self.WAITING_FOR_LOAD,
|
||||||
|
self.PLAYING,
|
||||||
|
self.READY_FOR_GAMEPLAY,
|
||||||
|
self.LOADED,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class DownloadState(str, Enum):
|
class DownloadState(str, Enum):
|
||||||
UNKNOWN = "unknown"
|
UNKNOWN = "unknown"
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ class SignalRMeta:
|
|||||||
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
|
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
|
||||||
json_ignore: bool = False # implement of JsonIgnore (json) attribute
|
json_ignore: bool = False # implement of JsonIgnore (json) attribute
|
||||||
use_upper_case: bool = False # use upper CamelCase for field names
|
use_upper_case: bool = False # use upper CamelCase for field names
|
||||||
|
use_abbr: bool = True
|
||||||
|
|
||||||
|
|
||||||
class SignalRUnionMessage(BaseModel):
|
class SignalRUnionMessage(BaseModel):
|
||||||
|
|||||||
@@ -1,10 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User
|
from app.database import (
|
||||||
|
Beatmap,
|
||||||
|
Playlist,
|
||||||
|
Score,
|
||||||
|
ScoreResp,
|
||||||
|
ScoreToken,
|
||||||
|
ScoreTokenResp,
|
||||||
|
User,
|
||||||
|
)
|
||||||
from app.database.score import get_leaderboard, process_score, process_user
|
from app.database.score import get_leaderboard, process_score, process_user
|
||||||
from app.dependencies.database import get_db, get_redis
|
from app.dependencies.database import get_db, get_redis
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
|
from app.fetcher import Fetcher
|
||||||
from app.models.beatmap import BeatmapRankStatus
|
from app.models.beatmap import BeatmapRankStatus
|
||||||
from app.models.score import (
|
from app.models.score import (
|
||||||
INT_TO_MODE,
|
INT_TO_MODE,
|
||||||
@@ -13,6 +22,7 @@ from app.models.score import (
|
|||||||
Rank,
|
Rank,
|
||||||
SoloScoreSubmissionInfo,
|
SoloScoreSubmissionInfo,
|
||||||
)
|
)
|
||||||
|
from app.signalr.hub import MultiplayerHubs
|
||||||
|
|
||||||
from .api_router import router
|
from .api_router import router
|
||||||
|
|
||||||
@@ -24,6 +34,68 @@ from sqlmodel import col, select
|
|||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
async def submit_score(
|
||||||
|
info: SoloScoreSubmissionInfo,
|
||||||
|
beatmap: int,
|
||||||
|
token: int,
|
||||||
|
current_user: User,
|
||||||
|
db: AsyncSession,
|
||||||
|
redis: Redis,
|
||||||
|
fetcher: Fetcher,
|
||||||
|
item_id: int | None = None,
|
||||||
|
room_id: int | None = None,
|
||||||
|
):
|
||||||
|
if not info.passed:
|
||||||
|
info.rank = Rank.F
|
||||||
|
score_token = (
|
||||||
|
await db.exec(
|
||||||
|
select(ScoreToken)
|
||||||
|
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||||
|
.where(ScoreToken.id == token)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if not score_token or score_token.user_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=404, detail="Score token not found")
|
||||||
|
if score_token.score_id:
|
||||||
|
score = (
|
||||||
|
await db.exec(
|
||||||
|
select(Score).where(
|
||||||
|
Score.id == score_token.score_id,
|
||||||
|
Score.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if not score:
|
||||||
|
raise HTTPException(status_code=404, detail="Score not found")
|
||||||
|
else:
|
||||||
|
beatmap_status = (
|
||||||
|
await db.exec(select(Beatmap.beatmap_status).where(Beatmap.id == beatmap))
|
||||||
|
).first()
|
||||||
|
if beatmap_status is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||||
|
ranked = beatmap_status in {
|
||||||
|
BeatmapRankStatus.RANKED,
|
||||||
|
BeatmapRankStatus.APPROVED,
|
||||||
|
}
|
||||||
|
score = await process_score(
|
||||||
|
current_user,
|
||||||
|
beatmap,
|
||||||
|
ranked,
|
||||||
|
score_token,
|
||||||
|
info,
|
||||||
|
fetcher,
|
||||||
|
db,
|
||||||
|
redis,
|
||||||
|
)
|
||||||
|
await db.refresh(current_user)
|
||||||
|
score_id = score.id
|
||||||
|
score_token.score_id = score_id
|
||||||
|
await process_user(db, current_user, score, ranked)
|
||||||
|
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||||
|
assert score is not None
|
||||||
|
return await ScoreResp.from_db(db, score)
|
||||||
|
|
||||||
|
|
||||||
class BeatmapScores(BaseModel):
|
class BeatmapScores(BaseModel):
|
||||||
scores: list[ScoreResp]
|
scores: list[ScoreResp]
|
||||||
userScore: ScoreResp | None = None
|
userScore: ScoreResp | None = None
|
||||||
@@ -97,9 +169,10 @@ async def get_user_beatmap_score(
|
|||||||
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
|
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
resp = await ScoreResp.from_db(db, user_score)
|
||||||
return BeatmapUserScore(
|
return BeatmapUserScore(
|
||||||
position=user_score.position if user_score.position is not None else 0,
|
position=resp.rank_global or 0,
|
||||||
score=await ScoreResp.from_db(db, user_score),
|
score=resp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -173,55 +246,95 @@ async def submit_solo_score(
|
|||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
fetcher=Depends(get_fetcher),
|
fetcher=Depends(get_fetcher),
|
||||||
):
|
):
|
||||||
if not info.passed:
|
return await submit_score(info, beatmap, token, current_user, db, redis, fetcher)
|
||||||
info.rank = Rank.F
|
|
||||||
async with db:
|
|
||||||
score_token = (
|
@router.post(
|
||||||
await db.exec(
|
"/rooms/{room_id}/playlist/{playlist_id}/scores", response_model=ScoreTokenResp
|
||||||
select(ScoreToken)
|
)
|
||||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
async def create_playlist_score(
|
||||||
.where(ScoreToken.id == token, ScoreToken.user_id == current_user.id)
|
room_id: int,
|
||||||
|
playlist_id: int,
|
||||||
|
beatmap_id: int = Form(),
|
||||||
|
beatmap_hash: str = Form(),
|
||||||
|
ruleset_id: int = Form(..., ge=0, le=3),
|
||||||
|
version_hash: str = Form(""),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
room = MultiplayerHubs.rooms[room_id]
|
||||||
|
if not room:
|
||||||
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
item = (
|
||||||
|
await session.exec(
|
||||||
|
select(Playlist).where(
|
||||||
|
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||||
)
|
)
|
||||||
).first()
|
)
|
||||||
if not score_token or score_token.user_id != current_user.id:
|
).first()
|
||||||
raise HTTPException(status_code=404, detail="Score token not found")
|
if not item:
|
||||||
if score_token.score_id:
|
raise HTTPException(status_code=404, detail="Playlist not found")
|
||||||
score = (
|
|
||||||
await db.exec(
|
# validate
|
||||||
select(Score).where(
|
if not item.freestyle:
|
||||||
Score.id == score_token.score_id,
|
if item.ruleset_id != ruleset_id:
|
||||||
Score.user_id == current_user.id,
|
raise HTTPException(
|
||||||
)
|
status_code=400, detail="Ruleset mismatch in playlist item"
|
||||||
)
|
|
||||||
).first()
|
|
||||||
if not score:
|
|
||||||
raise HTTPException(status_code=404, detail="Score not found")
|
|
||||||
else:
|
|
||||||
beatmap_status = (
|
|
||||||
await db.exec(
|
|
||||||
select(Beatmap.beatmap_status).where(Beatmap.id == beatmap)
|
|
||||||
)
|
|
||||||
).first()
|
|
||||||
if beatmap_status is None:
|
|
||||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
|
||||||
ranked = beatmap_status in {
|
|
||||||
BeatmapRankStatus.RANKED,
|
|
||||||
BeatmapRankStatus.APPROVED,
|
|
||||||
}
|
|
||||||
score = await process_score(
|
|
||||||
current_user,
|
|
||||||
beatmap,
|
|
||||||
ranked,
|
|
||||||
score_token,
|
|
||||||
info,
|
|
||||||
fetcher,
|
|
||||||
db,
|
|
||||||
redis,
|
|
||||||
)
|
)
|
||||||
await db.refresh(current_user)
|
if item.beatmap_id != beatmap_id:
|
||||||
score_id = score.id
|
raise HTTPException(
|
||||||
score_token.score_id = score_id
|
status_code=400, detail="Beatmap ID mismatch in playlist item"
|
||||||
await process_user(db, current_user, score, ranked)
|
)
|
||||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
# TODO: max attempts
|
||||||
assert score is not None
|
if item.expired:
|
||||||
return await ScoreResp.from_db(db, score)
|
raise HTTPException(status_code=400, detail="Playlist item has expired")
|
||||||
|
if item.played_at:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Playlist item has already been played"
|
||||||
|
)
|
||||||
|
# 这里应该不用验证mod了吧。。。
|
||||||
|
|
||||||
|
score_token = ScoreToken(
|
||||||
|
user_id=current_user.id,
|
||||||
|
beatmap_id=beatmap_id,
|
||||||
|
ruleset_id=INT_TO_MODE[ruleset_id],
|
||||||
|
playlist_item_id=playlist_id,
|
||||||
|
)
|
||||||
|
session.add(score_token)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(score_token)
|
||||||
|
return ScoreTokenResp.from_db(score_token)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/rooms/{room_id}/playlist/{playlist_id}/scores/{token}")
|
||||||
|
async def submit_playlist_score(
|
||||||
|
room_id: int,
|
||||||
|
playlist_id: int,
|
||||||
|
token: int,
|
||||||
|
info: SoloScoreSubmissionInfo,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
redis: Redis = Depends(get_redis),
|
||||||
|
fetcher: Fetcher = Depends(get_fetcher),
|
||||||
|
):
|
||||||
|
item = (
|
||||||
|
await session.exec(
|
||||||
|
select(Playlist).where(
|
||||||
|
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if not item:
|
||||||
|
raise HTTPException(status_code=404, detail="Playlist item not found")
|
||||||
|
score_resp = await submit_score(
|
||||||
|
info,
|
||||||
|
item.beatmap_id,
|
||||||
|
token,
|
||||||
|
current_user,
|
||||||
|
session,
|
||||||
|
redis,
|
||||||
|
fetcher,
|
||||||
|
item.id,
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
return score_resp
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import timedelta
|
||||||
from typing import override
|
from typing import override
|
||||||
|
|
||||||
from app.database import Room
|
from app.database import Room
|
||||||
@@ -8,8 +10,11 @@ from app.database.playlists import Playlist
|
|||||||
from app.dependencies.database import engine
|
from app.dependencies.database import engine
|
||||||
from app.exception import InvokeException
|
from app.exception import InvokeException
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
|
from app.models.mods import APIMod
|
||||||
from app.models.multiplayer_hub import (
|
from app.models.multiplayer_hub import (
|
||||||
BeatmapAvailability,
|
BeatmapAvailability,
|
||||||
|
ForceGameplayStartCountdown,
|
||||||
|
MatchServerEvent,
|
||||||
MultiplayerClientState,
|
MultiplayerClientState,
|
||||||
MultiplayerQueue,
|
MultiplayerQueue,
|
||||||
MultiplayerRoom,
|
MultiplayerRoom,
|
||||||
@@ -17,16 +22,22 @@ from app.models.multiplayer_hub import (
|
|||||||
PlaylistItem,
|
PlaylistItem,
|
||||||
ServerMultiplayerRoom,
|
ServerMultiplayerRoom,
|
||||||
)
|
)
|
||||||
from app.models.room import RoomCategory, RoomStatus
|
from app.models.room import (
|
||||||
|
DownloadState,
|
||||||
|
MultiplayerRoomState,
|
||||||
|
MultiplayerUserState,
|
||||||
|
RoomCategory,
|
||||||
|
RoomStatus,
|
||||||
|
)
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
from app.models.signalr import serialize_to_list
|
|
||||||
|
|
||||||
from .hub import Client, Hub
|
from .hub import Client, Hub
|
||||||
|
|
||||||
from msgpack_lazer_api import APIMod
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
GAMEPLAY_LOAD_TIMEOUT = 30
|
||||||
|
|
||||||
|
|
||||||
class MultiplayerHub(Hub[MultiplayerClientState]):
|
class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||||
@override
|
@override
|
||||||
@@ -58,7 +69,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
type=room.settings.match_type,
|
type=room.settings.match_type,
|
||||||
queue_mode=room.settings.queue_mode,
|
queue_mode=room.settings.queue_mode,
|
||||||
auto_skip=room.settings.auto_skip,
|
auto_skip=room.settings.auto_skip,
|
||||||
auto_start_duration=room.settings.auto_start_duration,
|
auto_start_duration=int(
|
||||||
|
room.settings.auto_start_duration.total_seconds()
|
||||||
|
),
|
||||||
host_id=client.user_id,
|
host_id=client.user_id,
|
||||||
status=RoomStatus.IDLE,
|
status=RoomStatus.IDLE,
|
||||||
)
|
)
|
||||||
@@ -75,10 +88,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
category=RoomCategory.NORMAL,
|
category=RoomCategory.NORMAL,
|
||||||
status=RoomStatus.IDLE,
|
status=RoomStatus.IDLE,
|
||||||
start_at=starts_at,
|
start_at=starts_at,
|
||||||
|
hub=self,
|
||||||
)
|
)
|
||||||
queue = MultiplayerQueue(
|
queue = MultiplayerQueue(
|
||||||
room=server_room,
|
room=server_room,
|
||||||
hub=self,
|
|
||||||
)
|
)
|
||||||
server_room.queue = queue
|
server_room.queue = queue
|
||||||
self.rooms[room.room_id] = server_room
|
self.rooms[room.room_id] = server_room
|
||||||
@@ -86,6 +99,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
client, room.room_id, room.settings.password
|
client, room.room_id, room.settings.password
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def JoinRoom(self, client: Client, room_id: int):
|
||||||
|
return self.JoinRoomWithPassword(client, room_id, "")
|
||||||
|
|
||||||
async def JoinRoomWithPassword(self, client: Client, room_id: int, password: str):
|
async def JoinRoomWithPassword(self, client: Client, room_id: int, password: str):
|
||||||
logger.info(f"[MultiplayerHub] {client.user_id} joining room {room_id}")
|
logger.info(f"[MultiplayerHub] {client.user_id} joining room {room_id}")
|
||||||
store = self.get_or_create_state(client)
|
store = self.get_or_create_state(client)
|
||||||
@@ -105,12 +121,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
# from CreateRoom
|
# from CreateRoom
|
||||||
room.host = user
|
room.host = user
|
||||||
store.room_id = room_id
|
store.room_id = room_id
|
||||||
await self.broadcast_group_call(
|
await self.broadcast_group_call(self.group_id(room_id), "UserJoined", user)
|
||||||
self.group_id(room_id), "UserJoined", serialize_to_list(user)
|
|
||||||
)
|
|
||||||
room.users.append(user)
|
room.users.append(user)
|
||||||
self.add_to_group(client, self.group_id(room_id))
|
self.add_to_group(client, self.group_id(room_id))
|
||||||
return serialize_to_list(room)
|
return room
|
||||||
|
|
||||||
async def ChangeBeatmapAvailability(
|
async def ChangeBeatmapAvailability(
|
||||||
self, client: Client, beatmap_availability: BeatmapAvailability
|
self, client: Client, beatmap_availability: BeatmapAvailability
|
||||||
@@ -132,12 +146,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
and availability.progress == beatmap_availability.progress
|
and availability.progress == beatmap_availability.progress
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
user.availability = availability
|
user.availability = beatmap_availability
|
||||||
await self.broadcast_group_call(
|
await self.broadcast_group_call(
|
||||||
self.group_id(store.room_id),
|
self.group_id(store.room_id),
|
||||||
"UserBeatmapAvailabilityChanged",
|
"UserBeatmapAvailabilityChanged",
|
||||||
user.user_id,
|
user.user_id,
|
||||||
serialize_to_list(beatmap_availability),
|
(beatmap_availability),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def AddPlaylistItem(self, client: Client, item: PlaylistItem):
|
async def AddPlaylistItem(self, client: Client, item: PlaylistItem):
|
||||||
@@ -198,14 +212,14 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
await self.broadcast_group_call(
|
await self.broadcast_group_call(
|
||||||
self.group_id(room.room.room_id),
|
self.group_id(room.room.room_id),
|
||||||
"SettingsChanged",
|
"SettingsChanged",
|
||||||
serialize_to_list(room.room.settings),
|
(room.room.settings),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem):
|
async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem):
|
||||||
await self.broadcast_group_call(
|
await self.broadcast_group_call(
|
||||||
self.group_id(room.room.room_id),
|
self.group_id(room.room.room_id),
|
||||||
"PlaylistItemAdded",
|
"PlaylistItemAdded",
|
||||||
serialize_to_list(item),
|
(item),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int):
|
async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int):
|
||||||
@@ -221,7 +235,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
await self.broadcast_group_call(
|
await self.broadcast_group_call(
|
||||||
self.group_id(room.room.room_id),
|
self.group_id(room.room.room_id),
|
||||||
"PlaylistItemChanged",
|
"PlaylistItemChanged",
|
||||||
serialize_to_list(item),
|
(item),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def ChangeUserStyle(
|
async def ChangeUserStyle(
|
||||||
@@ -378,7 +392,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
)
|
)
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
incompatible_mods = [
|
incompatible_mods = [
|
||||||
mod.acronym for mod in new_mods if mod not in valid_mods
|
mod["acronym"] for mod in new_mods if mod not in valid_mods
|
||||||
]
|
]
|
||||||
raise InvokeException(
|
raise InvokeException(
|
||||||
f"Incompatible mods were selected: {','.join(incompatible_mods)}"
|
f"Incompatible mods were selected: {','.join(incompatible_mods)}"
|
||||||
@@ -395,3 +409,221 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
|||||||
user.user_id,
|
user.user_id,
|
||||||
valid_mods,
|
valid_mods,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def validate_user_stare(
|
||||||
|
self,
|
||||||
|
room: ServerMultiplayerRoom,
|
||||||
|
old: MultiplayerUserState,
|
||||||
|
new: MultiplayerUserState,
|
||||||
|
):
|
||||||
|
assert room.queue
|
||||||
|
match new:
|
||||||
|
case MultiplayerUserState.IDLE:
|
||||||
|
if old.is_playing:
|
||||||
|
raise InvokeException(
|
||||||
|
"Cannot return to idle without aborting gameplay."
|
||||||
|
)
|
||||||
|
case MultiplayerUserState.READY:
|
||||||
|
if old != MultiplayerUserState.IDLE:
|
||||||
|
raise InvokeException(f"Cannot change state from {old} to {new}")
|
||||||
|
if room.queue.current_item.expired:
|
||||||
|
raise InvokeException(
|
||||||
|
"Cannot ready up while all items have been played."
|
||||||
|
)
|
||||||
|
case MultiplayerUserState.WAITING_FOR_LOAD:
|
||||||
|
raise InvokeException("Cannot change state from {old} to {new}")
|
||||||
|
case MultiplayerUserState.LOADED:
|
||||||
|
if old != MultiplayerUserState.WAITING_FOR_LOAD:
|
||||||
|
raise InvokeException(f"Cannot change state from {old} to {new}")
|
||||||
|
case MultiplayerUserState.READY_FOR_GAMEPLAY:
|
||||||
|
if old != MultiplayerUserState.LOADED:
|
||||||
|
raise InvokeException(f"Cannot change state from {old} to {new}")
|
||||||
|
case MultiplayerUserState.PLAYING:
|
||||||
|
raise InvokeException("State is managed by the server.")
|
||||||
|
case MultiplayerUserState.FINISHED_PLAY:
|
||||||
|
if old != MultiplayerUserState.PLAYING:
|
||||||
|
raise InvokeException(f"Cannot change state from {old} to {new}")
|
||||||
|
case MultiplayerUserState.RESULTS:
|
||||||
|
raise InvokeException("Cannot change state from {old} to {new}")
|
||||||
|
case MultiplayerUserState.SPECTATING:
|
||||||
|
if old not in (MultiplayerUserState.IDLE, MultiplayerUserState.READY):
|
||||||
|
raise InvokeException(f"Cannot change state from {old} to {new}")
|
||||||
|
|
||||||
|
async def ChangeState(self, client: Client, state: MultiplayerUserState):
|
||||||
|
store = self.get_or_create_state(client)
|
||||||
|
if store.room_id == 0:
|
||||||
|
raise InvokeException("You are not in a room")
|
||||||
|
if store.room_id not in self.rooms:
|
||||||
|
raise InvokeException("Room does not exist")
|
||||||
|
server_room = self.rooms[store.room_id]
|
||||||
|
room = server_room.room
|
||||||
|
user = next((u for u in room.users if u.user_id == client.user_id), None)
|
||||||
|
if user is None:
|
||||||
|
raise InvokeException("You are not in this room")
|
||||||
|
|
||||||
|
if user.state == state:
|
||||||
|
return
|
||||||
|
match state:
|
||||||
|
case MultiplayerUserState.IDLE:
|
||||||
|
if user.state.is_playing:
|
||||||
|
return
|
||||||
|
case MultiplayerUserState.LOADED | MultiplayerUserState.READY_FOR_GAMEPLAY:
|
||||||
|
if not user.state.is_playing:
|
||||||
|
return
|
||||||
|
await self.validate_user_stare(
|
||||||
|
server_room,
|
||||||
|
user.state,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
await self.change_user_state(server_room, user, state)
|
||||||
|
if state == MultiplayerUserState.SPECTATING and (
|
||||||
|
room.state == MultiplayerRoomState.PLAYING
|
||||||
|
or room.state == MultiplayerRoomState.WAITING_FOR_LOAD
|
||||||
|
):
|
||||||
|
await self.call_noblock(client, "LoadRequested")
|
||||||
|
await self.update_room_state(server_room)
|
||||||
|
|
||||||
|
async def change_user_state(
|
||||||
|
self,
|
||||||
|
room: ServerMultiplayerRoom,
|
||||||
|
user: MultiplayerRoomUser,
|
||||||
|
state: MultiplayerUserState,
|
||||||
|
):
|
||||||
|
user.state = state
|
||||||
|
await self.broadcast_group_call(
|
||||||
|
self.group_id(room.room.room_id),
|
||||||
|
"UserStateChanged",
|
||||||
|
user.user_id,
|
||||||
|
user.state,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_room_state(self, room: ServerMultiplayerRoom):
|
||||||
|
match room.room.state:
|
||||||
|
case MultiplayerRoomState.WAITING_FOR_LOAD:
|
||||||
|
played_count = len(
|
||||||
|
[True for user in room.room.users if user.state.is_playing]
|
||||||
|
)
|
||||||
|
ready_count = len(
|
||||||
|
[
|
||||||
|
True
|
||||||
|
for user in room.room.users
|
||||||
|
if user.state == MultiplayerUserState.READY_FOR_GAMEPLAY
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if played_count == ready_count:
|
||||||
|
await self.start_gameplay(room)
|
||||||
|
case MultiplayerRoomState.PLAYING:
|
||||||
|
assert room.queue
|
||||||
|
if all(
|
||||||
|
u.state != MultiplayerUserState.PLAYING for u in room.room.users
|
||||||
|
):
|
||||||
|
for u in filter(
|
||||||
|
lambda u: u.state == MultiplayerUserState.FINISHED_PLAY,
|
||||||
|
room.room.users,
|
||||||
|
):
|
||||||
|
await self.change_user_state(
|
||||||
|
room, u, MultiplayerUserState.RESULTS
|
||||||
|
)
|
||||||
|
await self.change_room_state(room, MultiplayerRoomState.OPEN)
|
||||||
|
await self.broadcast_group_call(
|
||||||
|
self.group_id(room.room.room_id),
|
||||||
|
"ResultsReady",
|
||||||
|
)
|
||||||
|
await room.queue.finish_current_item()
|
||||||
|
|
||||||
|
async def change_room_state(
|
||||||
|
self, room: ServerMultiplayerRoom, state: MultiplayerRoomState
|
||||||
|
):
|
||||||
|
room.room.state = state
|
||||||
|
await self.broadcast_group_call(
|
||||||
|
self.group_id(room.room.room_id),
|
||||||
|
"RoomStateChanged",
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def StartMatch(self, client: Client):
|
||||||
|
store = self.get_or_create_state(client)
|
||||||
|
if store.room_id == 0:
|
||||||
|
raise InvokeException("You are not in a room")
|
||||||
|
if store.room_id not in self.rooms:
|
||||||
|
raise InvokeException("Room does not exist")
|
||||||
|
server_room = self.rooms[store.room_id]
|
||||||
|
room = server_room.room
|
||||||
|
user = next((u for u in room.users if u.user_id == client.user_id), None)
|
||||||
|
if user is None:
|
||||||
|
raise InvokeException("You are not in this room")
|
||||||
|
if room.host is None or room.host.user_id != client.user_id:
|
||||||
|
raise InvokeException("You are not the host of this room")
|
||||||
|
if any(u.state != MultiplayerUserState.READY for u in room.users):
|
||||||
|
raise InvokeException("Not all users are ready")
|
||||||
|
|
||||||
|
await self.start_match(server_room)
|
||||||
|
|
||||||
|
async def start_match(self, room: ServerMultiplayerRoom):
|
||||||
|
assert room.queue
|
||||||
|
if room.room.state != MultiplayerRoomState.OPEN:
|
||||||
|
raise InvokeException("Can't start match when already in a running state.")
|
||||||
|
if room.queue.current_item.expired:
|
||||||
|
raise InvokeException("Current playlist item is expired")
|
||||||
|
ready_users = [
|
||||||
|
u
|
||||||
|
for u in room.room.users
|
||||||
|
if u.availability.state == DownloadState.LOCALLY_AVAILABLE
|
||||||
|
and (
|
||||||
|
u.state == MultiplayerUserState.READY
|
||||||
|
or u.state == MultiplayerUserState.IDLE
|
||||||
|
)
|
||||||
|
]
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
self.change_user_state(room, u, MultiplayerUserState.WAITING_FOR_LOAD)
|
||||||
|
for u in ready_users
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await self.change_room_state(
|
||||||
|
room,
|
||||||
|
MultiplayerRoomState.WAITING_FOR_LOAD,
|
||||||
|
)
|
||||||
|
await self.broadcast_group_call(
|
||||||
|
self.group_id(room.room.room_id),
|
||||||
|
"LoadRequested",
|
||||||
|
)
|
||||||
|
await room.start_countdown(
|
||||||
|
ForceGameplayStartCountdown(
|
||||||
|
remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT)
|
||||||
|
),
|
||||||
|
self.start_gameplay,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def start_gameplay(self, room: ServerMultiplayerRoom):
|
||||||
|
assert room.queue
|
||||||
|
if room.room.state != MultiplayerRoomState.WAITING_FOR_LOAD:
|
||||||
|
raise InvokeException("Room is not ready for gameplay")
|
||||||
|
if room.queue.current_item.expired:
|
||||||
|
raise InvokeException("Current playlist item is expired")
|
||||||
|
playing = False
|
||||||
|
for user in room.room.users:
|
||||||
|
client = self.get_client_by_id(str(user.user_id))
|
||||||
|
if client is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user.state in (
|
||||||
|
MultiplayerUserState.READY_FOR_GAMEPLAY,
|
||||||
|
MultiplayerUserState.LOADED,
|
||||||
|
):
|
||||||
|
playing = True
|
||||||
|
await self.change_user_state(room, user, MultiplayerUserState.PLAYING)
|
||||||
|
await self.call_noblock(client, "GameplayStarted")
|
||||||
|
await self.change_room_state(
|
||||||
|
room,
|
||||||
|
(MultiplayerRoomState.PLAYING if playing else MultiplayerRoomState.OPEN),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_match_event(
|
||||||
|
self, room: ServerMultiplayerRoom, event: MatchServerEvent
|
||||||
|
):
|
||||||
|
await self.broadcast_group_call(
|
||||||
|
self.group_id(room.room.room_id),
|
||||||
|
"MatchEvent",
|
||||||
|
event,
|
||||||
|
)
|
||||||
|
|||||||
@@ -97,6 +97,8 @@ class MsgpackProtocol:
|
|||||||
return [cls.serialize_msgpack(item) for item in v]
|
return [cls.serialize_msgpack(item) for item in v]
|
||||||
elif issubclass(typ, datetime.datetime):
|
elif issubclass(typ, datetime.datetime):
|
||||||
return [v, 0]
|
return [v, 0]
|
||||||
|
elif issubclass(typ, datetime.timedelta):
|
||||||
|
return int(v.total_seconds())
|
||||||
elif isinstance(v, dict):
|
elif isinstance(v, dict):
|
||||||
return {
|
return {
|
||||||
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
|
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
|
||||||
@@ -213,6 +215,8 @@ class MsgpackProtocol:
|
|||||||
return typ.model_validate(obj=cls.process_object(v, typ))
|
return typ.model_validate(obj=cls.process_object(v, typ))
|
||||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||||
return v[0]
|
return v[0]
|
||||||
|
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
|
||||||
|
return datetime.timedelta(seconds=int(v))
|
||||||
elif isinstance(v, list):
|
elif isinstance(v, list):
|
||||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||||
@@ -296,21 +300,30 @@ class MsgpackProtocol:
|
|||||||
|
|
||||||
class JSONProtocol:
|
class JSONProtocol:
|
||||||
@classmethod
|
@classmethod
|
||||||
def serialize_to_json(cls, v: Any):
|
def serialize_to_json(cls, v: Any, dict_key: bool = False):
|
||||||
typ = v.__class__
|
typ = v.__class__
|
||||||
if issubclass(typ, BaseModel):
|
if issubclass(typ, BaseModel):
|
||||||
return cls.serialize_model(v)
|
return cls.serialize_model(v)
|
||||||
elif isinstance(v, dict):
|
elif isinstance(v, dict):
|
||||||
return {
|
return {
|
||||||
cls.serialize_to_json(k): cls.serialize_to_json(value)
|
cls.serialize_to_json(k, True): cls.serialize_to_json(value)
|
||||||
for k, value in v.items()
|
for k, value in v.items()
|
||||||
}
|
}
|
||||||
elif isinstance(v, list):
|
elif isinstance(v, list):
|
||||||
return [cls.serialize_to_json(item) for item in v]
|
return [cls.serialize_to_json(item) for item in v]
|
||||||
elif isinstance(v, datetime.datetime):
|
elif isinstance(v, datetime.datetime):
|
||||||
return v.isoformat()
|
return v.isoformat()
|
||||||
elif isinstance(v, Enum):
|
elif isinstance(v, datetime.timedelta):
|
||||||
|
# d.hh:mm:ss
|
||||||
|
total_seconds = int(v.total_seconds())
|
||||||
|
hours, remainder = divmod(total_seconds, 3600)
|
||||||
|
minutes, seconds = divmod(remainder, 60)
|
||||||
|
return f"{hours:02}:{minutes:02}:{seconds:02}"
|
||||||
|
elif isinstance(v, Enum) and dict_key:
|
||||||
return v.value
|
return v.value
|
||||||
|
elif isinstance(v, Enum):
|
||||||
|
list_ = list(typ)
|
||||||
|
return list_.index(v)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -322,9 +335,13 @@ class JSONProtocol:
|
|||||||
)
|
)
|
||||||
if metadata and metadata.json_ignore:
|
if metadata and metadata.json_ignore:
|
||||||
continue
|
continue
|
||||||
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = (
|
d[
|
||||||
cls.serialize_to_json(getattr(v, field))
|
snake_to_camel(
|
||||||
)
|
field,
|
||||||
|
metadata.use_upper_case if metadata else False,
|
||||||
|
metadata.use_abbr if metadata else True,
|
||||||
|
)
|
||||||
|
] = cls.serialize_to_json(getattr(v, field))
|
||||||
if issubclass(v.__class__, SignalRUnionMessage):
|
if issubclass(v.__class__, SignalRUnionMessage):
|
||||||
return {
|
return {
|
||||||
"$dtype": v.__class__.__name__,
|
"$dtype": v.__class__.__name__,
|
||||||
@@ -343,7 +360,11 @@ class JSONProtocol:
|
|||||||
)
|
)
|
||||||
if metadata and metadata.json_ignore:
|
if metadata and metadata.json_ignore:
|
||||||
continue
|
continue
|
||||||
value = v.get(snake_to_camel(field, not from_union))
|
value = v.get(
|
||||||
|
snake_to_camel(
|
||||||
|
field, not from_union, metadata.use_abbr if metadata else True
|
||||||
|
)
|
||||||
|
)
|
||||||
anno = typ.model_fields[field].annotation
|
anno = typ.model_fields[field].annotation
|
||||||
if anno is None:
|
if anno is None:
|
||||||
d[field] = value
|
d[field] = value
|
||||||
@@ -401,6 +422,17 @@ class JSONProtocol:
|
|||||||
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
|
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
|
||||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||||
return datetime.datetime.fromisoformat(v)
|
return datetime.datetime.fromisoformat(v)
|
||||||
|
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
|
||||||
|
# d.hh:mm:ss
|
||||||
|
parts = v.split(":")
|
||||||
|
if len(parts) == 3:
|
||||||
|
return datetime.timedelta(
|
||||||
|
hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2])
|
||||||
|
)
|
||||||
|
elif len(parts) == 2:
|
||||||
|
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
|
||||||
|
elif len(parts) == 1:
|
||||||
|
return datetime.timedelta(seconds=int(parts[0]))
|
||||||
elif isinstance(v, list):
|
elif isinstance(v, list):
|
||||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def camel_to_snake(name: str) -> str:
|
|||||||
return "".join(result)
|
return "".join(result)
|
||||||
|
|
||||||
|
|
||||||
def snake_to_camel(name: str, lower_case: bool = True) -> str:
|
def snake_to_camel(name: str, lower_case: bool = True, use_abbr: bool = True) -> str:
|
||||||
"""Convert a snake_case string to camelCase."""
|
"""Convert a snake_case string to camelCase."""
|
||||||
if not name:
|
if not name:
|
||||||
return name
|
return name
|
||||||
@@ -47,7 +47,7 @@ def snake_to_camel(name: str, lower_case: bool = True) -> str:
|
|||||||
|
|
||||||
result = []
|
result = []
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if part.lower() in abbreviations:
|
if part.lower() in abbreviations and use_abbr:
|
||||||
result.append(part.upper())
|
result.append(part.upper())
|
||||||
else:
|
else:
|
||||||
if result or not lower_case:
|
if result or not lower_case:
|
||||||
|
|||||||
Reference in New Issue
Block a user