feat(spectator): support spectate solo player

This commit is contained in:
MingxuanGame
2025-07-28 05:52:48 +00:00
parent 20d528d203
commit 722a6e57d8
7 changed files with 213 additions and 103 deletions

View File

@@ -28,8 +28,7 @@ class _UserActivity(BaseModel):
class ChoosingBeatmap(_UserActivity): class ChoosingBeatmap(_UserActivity):
type: Literal["ChoosingBeatmap"] = "ChoosingBeatmap" type: Literal["ChoosingBeatmap"] = Field(alias="$dtype")
value: Literal[None] = None
class InGameValue(BaseModel): class InGameValue(BaseModel):
@@ -44,19 +43,19 @@ class _InGame(_UserActivity):
class InSoloGame(_InGame): class InSoloGame(_InGame):
type: Literal["InSoloGame"] = "InSoloGame" type: Literal["InSoloGame"] = Field(alias="$dtype")
class InMultiplayerGame(_InGame): class InMultiplayerGame(_InGame):
type: Literal["InMultiplayerGame"] = "InMultiplayerGame" type: Literal["InMultiplayerGame"] = Field(alias="$dtype")
class SpectatingMultiplayerGame(_InGame): class SpectatingMultiplayerGame(_InGame):
type: Literal["SpectatingMultiplayerGame"] = "SpectatingMultiplayerGame" type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype")
class InPlaylistGame(_InGame): class InPlaylistGame(_InGame):
type: Literal["InPlaylistGame"] = "InPlaylistGame" type: Literal["InPlaylistGame"] = Field(alias="$dtype")
class EditingBeatmapValue(BaseModel): class EditingBeatmapValue(BaseModel):
@@ -65,16 +64,16 @@ class EditingBeatmapValue(BaseModel):
class EditingBeatmap(_UserActivity): class EditingBeatmap(_UserActivity):
type: Literal["EditingBeatmap"] = "EditingBeatmap" type: Literal["EditingBeatmap"] = Field(alias="$dtype")
value: EditingBeatmapValue = Field(alias="$value") value: EditingBeatmapValue = Field(alias="$value")
class TestingBeatmap(_UserActivity): class TestingBeatmap(_UserActivity):
type: Literal["TestingBeatmap"] = "TestingBeatmap" type: Literal["TestingBeatmap"] = Field(alias="$dtype")
class ModdingBeatmap(_UserActivity): class ModdingBeatmap(_UserActivity):
type: Literal["ModdingBeatmap"] = "ModdingBeatmap" type: Literal["ModdingBeatmap"] = Field(alias="$dtype")
class WatchingReplayValue(BaseModel): class WatchingReplayValue(BaseModel):
@@ -85,17 +84,16 @@ class WatchingReplayValue(BaseModel):
class WatchingReplay(_UserActivity): class WatchingReplay(_UserActivity):
type: Literal["WatchingReplay"] = "WatchingReplay" type: Literal["WatchingReplay"] = Field(alias="$dtype")
value: int | None = Field(alias="$value") # Replay ID value: int | None = Field(alias="$value") # Replay ID
class SpectatingUser(WatchingReplay): class SpectatingUser(WatchingReplay):
type: Literal["SpectatingUser"] = "SpectatingUser" type: Literal["SpectatingUser"] = Field(alias="$dtype")
class SearchingForLobby(_UserActivity): class SearchingForLobby(_UserActivity):
type: Literal["SearchingForLobby"] = "SearchingForLobby" type: Literal["SearchingForLobby"] = Field(alias="$dtype")
value: None = Field(alias="$value")
class InLobbyValue(BaseModel): class InLobbyValue(BaseModel):
@@ -105,12 +103,10 @@ class InLobbyValue(BaseModel):
class InLobby(_UserActivity): class InLobby(_UserActivity):
type: Literal["InLobby"] = "InLobby" type: Literal["InLobby"] = "InLobby"
value: None = Field(alias="$value")
class InDailyChallengeLobby(_UserActivity): class InDailyChallengeLobby(_UserActivity):
type: Literal["InDailyChallengeLobby"] = "InDailyChallengeLobby" type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype")
value: None = Field(alias="$value")
UserActivity = ( UserActivity = (

View File

@@ -1,11 +1,42 @@
from __future__ import annotations from __future__ import annotations
from typing import Any import datetime
from typing import Any, get_origin
from pydantic import BaseModel, Field, model_validator import msgpack
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
model_serializer,
model_validator,
)
def serialize_to_list(value: BaseModel) -> list[Any]:
data = []
for field, info in value.__class__.model_fields.items():
v = getattr(value, field)
anno = get_origin(info.annotation)
if anno and issubclass(anno, BaseModel):
data.append(serialize_to_list(v))
elif anno and issubclass(anno, list):
data.append(
TypeAdapter(
info.annotation,
).dump_python(v)
)
elif isinstance(v, datetime.datetime):
data.append([msgpack.ext.Timestamp.from_datetime(v), 0])
else:
data.append(v)
return data
class MessagePackArrayModel(BaseModel): class MessagePackArrayModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def unpack(cls, v: Any) -> Any: def unpack(cls, v: Any) -> Any:
@@ -16,6 +47,10 @@ class MessagePackArrayModel(BaseModel):
return dict(zip(fields, v)) return dict(zip(fields, v))
return v return v
@model_serializer
def serialize(self) -> list[Any]:
return serialize_to_list(self)
class Transport(BaseModel): class Transport(BaseModel):
transport: str transport: str

View File

@@ -12,7 +12,7 @@ from .score import (
from .signalr import MessagePackArrayModel from .signalr import MessagePackArrayModel
import msgpack import msgpack
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from pydantic import BaseModel, Field, field_validator
class APIMod(MessagePackArrayModel): class APIMod(MessagePackArrayModel):
@@ -58,8 +58,6 @@ class ScoreProcessorStatistics(MessagePackArrayModel):
class FrameHeader(MessagePackArrayModel): class FrameHeader(MessagePackArrayModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
total_score: int total_score: int
acc: float acc: float
combo: int combo: int
@@ -84,25 +82,21 @@ class FrameHeader(MessagePackArrayModel):
return datetime.datetime.fromisoformat(v) return datetime.datetime.fromisoformat(v)
raise ValueError(f"Cannot convert {type(v)} to datetime") raise ValueError(f"Cannot convert {type(v)} to datetime")
@field_serializer("received_time")
def serialize_received_time(self, v: datetime.datetime) -> msgpack.ext.Timestamp:
return msgpack.ext.Timestamp.from_datetime(v)
# class ReplayButtonState(IntEnum):
class ReplayButtonState(IntEnum): # NONE = 0
NONE = 0 # LEFT1 = 1
LEFT1 = 1 # RIGHT1 = 2
RIGHT1 = 2 # LEFT2 = 4
LEFT2 = 4 # RIGHT2 = 8
RIGHT2 = 8 # SMOKE = 16
SMOKE = 16
class LegacyReplayFrame(MessagePackArrayModel): class LegacyReplayFrame(MessagePackArrayModel):
time: float # from ReplayFrame,the parent of LegacyReplayFrame time: float # from ReplayFrame,the parent of LegacyReplayFrame
x: float | None = None x: float | None = None
y: float | None = None y: float | None = None
button_state: ReplayButtonState button_state: int
class FrameDataBundle(MessagePackArrayModel): class FrameDataBundle(MessagePackArrayModel):
@@ -135,10 +129,10 @@ class StoreScore(BaseModel):
class StoreClientState(BaseModel): class StoreClientState(BaseModel):
state: SpectatorState | None state: SpectatorState | None = None
beatmap_status: BeatmapRankStatus beatmap_status: BeatmapRankStatus | None = None
checksum: str checksum: str | None = None
ruleset_id: int ruleset_id: int | None = None
score_token: int score_token: int | None = None
watched_user: set[int] watched_user: set[int] = Field(default_factory=set)
score: StoreScore score: StoreScore | None = None

View File

@@ -44,11 +44,11 @@ class Client:
async def send_packet(self, packet: Packet): async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.procotol.encode(packet)) await self.connection.send_bytes(self.procotol.encode(packet))
async def receive_packet(self) -> Packet: async def receive_packets(self) -> list[Packet]:
message = await self.connection.receive() message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode() d = message.get("bytes") or message.get("text", "").encode()
if not d: if not d:
return PingPacket() # FIXME: Graceful empty message handling return [PingPacket()] # FIXME: Graceful empty message handling
return self.procotol.decode(d) return self.procotol.decode(d)
async def _ping(self): async def _ping(self):
@@ -138,12 +138,13 @@ class Hub:
jump = False jump = False
while not jump: while not jump:
try: try:
packet = await client.receive_packet() packets = await client.receive_packets()
task = asyncio.create_task(self._handle_packet(client, packet)) for packet in packets:
self.tasks.add(task) if isinstance(packet, PingPacket):
task.add_done_callback(self.tasks.discard) continue
except StopIteration: task = asyncio.create_task(self._handle_packet(client, packet))
pass self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e: except WebSocketDisconnect as e:
print( print(
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}" f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"

View File

@@ -106,11 +106,12 @@ class MetadataHub(Hub):
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None: async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
if activity_dict is None:
# idle
return
user_id = int(client.connection_id) user_id = int(client.connection_id)
activity = TypeAdapter(UserActivity).validate_python(activity_dict) activity = (
TypeAdapter(UserActivity).validate_python(activity_dict)
if activity_dict
else None
)
store = self.state.get(user_id) store = self.state.get(user_id)
if store: if store:
store.user_activity = activity store.user_activity = activity
@@ -119,7 +120,6 @@ class MetadataHub(Hub):
user_activity=activity, user_activity=activity,
) )
self.state[user_id] = store self.state[user_id] = store
tasks = self.broadcast_tasks(user_id, store) tasks = self.broadcast_tasks(user_id, store)
tasks.add( tasks.add(
self.call_noblock( self.call_noblock(

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
import lzma import lzma
import struct import struct
@@ -13,6 +14,7 @@ from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int from app.models.mods import mods_to_int
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
from app.models.signalr import serialize_to_list
from app.models.spectator_hub import ( from app.models.spectator_hub import (
APIUser, APIUser,
FrameDataBundle, FrameDataBundle,
@@ -106,7 +108,7 @@ def save_replay(
for frame in frames: for frame in frames:
frame_strs.append( frame_strs.append(
f"{frame.time - last_time}|{frame.x or 0.0}" f"{frame.time - last_time}|{frame.x or 0.0}"
f"|{frame.y or 0.0}|{frame.button_state.value}" f"|{frame.y or 0.0}|{frame.button_state}"
) )
last_time = frame.time last_time = frame.time
frame_strs.append("-12345|0|0|0") frame_strs.append("-12345|0|0|0")
@@ -143,6 +145,20 @@ class SpectatorHub(Hub):
super().__init__() super().__init__()
self.state: dict[int, StoreClientState] = {} self.state: dict[int, StoreClientState] = {}
@staticmethod
def group_id(user_id: int) -> str:
return f"watch:{user_id}"
async def on_client_connect(self, client: Client) -> None:
tasks = [
self.call_noblock(
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
)
for user_id, store in self.state.items()
if store.state is not None
]
await asyncio.gather(*tasks)
async def BeginPlaySession( async def BeginPlaySession(
self, client: Client, score_token: int, state: SpectatorState self, client: Client, score_token: int, state: SpectatorState
) -> None: ) -> None:
@@ -184,7 +200,12 @@ class SpectatorHub(Hub):
), ),
) )
self.state[user_id] = store self.state[user_id] = store
await self.broadcast_call("UserBeganPlaying", user_id, state.model_dump()) await self.broadcast_group_call(
self.group_id(user_id),
"UserBeganPlaying",
user_id,
serialize_to_list(state),
)
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None: async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
user_id = int(client.connection_id) user_id = int(client.connection_id)
@@ -201,14 +222,14 @@ class SpectatorHub(Hub):
score.score_info.total_score = frame_data.header.total_score score.score_info.total_score = frame_data.header.total_score
score.score_info.mods = frame_data.header.mods score.score_info.mods = frame_data.header.mods
score.replay_frames.extend(frame_data.frames) score.replay_frames.extend(frame_data.frames)
await self.broadcast_call( await self.broadcast_group_call(
self.group_id(user_id),
"UserSentFrames", "UserSentFrames",
user_id, user_id,
frame_data.model_dump(), frame_data.model_dump(),
) )
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None: async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
print(f"EndPlaySession -> {client.connection_id} {state.model_dump()!r}")
user_id = int(client.connection_id) user_id = int(client.connection_id)
store = self.state.get(user_id) store = self.state.get(user_id)
if not store: if not store:
@@ -217,7 +238,13 @@ class SpectatorHub(Hub):
if not score or not store.score_token: if not score or not store.score_token:
return return
assert store.beatmap_status is not None
async def _save_replay(): async def _save_replay():
assert store.checksum is not None
assert store.ruleset_id is not None
assert store.state is not None
assert store.score is not None
async with AsyncSession(engine) as session: async with AsyncSession(engine) as session:
async with session: async with session:
start_time = time.time() start_time = time.time()
@@ -271,8 +298,51 @@ class SpectatorHub(Hub):
del self.state[user_id] del self.state[user_id]
if state.state == SpectatedUserState.Playing: if state.state == SpectatedUserState.Playing:
state.state = SpectatedUserState.Quit state.state = SpectatedUserState.Quit
await self.broadcast_call( await self.broadcast_group_call(
"UserEndedPlaying", self.group_id(user_id),
"UserFinishedPlaying",
user_id, user_id,
state.model_dump(), serialize_to_list(state) if state else None,
) )
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
print(f"StartWatchingUser -> {client.connection_id} {target_id}")
user_id = int(client.connection_id)
target_store = self.state.get(target_id)
if target_store and target_store.state:
await self.call_noblock(
client,
"UserBeganPlaying",
target_id,
serialize_to_list(target_store.state),
)
store = self.state.get(user_id)
if store is None:
store = StoreClientState(
watched_user=set(),
)
store.watched_user.add(target_id)
self.state[user_id] = store
self.groups.setdefault(self.group_id(target_id), set()).add(client)
async with AsyncSession(engine) as session:
async with session.begin():
username = (
await session.exec(select(User.name).where(User.id == user_id))
).first()
if not username:
return
if (target_client := self.get_client_by_id(str(target_id))) is not None:
await self.call_noblock(
target_client, "UserStartedWatching", [[user_id, username]]
)
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
print(f"EndWatchingUser -> {client.connection_id} {target_id}")
user_id = int(client.connection_id)
self.groups[self.group_id(target_id)].discard(client)
store = self.state.get(user_id)
if store:
store.watched_user.discard(target_id)
if (target_client := self.get_client_by_id(str(target_id))) is not None:
await self.call_noblock(target_client, "UserEndedWatching", user_id)

View File

@@ -60,7 +60,7 @@ PACKETS = {
class Protocol(TypingProtocol): class Protocol(TypingProtocol):
@staticmethod @staticmethod
def decode(input: bytes) -> Packet: ... def decode(input: bytes) -> list[Packet]: ...
@staticmethod @staticmethod
def encode(packet: Packet) -> bytes: ... def encode(packet: Packet) -> bytes: ...
@@ -93,7 +93,7 @@ class MsgpackProtocol:
return result, pos return result, pos
@staticmethod @staticmethod
def decode(input: bytes) -> Packet: def decode(input: bytes) -> list[Packet]:
length, offset = MsgpackProtocol._decode_varint(input) length, offset = MsgpackProtocol._decode_varint(input)
message_data = input[offset : offset + length] message_data = input[offset : offset + length]
# FIXME: custom deserializer for APIMod # FIXME: custom deserializer for APIMod
@@ -106,23 +106,27 @@ class MsgpackProtocol:
raise ValueError(f"Unknown packet type: {packet_type}") raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type: match packet_type:
case PacketType.INVOCATION: case PacketType.INVOCATION:
return InvocationPacket( return [
header=unpacked[1], InvocationPacket(
invocation_id=unpacked[2], header=unpacked[1],
target=unpacked[3], invocation_id=unpacked[2],
arguments=unpacked[4] if len(unpacked) > 4 else None, target=unpacked[3],
stream_ids=unpacked[5] if len(unpacked) > 5 else None, arguments=unpacked[4] if len(unpacked) > 4 else None,
) stream_ids=unpacked[5] if len(unpacked) > 5 else None,
)
]
case PacketType.COMPLETION: case PacketType.COMPLETION:
result_kind = unpacked[3] result_kind = unpacked[3]
return CompletionPacket( return [
header=unpacked[1], CompletionPacket(
invocation_id=unpacked[2], header=unpacked[1],
error=unpacked[4] if result_kind == 1 else None, invocation_id=unpacked[2],
result=unpacked[5] if result_kind == 3 else None, error=unpacked[4] if result_kind == 1 else None,
) result=unpacked[5] if result_kind == 3 else None,
)
]
case PacketType.PING: case PacketType.PING:
return PingPacket() return [PingPacket()]
raise ValueError(f"Unsupported packet type: {packet_type}") raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod @staticmethod
@@ -153,38 +157,48 @@ class MsgpackProtocol:
] ]
) )
elif isinstance(packet, PingPacket): elif isinstance(packet, PingPacket):
pass payload.pop(-1)
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
data = msgpack.packb(payload, use_bin_type=True)
return MsgpackProtocol._encode_varint(len(data)) + data return MsgpackProtocol._encode_varint(len(data)) + data
class JSONProtocol: class JSONProtocol:
@staticmethod @staticmethod
def decode(input: bytes) -> Packet: def decode(input: bytes) -> list[Packet]:
data = json.loads(input[:-1].decode("utf-8")) packets_raw = input.removesuffix(SEP).split(SEP)
packet_type = PacketType(data["type"]) packets = []
if packet_type not in PACKETS: if len(packets_raw) > 1:
raise ValueError(f"Unknown packet type: {packet_type}") for packet_raw in packets_raw:
match packet_type: packets.extend(JSONProtocol.decode(packet_raw))
case PacketType.INVOCATION: return packets
return InvocationPacket( else:
header=data.get("header"), data = json.loads(packets_raw[0])
invocation_id=data.get("invocationId"), packet_type = PacketType(data["type"])
target=data["target"], if packet_type not in PACKETS:
arguments=data.get("arguments"), raise ValueError(f"Unknown packet type: {packet_type}")
stream_ids=data.get("streamIds"), match packet_type:
) case PacketType.INVOCATION:
case PacketType.COMPLETION: return [
return CompletionPacket( InvocationPacket(
header=data.get("header"), header=data.get("header"),
invocation_id=data["invocationId"], invocation_id=data.get("invocationId"),
error=data.get("error"), target=data["target"],
result=data.get("result"), arguments=data.get("arguments"),
) stream_ids=data.get("streamIds"),
case PacketType.PING: )
return PingPacket() ]
raise ValueError(f"Unsupported packet type: {packet_type}") case PacketType.COMPLETION:
return [
CompletionPacket(
header=data.get("header"),
invocation_id=data["invocationId"],
error=data.get("error"),
result=data.get("result"),
)
]
case PacketType.PING:
return [PingPacket()]
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod @staticmethod
def encode(packet: Packet) -> bytes: def encode(packet: Packet) -> bytes: