From 722a6e57d809e3cbef389c13ad15d40d92c612ab Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Mon, 28 Jul 2025 05:52:48 +0000 Subject: [PATCH] feat(spectator): support spectate solo player --- app/models/metadata_hub.py | 28 +++++----- app/models/signalr.py | 39 +++++++++++++- app/models/spectator_hub.py | 38 ++++++------- app/signalr/hub/hub.py | 17 +++--- app/signalr/hub/metadata.py | 10 ++-- app/signalr/hub/spectator.py | 84 ++++++++++++++++++++++++++--- app/signalr/packet.py | 100 ++++++++++++++++++++--------------- 7 files changed, 213 insertions(+), 103 deletions(-) diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py index 24a29b5..615ea9b 100644 --- a/app/models/metadata_hub.py +++ b/app/models/metadata_hub.py @@ -28,8 +28,7 @@ class _UserActivity(BaseModel): class ChoosingBeatmap(_UserActivity): - type: Literal["ChoosingBeatmap"] = "ChoosingBeatmap" - value: Literal[None] = None + type: Literal["ChoosingBeatmap"] = Field(alias="$dtype") class InGameValue(BaseModel): @@ -44,19 +43,19 @@ class _InGame(_UserActivity): class InSoloGame(_InGame): - type: Literal["InSoloGame"] = "InSoloGame" + type: Literal["InSoloGame"] = Field(alias="$dtype") class InMultiplayerGame(_InGame): - type: Literal["InMultiplayerGame"] = "InMultiplayerGame" + type: Literal["InMultiplayerGame"] = Field(alias="$dtype") class SpectatingMultiplayerGame(_InGame): - type: Literal["SpectatingMultiplayerGame"] = "SpectatingMultiplayerGame" + type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype") class InPlaylistGame(_InGame): - type: Literal["InPlaylistGame"] = "InPlaylistGame" + type: Literal["InPlaylistGame"] = Field(alias="$dtype") class EditingBeatmapValue(BaseModel): @@ -65,16 +64,16 @@ class EditingBeatmapValue(BaseModel): class EditingBeatmap(_UserActivity): - type: Literal["EditingBeatmap"] = "EditingBeatmap" + type: Literal["EditingBeatmap"] = Field(alias="$dtype") value: EditingBeatmapValue = Field(alias="$value") class TestingBeatmap(_UserActivity): - type: Literal["TestingBeatmap"] = "TestingBeatmap" + type: Literal["TestingBeatmap"] = Field(alias="$dtype") class ModdingBeatmap(_UserActivity): - type: Literal["ModdingBeatmap"] = "ModdingBeatmap" + type: Literal["ModdingBeatmap"] = Field(alias="$dtype") class WatchingReplayValue(BaseModel): @@ -85,17 +84,16 @@ class WatchingReplayValue(BaseModel): class WatchingReplay(_UserActivity): - type: Literal["WatchingReplay"] = "WatchingReplay" + type: Literal["WatchingReplay"] = Field(alias="$dtype") value: int | None = Field(alias="$value") # Replay ID class SpectatingUser(WatchingReplay): - type: Literal["SpectatingUser"] = "SpectatingUser" + type: Literal["SpectatingUser"] = Field(alias="$dtype") class SearchingForLobby(_UserActivity): - type: Literal["SearchingForLobby"] = "SearchingForLobby" - value: None = Field(alias="$value") + type: Literal["SearchingForLobby"] = Field(alias="$dtype") class InLobbyValue(BaseModel): @@ -105,12 +103,10 @@ class InLobbyValue(BaseModel): class InLobby(_UserActivity): type: Literal["InLobby"] = "InLobby" - value: None = Field(alias="$value") class InDailyChallengeLobby(_UserActivity): - type: Literal["InDailyChallengeLobby"] = "InDailyChallengeLobby" - value: None = Field(alias="$value") + type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype") UserActivity = ( diff --git a/app/models/signalr.py b/app/models/signalr.py index 49db11f..ac8475f 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -1,11 +1,42 @@ from __future__ import annotations -from typing import Any +import datetime +from typing import Any, get_origin -from pydantic import BaseModel, Field, model_validator +import msgpack +from pydantic import ( + BaseModel, + ConfigDict, + Field, + TypeAdapter, + model_serializer, + model_validator, +) + + +def serialize_to_list(value: BaseModel) -> list[Any]: + data = [] + for field, info in value.__class__.model_fields.items(): + v = getattr(value, field) + anno = get_origin(info.annotation) + if anno and issubclass(anno, BaseModel): + data.append(serialize_to_list(v)) + elif anno and issubclass(anno, list): + data.append( + TypeAdapter( + info.annotation, + ).dump_python(v) + ) + elif isinstance(v, datetime.datetime): + data.append([msgpack.ext.Timestamp.from_datetime(v), 0]) + else: + data.append(v) + return data class MessagePackArrayModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + @model_validator(mode="before") @classmethod def unpack(cls, v: Any) -> Any: @@ -16,6 +47,10 @@ class MessagePackArrayModel(BaseModel): return dict(zip(fields, v)) return v + @model_serializer + def serialize(self) -> list[Any]: + return serialize_to_list(self) + class Transport(BaseModel): transport: str diff --git a/app/models/spectator_hub.py b/app/models/spectator_hub.py index fe95930..820eb16 100644 --- a/app/models/spectator_hub.py +++ b/app/models/spectator_hub.py @@ -12,7 +12,7 @@ from .score import ( from .signalr import MessagePackArrayModel import msgpack -from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator +from pydantic import BaseModel, Field, field_validator class APIMod(MessagePackArrayModel): @@ -58,8 +58,6 @@ class ScoreProcessorStatistics(MessagePackArrayModel): class FrameHeader(MessagePackArrayModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - total_score: int acc: float combo: int @@ -84,25 +82,21 @@ class FrameHeader(MessagePackArrayModel): return datetime.datetime.fromisoformat(v) raise ValueError(f"Cannot convert {type(v)} to datetime") - @field_serializer("received_time") - def serialize_received_time(self, v: datetime.datetime) -> msgpack.ext.Timestamp: - return msgpack.ext.Timestamp.from_datetime(v) - -class ReplayButtonState(IntEnum): - NONE = 0 - LEFT1 = 1 - RIGHT1 = 2 - LEFT2 = 4 - RIGHT2 = 8 - SMOKE = 16 +# class ReplayButtonState(IntEnum): +# NONE = 0 +# LEFT1 = 1 +# RIGHT1 = 2 +# LEFT2 = 4 +# RIGHT2 = 8 +# SMOKE = 16 class LegacyReplayFrame(MessagePackArrayModel): time: float # from ReplayFrame,the parent of LegacyReplayFrame x: float | None = None y: float | None = None - button_state: ReplayButtonState + button_state: int class FrameDataBundle(MessagePackArrayModel): @@ -135,10 +129,10 @@ class StoreScore(BaseModel): class StoreClientState(BaseModel): - state: SpectatorState | None - beatmap_status: BeatmapRankStatus - checksum: str - ruleset_id: int - score_token: int - watched_user: set[int] - score: StoreScore + state: SpectatorState | None = None + beatmap_status: BeatmapRankStatus | None = None + checksum: str | None = None + ruleset_id: int | None = None + score_token: int | None = None + watched_user: set[int] = Field(default_factory=set) + score: StoreScore | None = None diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index 40e75da..e5c807c 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -44,11 +44,11 @@ class Client: async def send_packet(self, packet: Packet): await self.connection.send_bytes(self.procotol.encode(packet)) - async def receive_packet(self) -> Packet: + async def receive_packets(self) -> list[Packet]: message = await self.connection.receive() d = message.get("bytes") or message.get("text", "").encode() if not d: - return PingPacket() # FIXME: Graceful empty message handling + return [PingPacket()] # FIXME: Graceful empty message handling return self.procotol.decode(d) async def _ping(self): @@ -138,12 +138,13 @@ class Hub: jump = False while not jump: try: - packet = await client.receive_packet() - task = asyncio.create_task(self._handle_packet(client, packet)) - self.tasks.add(task) - task.add_done_callback(self.tasks.discard) - except StopIteration: - pass + packets = await client.receive_packets() + for packet in packets: + if isinstance(packet, PingPacket): + continue + task = asyncio.create_task(self._handle_packet(client, packet)) + self.tasks.add(task) + task.add_done_callback(self.tasks.discard) except WebSocketDisconnect as e: print( f"Client {client.connection_id} disconnected: {e.code}, {e.reason}" diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index cfed09c..4229723 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -106,11 +106,12 @@ class MetadataHub(Hub): await asyncio.gather(*tasks) async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None: - if activity_dict is None: - # idle - return user_id = int(client.connection_id) - activity = TypeAdapter(UserActivity).validate_python(activity_dict) + activity = ( + TypeAdapter(UserActivity).validate_python(activity_dict) + if activity_dict + else None + ) store = self.state.get(user_id) if store: store.user_activity = activity @@ -119,7 +120,6 @@ class MetadataHub(Hub): user_activity=activity, ) self.state[user_id] = store - tasks = self.broadcast_tasks(user_id, store) tasks.add( self.call_noblock( diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index b3e6f10..c0be2dd 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import lzma import struct @@ -13,6 +14,7 @@ from app.dependencies.database import engine from app.models.beatmap import BeatmapRankStatus from app.models.mods import mods_to_int from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt +from app.models.signalr import serialize_to_list from app.models.spectator_hub import ( APIUser, FrameDataBundle, @@ -106,7 +108,7 @@ def save_replay( for frame in frames: frame_strs.append( f"{frame.time - last_time}|{frame.x or 0.0}" - f"|{frame.y or 0.0}|{frame.button_state.value}" + f"|{frame.y or 0.0}|{frame.button_state}" ) last_time = frame.time frame_strs.append("-12345|0|0|0") @@ -143,6 +145,20 @@ class SpectatorHub(Hub): super().__init__() self.state: dict[int, StoreClientState] = {} + @staticmethod + def group_id(user_id: int) -> str: + return f"watch:{user_id}" + + async def on_client_connect(self, client: Client) -> None: + tasks = [ + self.call_noblock( + client, "UserBeganPlaying", user_id, serialize_to_list(store.state) + ) + for user_id, store in self.state.items() + if store.state is not None + ] + await asyncio.gather(*tasks) + async def BeginPlaySession( self, client: Client, score_token: int, state: SpectatorState ) -> None: @@ -184,7 +200,12 @@ class SpectatorHub(Hub): ), ) self.state[user_id] = store - await self.broadcast_call("UserBeganPlaying", user_id, state.model_dump()) + await self.broadcast_group_call( + self.group_id(user_id), + "UserBeganPlaying", + user_id, + serialize_to_list(state), + ) async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None: user_id = int(client.connection_id) @@ -201,14 +222,14 @@ class SpectatorHub(Hub): score.score_info.total_score = frame_data.header.total_score score.score_info.mods = frame_data.header.mods score.replay_frames.extend(frame_data.frames) - await self.broadcast_call( + await self.broadcast_group_call( + self.group_id(user_id), "UserSentFrames", user_id, frame_data.model_dump(), ) async def EndPlaySession(self, client: Client, state: SpectatorState) -> None: - print(f"EndPlaySession -> {client.connection_id} {state.model_dump()!r}") user_id = int(client.connection_id) store = self.state.get(user_id) if not store: @@ -217,7 +238,13 @@ class SpectatorHub(Hub): if not score or not store.score_token: return + assert store.beatmap_status is not None + async def _save_replay(): + assert store.checksum is not None + assert store.ruleset_id is not None + assert store.state is not None + assert store.score is not None async with AsyncSession(engine) as session: async with session: start_time = time.time() @@ -271,8 +298,51 @@ class SpectatorHub(Hub): del self.state[user_id] if state.state == SpectatedUserState.Playing: state.state = SpectatedUserState.Quit - await self.broadcast_call( - "UserEndedPlaying", + await self.broadcast_group_call( + self.group_id(user_id), + "UserFinishedPlaying", user_id, - state.model_dump(), + serialize_to_list(state) if state else None, ) + + async def StartWatchingUser(self, client: Client, target_id: int) -> None: + print(f"StartWatchingUser -> {client.connection_id} {target_id}") + user_id = int(client.connection_id) + target_store = self.state.get(target_id) + if target_store and target_store.state: + await self.call_noblock( + client, + "UserBeganPlaying", + target_id, + serialize_to_list(target_store.state), + ) + store = self.state.get(user_id) + if store is None: + store = StoreClientState( + watched_user=set(), + ) + store.watched_user.add(target_id) + self.state[user_id] = store + self.groups.setdefault(self.group_id(target_id), set()).add(client) + + async with AsyncSession(engine) as session: + async with session.begin(): + username = ( + await session.exec(select(User.name).where(User.id == user_id)) + ).first() + if not username: + return + if (target_client := self.get_client_by_id(str(target_id))) is not None: + await self.call_noblock( + target_client, "UserStartedWatching", [[user_id, username]] + ) + + async def EndWatchingUser(self, client: Client, target_id: int) -> None: + print(f"EndWatchingUser -> {client.connection_id} {target_id}") + user_id = int(client.connection_id) + self.groups[self.group_id(target_id)].discard(client) + store = self.state.get(user_id) + if store: + store.watched_user.discard(target_id) + if (target_client := self.get_client_by_id(str(target_id))) is not None: + await self.call_noblock(target_client, "UserEndedWatching", user_id) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 5fa0908..1ff9b83 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -60,7 +60,7 @@ PACKETS = { class Protocol(TypingProtocol): @staticmethod - def decode(input: bytes) -> Packet: ... + def decode(input: bytes) -> list[Packet]: ... @staticmethod def encode(packet: Packet) -> bytes: ... @@ -93,7 +93,7 @@ class MsgpackProtocol: return result, pos @staticmethod - def decode(input: bytes) -> Packet: + def decode(input: bytes) -> list[Packet]: length, offset = MsgpackProtocol._decode_varint(input) message_data = input[offset : offset + length] # FIXME: custom deserializer for APIMod @@ -106,23 +106,27 @@ class MsgpackProtocol: raise ValueError(f"Unknown packet type: {packet_type}") match packet_type: case PacketType.INVOCATION: - return InvocationPacket( - header=unpacked[1], - invocation_id=unpacked[2], - target=unpacked[3], - arguments=unpacked[4] if len(unpacked) > 4 else None, - stream_ids=unpacked[5] if len(unpacked) > 5 else None, - ) + return [ + InvocationPacket( + header=unpacked[1], + invocation_id=unpacked[2], + target=unpacked[3], + arguments=unpacked[4] if len(unpacked) > 4 else None, + stream_ids=unpacked[5] if len(unpacked) > 5 else None, + ) + ] case PacketType.COMPLETION: result_kind = unpacked[3] - return CompletionPacket( - header=unpacked[1], - invocation_id=unpacked[2], - error=unpacked[4] if result_kind == 1 else None, - result=unpacked[5] if result_kind == 3 else None, - ) + return [ + CompletionPacket( + header=unpacked[1], + invocation_id=unpacked[2], + error=unpacked[4] if result_kind == 1 else None, + result=unpacked[5] if result_kind == 3 else None, + ) + ] case PacketType.PING: - return PingPacket() + return [PingPacket()] raise ValueError(f"Unsupported packet type: {packet_type}") @staticmethod @@ -153,38 +157,48 @@ class MsgpackProtocol: ] ) elif isinstance(packet, PingPacket): - pass - - data = msgpack.packb(payload, use_bin_type=True) + payload.pop(-1) + data = msgpack.packb(payload, use_bin_type=True, datetime=True) return MsgpackProtocol._encode_varint(len(data)) + data class JSONProtocol: @staticmethod - def decode(input: bytes) -> Packet: - data = json.loads(input[:-1].decode("utf-8")) - packet_type = PacketType(data["type"]) - if packet_type not in PACKETS: - raise ValueError(f"Unknown packet type: {packet_type}") - match packet_type: - case PacketType.INVOCATION: - return InvocationPacket( - header=data.get("header"), - invocation_id=data.get("invocationId"), - target=data["target"], - arguments=data.get("arguments"), - stream_ids=data.get("streamIds"), - ) - case PacketType.COMPLETION: - return CompletionPacket( - header=data.get("header"), - invocation_id=data["invocationId"], - error=data.get("error"), - result=data.get("result"), - ) - case PacketType.PING: - return PingPacket() - raise ValueError(f"Unsupported packet type: {packet_type}") + def decode(input: bytes) -> list[Packet]: + packets_raw = input.removesuffix(SEP).split(SEP) + packets = [] + if len(packets_raw) > 1: + for packet_raw in packets_raw: + packets.extend(JSONProtocol.decode(packet_raw)) + return packets + else: + data = json.loads(packets_raw[0]) + packet_type = PacketType(data["type"]) + if packet_type not in PACKETS: + raise ValueError(f"Unknown packet type: {packet_type}") + match packet_type: + case PacketType.INVOCATION: + return [ + InvocationPacket( + header=data.get("header"), + invocation_id=data.get("invocationId"), + target=data["target"], + arguments=data.get("arguments"), + stream_ids=data.get("streamIds"), + ) + ] + case PacketType.COMPLETION: + return [ + CompletionPacket( + header=data.get("header"), + invocation_id=data["invocationId"], + error=data.get("error"), + result=data.get("result"), + ) + ] + case PacketType.PING: + return [PingPacket()] + raise ValueError(f"Unsupported packet type: {packet_type}") @staticmethod def encode(packet: Packet) -> bytes: