feat(spectator): support spectate solo player
This commit is contained in:
@@ -44,11 +44,11 @@ class Client:
|
||||
async def send_packet(self, packet: Packet):
|
||||
await self.connection.send_bytes(self.procotol.encode(packet))
|
||||
|
||||
async def receive_packet(self) -> Packet:
|
||||
async def receive_packets(self) -> list[Packet]:
|
||||
message = await self.connection.receive()
|
||||
d = message.get("bytes") or message.get("text", "").encode()
|
||||
if not d:
|
||||
return PingPacket() # FIXME: Graceful empty message handling
|
||||
return [PingPacket()] # FIXME: Graceful empty message handling
|
||||
return self.procotol.decode(d)
|
||||
|
||||
async def _ping(self):
|
||||
@@ -138,12 +138,13 @@ class Hub:
|
||||
jump = False
|
||||
while not jump:
|
||||
try:
|
||||
packet = await client.receive_packet()
|
||||
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except StopIteration:
|
||||
pass
|
||||
packets = await client.receive_packets()
|
||||
for packet in packets:
|
||||
if isinstance(packet, PingPacket):
|
||||
continue
|
||||
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except WebSocketDisconnect as e:
|
||||
print(
|
||||
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
||||
|
||||
@@ -106,11 +106,12 @@ class MetadataHub(Hub):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
|
||||
if activity_dict is None:
|
||||
# idle
|
||||
return
|
||||
user_id = int(client.connection_id)
|
||||
activity = TypeAdapter(UserActivity).validate_python(activity_dict)
|
||||
activity = (
|
||||
TypeAdapter(UserActivity).validate_python(activity_dict)
|
||||
if activity_dict
|
||||
else None
|
||||
)
|
||||
store = self.state.get(user_id)
|
||||
if store:
|
||||
store.user_activity = activity
|
||||
@@ -119,7 +120,6 @@ class MetadataHub(Hub):
|
||||
user_activity=activity,
|
||||
)
|
||||
self.state[user_id] = store
|
||||
|
||||
tasks = self.broadcast_tasks(user_id, store)
|
||||
tasks.add(
|
||||
self.call_noblock(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import lzma
|
||||
import struct
|
||||
@@ -13,6 +14,7 @@ from app.dependencies.database import engine
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import mods_to_int
|
||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
|
||||
from app.models.signalr import serialize_to_list
|
||||
from app.models.spectator_hub import (
|
||||
APIUser,
|
||||
FrameDataBundle,
|
||||
@@ -106,7 +108,7 @@ def save_replay(
|
||||
for frame in frames:
|
||||
frame_strs.append(
|
||||
f"{frame.time - last_time}|{frame.x or 0.0}"
|
||||
f"|{frame.y or 0.0}|{frame.button_state.value}"
|
||||
f"|{frame.y or 0.0}|{frame.button_state}"
|
||||
)
|
||||
last_time = frame.time
|
||||
frame_strs.append("-12345|0|0|0")
|
||||
@@ -143,6 +145,20 @@ class SpectatorHub(Hub):
|
||||
super().__init__()
|
||||
self.state: dict[int, StoreClientState] = {}
|
||||
|
||||
@staticmethod
|
||||
def group_id(user_id: int) -> str:
|
||||
return f"watch:{user_id}"
|
||||
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
tasks = [
|
||||
self.call_noblock(
|
||||
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
|
||||
)
|
||||
for user_id, store in self.state.items()
|
||||
if store.state is not None
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def BeginPlaySession(
|
||||
self, client: Client, score_token: int, state: SpectatorState
|
||||
) -> None:
|
||||
@@ -184,7 +200,12 @@ class SpectatorHub(Hub):
|
||||
),
|
||||
)
|
||||
self.state[user_id] = store
|
||||
await self.broadcast_call("UserBeganPlaying", user_id, state.model_dump())
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserBeganPlaying",
|
||||
user_id,
|
||||
serialize_to_list(state),
|
||||
)
|
||||
|
||||
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
@@ -201,14 +222,14 @@ class SpectatorHub(Hub):
|
||||
score.score_info.total_score = frame_data.header.total_score
|
||||
score.score_info.mods = frame_data.header.mods
|
||||
score.replay_frames.extend(frame_data.frames)
|
||||
await self.broadcast_call(
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserSentFrames",
|
||||
user_id,
|
||||
frame_data.model_dump(),
|
||||
)
|
||||
|
||||
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
||||
print(f"EndPlaySession -> {client.connection_id} {state.model_dump()!r}")
|
||||
user_id = int(client.connection_id)
|
||||
store = self.state.get(user_id)
|
||||
if not store:
|
||||
@@ -217,7 +238,13 @@ class SpectatorHub(Hub):
|
||||
if not score or not store.score_token:
|
||||
return
|
||||
|
||||
assert store.beatmap_status is not None
|
||||
|
||||
async def _save_replay():
|
||||
assert store.checksum is not None
|
||||
assert store.ruleset_id is not None
|
||||
assert store.state is not None
|
||||
assert store.score is not None
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session:
|
||||
start_time = time.time()
|
||||
@@ -271,8 +298,51 @@ class SpectatorHub(Hub):
|
||||
del self.state[user_id]
|
||||
if state.state == SpectatedUserState.Playing:
|
||||
state.state = SpectatedUserState.Quit
|
||||
await self.broadcast_call(
|
||||
"UserEndedPlaying",
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserFinishedPlaying",
|
||||
user_id,
|
||||
state.model_dump(),
|
||||
serialize_to_list(state) if state else None,
|
||||
)
|
||||
|
||||
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
print(f"StartWatchingUser -> {client.connection_id} {target_id}")
|
||||
user_id = int(client.connection_id)
|
||||
target_store = self.state.get(target_id)
|
||||
if target_store and target_store.state:
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"UserBeganPlaying",
|
||||
target_id,
|
||||
serialize_to_list(target_store.state),
|
||||
)
|
||||
store = self.state.get(user_id)
|
||||
if store is None:
|
||||
store = StoreClientState(
|
||||
watched_user=set(),
|
||||
)
|
||||
store.watched_user.add(target_id)
|
||||
self.state[user_id] = store
|
||||
self.groups.setdefault(self.group_id(target_id), set()).add(client)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
username = (
|
||||
await session.exec(select(User.name).where(User.id == user_id))
|
||||
).first()
|
||||
if not username:
|
||||
return
|
||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||
await self.call_noblock(
|
||||
target_client, "UserStartedWatching", [[user_id, username]]
|
||||
)
|
||||
|
||||
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
print(f"EndWatchingUser -> {client.connection_id} {target_id}")
|
||||
user_id = int(client.connection_id)
|
||||
self.groups[self.group_id(target_id)].discard(client)
|
||||
store = self.state.get(user_id)
|
||||
if store:
|
||||
store.watched_user.discard(target_id)
|
||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
||||
|
||||
@@ -60,7 +60,7 @@ PACKETS = {
|
||||
|
||||
class Protocol(TypingProtocol):
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> Packet: ...
|
||||
def decode(input: bytes) -> list[Packet]: ...
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes: ...
|
||||
@@ -93,7 +93,7 @@ class MsgpackProtocol:
|
||||
return result, pos
|
||||
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> Packet:
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
length, offset = MsgpackProtocol._decode_varint(input)
|
||||
message_data = input[offset : offset + length]
|
||||
# FIXME: custom deserializer for APIMod
|
||||
@@ -106,23 +106,27 @@ class MsgpackProtocol:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
match packet_type:
|
||||
case PacketType.INVOCATION:
|
||||
return InvocationPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
target=unpacked[3],
|
||||
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
||||
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
||||
)
|
||||
return [
|
||||
InvocationPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
target=unpacked[3],
|
||||
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
||||
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
||||
)
|
||||
]
|
||||
case PacketType.COMPLETION:
|
||||
result_kind = unpacked[3]
|
||||
return CompletionPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
error=unpacked[4] if result_kind == 1 else None,
|
||||
result=unpacked[5] if result_kind == 3 else None,
|
||||
)
|
||||
return [
|
||||
CompletionPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
error=unpacked[4] if result_kind == 1 else None,
|
||||
result=unpacked[5] if result_kind == 3 else None,
|
||||
)
|
||||
]
|
||||
case PacketType.PING:
|
||||
return PingPacket()
|
||||
return [PingPacket()]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@staticmethod
|
||||
@@ -153,38 +157,48 @@ class MsgpackProtocol:
|
||||
]
|
||||
)
|
||||
elif isinstance(packet, PingPacket):
|
||||
pass
|
||||
|
||||
data = msgpack.packb(payload, use_bin_type=True)
|
||||
payload.pop(-1)
|
||||
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
|
||||
return MsgpackProtocol._encode_varint(len(data)) + data
|
||||
|
||||
|
||||
class JSONProtocol:
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> Packet:
|
||||
data = json.loads(input[:-1].decode("utf-8"))
|
||||
packet_type = PacketType(data["type"])
|
||||
if packet_type not in PACKETS:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
match packet_type:
|
||||
case PacketType.INVOCATION:
|
||||
return InvocationPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data.get("invocationId"),
|
||||
target=data["target"],
|
||||
arguments=data.get("arguments"),
|
||||
stream_ids=data.get("streamIds"),
|
||||
)
|
||||
case PacketType.COMPLETION:
|
||||
return CompletionPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data["invocationId"],
|
||||
error=data.get("error"),
|
||||
result=data.get("result"),
|
||||
)
|
||||
case PacketType.PING:
|
||||
return PingPacket()
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
packets_raw = input.removesuffix(SEP).split(SEP)
|
||||
packets = []
|
||||
if len(packets_raw) > 1:
|
||||
for packet_raw in packets_raw:
|
||||
packets.extend(JSONProtocol.decode(packet_raw))
|
||||
return packets
|
||||
else:
|
||||
data = json.loads(packets_raw[0])
|
||||
packet_type = PacketType(data["type"])
|
||||
if packet_type not in PACKETS:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
match packet_type:
|
||||
case PacketType.INVOCATION:
|
||||
return [
|
||||
InvocationPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data.get("invocationId"),
|
||||
target=data["target"],
|
||||
arguments=data.get("arguments"),
|
||||
stream_ids=data.get("streamIds"),
|
||||
)
|
||||
]
|
||||
case PacketType.COMPLETION:
|
||||
return [
|
||||
CompletionPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data["invocationId"],
|
||||
error=data.get("error"),
|
||||
result=data.get("result"),
|
||||
)
|
||||
]
|
||||
case PacketType.PING:
|
||||
return [PingPacket()]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes:
|
||||
|
||||
Reference in New Issue
Block a user