feat(spectator): support spectate solo player

This commit is contained in:
MingxuanGame
2025-07-28 05:52:48 +00:00
parent 20d528d203
commit 722a6e57d8
7 changed files with 213 additions and 103 deletions

View File

@@ -44,11 +44,11 @@ class Client:
async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.procotol.encode(packet))
async def receive_packet(self) -> Packet:
async def receive_packets(self) -> list[Packet]:
message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
return PingPacket() # FIXME: Graceful empty message handling
return [PingPacket()] # FIXME: Graceful empty message handling
return self.procotol.decode(d)
async def _ping(self):
@@ -138,12 +138,13 @@ class Hub:
jump = False
while not jump:
try:
packet = await client.receive_packet()
task = asyncio.create_task(self._handle_packet(client, packet))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except StopIteration:
pass
packets = await client.receive_packets()
for packet in packets:
if isinstance(packet, PingPacket):
continue
task = asyncio.create_task(self._handle_packet(client, packet))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e:
print(
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"

View File

@@ -106,11 +106,12 @@ class MetadataHub(Hub):
await asyncio.gather(*tasks)
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
if activity_dict is None:
# idle
return
user_id = int(client.connection_id)
activity = TypeAdapter(UserActivity).validate_python(activity_dict)
activity = (
TypeAdapter(UserActivity).validate_python(activity_dict)
if activity_dict
else None
)
store = self.state.get(user_id)
if store:
store.user_activity = activity
@@ -119,7 +120,6 @@ class MetadataHub(Hub):
user_activity=activity,
)
self.state[user_id] = store
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import json
import lzma
import struct
@@ -13,6 +14,7 @@ from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
from app.models.signalr import serialize_to_list
from app.models.spectator_hub import (
APIUser,
FrameDataBundle,
@@ -106,7 +108,7 @@ def save_replay(
for frame in frames:
frame_strs.append(
f"{frame.time - last_time}|{frame.x or 0.0}"
f"|{frame.y or 0.0}|{frame.button_state.value}"
f"|{frame.y or 0.0}|{frame.button_state}"
)
last_time = frame.time
frame_strs.append("-12345|0|0|0")
@@ -143,6 +145,20 @@ class SpectatorHub(Hub):
super().__init__()
self.state: dict[int, StoreClientState] = {}
@staticmethod
def group_id(user_id: int) -> str:
return f"watch:{user_id}"
async def on_client_connect(self, client: Client) -> None:
tasks = [
self.call_noblock(
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
)
for user_id, store in self.state.items()
if store.state is not None
]
await asyncio.gather(*tasks)
async def BeginPlaySession(
self, client: Client, score_token: int, state: SpectatorState
) -> None:
@@ -184,7 +200,12 @@ class SpectatorHub(Hub):
),
)
self.state[user_id] = store
await self.broadcast_call("UserBeganPlaying", user_id, state.model_dump())
await self.broadcast_group_call(
self.group_id(user_id),
"UserBeganPlaying",
user_id,
serialize_to_list(state),
)
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
user_id = int(client.connection_id)
@@ -201,14 +222,14 @@ class SpectatorHub(Hub):
score.score_info.total_score = frame_data.header.total_score
score.score_info.mods = frame_data.header.mods
score.replay_frames.extend(frame_data.frames)
await self.broadcast_call(
await self.broadcast_group_call(
self.group_id(user_id),
"UserSentFrames",
user_id,
frame_data.model_dump(),
)
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
print(f"EndPlaySession -> {client.connection_id} {state.model_dump()!r}")
user_id = int(client.connection_id)
store = self.state.get(user_id)
if not store:
@@ -217,7 +238,13 @@ class SpectatorHub(Hub):
if not score or not store.score_token:
return
assert store.beatmap_status is not None
async def _save_replay():
assert store.checksum is not None
assert store.ruleset_id is not None
assert store.state is not None
assert store.score is not None
async with AsyncSession(engine) as session:
async with session:
start_time = time.time()
@@ -271,8 +298,51 @@ class SpectatorHub(Hub):
del self.state[user_id]
if state.state == SpectatedUserState.Playing:
state.state = SpectatedUserState.Quit
await self.broadcast_call(
"UserEndedPlaying",
await self.broadcast_group_call(
self.group_id(user_id),
"UserFinishedPlaying",
user_id,
state.model_dump(),
serialize_to_list(state) if state else None,
)
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
print(f"StartWatchingUser -> {client.connection_id} {target_id}")
user_id = int(client.connection_id)
target_store = self.state.get(target_id)
if target_store and target_store.state:
await self.call_noblock(
client,
"UserBeganPlaying",
target_id,
serialize_to_list(target_store.state),
)
store = self.state.get(user_id)
if store is None:
store = StoreClientState(
watched_user=set(),
)
store.watched_user.add(target_id)
self.state[user_id] = store
self.groups.setdefault(self.group_id(target_id), set()).add(client)
async with AsyncSession(engine) as session:
async with session.begin():
username = (
await session.exec(select(User.name).where(User.id == user_id))
).first()
if not username:
return
if (target_client := self.get_client_by_id(str(target_id))) is not None:
await self.call_noblock(
target_client, "UserStartedWatching", [[user_id, username]]
)
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
print(f"EndWatchingUser -> {client.connection_id} {target_id}")
user_id = int(client.connection_id)
self.groups[self.group_id(target_id)].discard(client)
store = self.state.get(user_id)
if store:
store.watched_user.discard(target_id)
if (target_client := self.get_client_by_id(str(target_id))) is not None:
await self.call_noblock(target_client, "UserEndedWatching", user_id)

View File

@@ -60,7 +60,7 @@ PACKETS = {
class Protocol(TypingProtocol):
@staticmethod
def decode(input: bytes) -> Packet: ...
def decode(input: bytes) -> list[Packet]: ...
@staticmethod
def encode(packet: Packet) -> bytes: ...
@@ -93,7 +93,7 @@ class MsgpackProtocol:
return result, pos
@staticmethod
def decode(input: bytes) -> Packet:
def decode(input: bytes) -> list[Packet]:
length, offset = MsgpackProtocol._decode_varint(input)
message_data = input[offset : offset + length]
# FIXME: custom deserializer for APIMod
@@ -106,23 +106,27 @@ class MsgpackProtocol:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return InvocationPacket(
header=unpacked[1],
invocation_id=unpacked[2],
target=unpacked[3],
arguments=unpacked[4] if len(unpacked) > 4 else None,
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
)
return [
InvocationPacket(
header=unpacked[1],
invocation_id=unpacked[2],
target=unpacked[3],
arguments=unpacked[4] if len(unpacked) > 4 else None,
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
)
]
case PacketType.COMPLETION:
result_kind = unpacked[3]
return CompletionPacket(
header=unpacked[1],
invocation_id=unpacked[2],
error=unpacked[4] if result_kind == 1 else None,
result=unpacked[5] if result_kind == 3 else None,
)
return [
CompletionPacket(
header=unpacked[1],
invocation_id=unpacked[2],
error=unpacked[4] if result_kind == 1 else None,
result=unpacked[5] if result_kind == 3 else None,
)
]
case PacketType.PING:
return PingPacket()
return [PingPacket()]
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod
@@ -153,38 +157,48 @@ class MsgpackProtocol:
]
)
elif isinstance(packet, PingPacket):
pass
data = msgpack.packb(payload, use_bin_type=True)
payload.pop(-1)
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
return MsgpackProtocol._encode_varint(len(data)) + data
class JSONProtocol:
@staticmethod
def decode(input: bytes) -> Packet:
data = json.loads(input[:-1].decode("utf-8"))
packet_type = PacketType(data["type"])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return InvocationPacket(
header=data.get("header"),
invocation_id=data.get("invocationId"),
target=data["target"],
arguments=data.get("arguments"),
stream_ids=data.get("streamIds"),
)
case PacketType.COMPLETION:
return CompletionPacket(
header=data.get("header"),
invocation_id=data["invocationId"],
error=data.get("error"),
result=data.get("result"),
)
case PacketType.PING:
return PingPacket()
raise ValueError(f"Unsupported packet type: {packet_type}")
def decode(input: bytes) -> list[Packet]:
packets_raw = input.removesuffix(SEP).split(SEP)
packets = []
if len(packets_raw) > 1:
for packet_raw in packets_raw:
packets.extend(JSONProtocol.decode(packet_raw))
return packets
else:
data = json.loads(packets_raw[0])
packet_type = PacketType(data["type"])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return [
InvocationPacket(
header=data.get("header"),
invocation_id=data.get("invocationId"),
target=data["target"],
arguments=data.get("arguments"),
stream_ids=data.get("streamIds"),
)
]
case PacketType.COMPLETION:
return [
CompletionPacket(
header=data.get("header"),
invocation_id=data["invocationId"],
error=data.get("error"),
result=data.get("result"),
)
]
case PacketType.PING:
return [PingPacket()]
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod
def encode(packet: Packet) -> bytes: