feat(spectator): support spectate solo player
This commit is contained in:
@@ -28,8 +28,7 @@ class _UserActivity(BaseModel):
|
||||
|
||||
|
||||
class ChoosingBeatmap(_UserActivity):
|
||||
type: Literal["ChoosingBeatmap"] = "ChoosingBeatmap"
|
||||
value: Literal[None] = None
|
||||
type: Literal["ChoosingBeatmap"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class InGameValue(BaseModel):
|
||||
@@ -44,19 +43,19 @@ class _InGame(_UserActivity):
|
||||
|
||||
|
||||
class InSoloGame(_InGame):
|
||||
type: Literal["InSoloGame"] = "InSoloGame"
|
||||
type: Literal["InSoloGame"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class InMultiplayerGame(_InGame):
|
||||
type: Literal["InMultiplayerGame"] = "InMultiplayerGame"
|
||||
type: Literal["InMultiplayerGame"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class SpectatingMultiplayerGame(_InGame):
|
||||
type: Literal["SpectatingMultiplayerGame"] = "SpectatingMultiplayerGame"
|
||||
type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class InPlaylistGame(_InGame):
|
||||
type: Literal["InPlaylistGame"] = "InPlaylistGame"
|
||||
type: Literal["InPlaylistGame"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class EditingBeatmapValue(BaseModel):
|
||||
@@ -65,16 +64,16 @@ class EditingBeatmapValue(BaseModel):
|
||||
|
||||
|
||||
class EditingBeatmap(_UserActivity):
|
||||
type: Literal["EditingBeatmap"] = "EditingBeatmap"
|
||||
type: Literal["EditingBeatmap"] = Field(alias="$dtype")
|
||||
value: EditingBeatmapValue = Field(alias="$value")
|
||||
|
||||
|
||||
class TestingBeatmap(_UserActivity):
|
||||
type: Literal["TestingBeatmap"] = "TestingBeatmap"
|
||||
type: Literal["TestingBeatmap"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class ModdingBeatmap(_UserActivity):
|
||||
type: Literal["ModdingBeatmap"] = "ModdingBeatmap"
|
||||
type: Literal["ModdingBeatmap"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class WatchingReplayValue(BaseModel):
|
||||
@@ -85,17 +84,16 @@ class WatchingReplayValue(BaseModel):
|
||||
|
||||
|
||||
class WatchingReplay(_UserActivity):
|
||||
type: Literal["WatchingReplay"] = "WatchingReplay"
|
||||
type: Literal["WatchingReplay"] = Field(alias="$dtype")
|
||||
value: int | None = Field(alias="$value") # Replay ID
|
||||
|
||||
|
||||
class SpectatingUser(WatchingReplay):
|
||||
type: Literal["SpectatingUser"] = "SpectatingUser"
|
||||
type: Literal["SpectatingUser"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class SearchingForLobby(_UserActivity):
|
||||
type: Literal["SearchingForLobby"] = "SearchingForLobby"
|
||||
value: None = Field(alias="$value")
|
||||
type: Literal["SearchingForLobby"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class InLobbyValue(BaseModel):
|
||||
@@ -105,12 +103,10 @@ class InLobbyValue(BaseModel):
|
||||
|
||||
class InLobby(_UserActivity):
|
||||
type: Literal["InLobby"] = "InLobby"
|
||||
value: None = Field(alias="$value")
|
||||
|
||||
|
||||
class InDailyChallengeLobby(_UserActivity):
|
||||
type: Literal["InDailyChallengeLobby"] = "InDailyChallengeLobby"
|
||||
value: None = Field(alias="$value")
|
||||
type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
UserActivity = (
|
||||
|
||||
@@ -1,11 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import datetime
|
||||
from typing import Any, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
import msgpack
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
def serialize_to_list(value: BaseModel) -> list[Any]:
|
||||
data = []
|
||||
for field, info in value.__class__.model_fields.items():
|
||||
v = getattr(value, field)
|
||||
anno = get_origin(info.annotation)
|
||||
if anno and issubclass(anno, BaseModel):
|
||||
data.append(serialize_to_list(v))
|
||||
elif anno and issubclass(anno, list):
|
||||
data.append(
|
||||
TypeAdapter(
|
||||
info.annotation,
|
||||
).dump_python(v)
|
||||
)
|
||||
elif isinstance(v, datetime.datetime):
|
||||
data.append([msgpack.ext.Timestamp.from_datetime(v), 0])
|
||||
else:
|
||||
data.append(v)
|
||||
return data
|
||||
|
||||
|
||||
class MessagePackArrayModel(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def unpack(cls, v: Any) -> Any:
|
||||
@@ -16,6 +47,10 @@ class MessagePackArrayModel(BaseModel):
|
||||
return dict(zip(fields, v))
|
||||
return v
|
||||
|
||||
@model_serializer
|
||||
def serialize(self) -> list[Any]:
|
||||
return serialize_to_list(self)
|
||||
|
||||
|
||||
class Transport(BaseModel):
|
||||
transport: str
|
||||
|
||||
@@ -12,7 +12,7 @@ from .score import (
|
||||
from .signalr import MessagePackArrayModel
|
||||
|
||||
import msgpack
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class APIMod(MessagePackArrayModel):
|
||||
@@ -58,8 +58,6 @@ class ScoreProcessorStatistics(MessagePackArrayModel):
|
||||
|
||||
|
||||
class FrameHeader(MessagePackArrayModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
total_score: int
|
||||
acc: float
|
||||
combo: int
|
||||
@@ -84,25 +82,21 @@ class FrameHeader(MessagePackArrayModel):
|
||||
return datetime.datetime.fromisoformat(v)
|
||||
raise ValueError(f"Cannot convert {type(v)} to datetime")
|
||||
|
||||
@field_serializer("received_time")
|
||||
def serialize_received_time(self, v: datetime.datetime) -> msgpack.ext.Timestamp:
|
||||
return msgpack.ext.Timestamp.from_datetime(v)
|
||||
|
||||
|
||||
class ReplayButtonState(IntEnum):
|
||||
NONE = 0
|
||||
LEFT1 = 1
|
||||
RIGHT1 = 2
|
||||
LEFT2 = 4
|
||||
RIGHT2 = 8
|
||||
SMOKE = 16
|
||||
# class ReplayButtonState(IntEnum):
|
||||
# NONE = 0
|
||||
# LEFT1 = 1
|
||||
# RIGHT1 = 2
|
||||
# LEFT2 = 4
|
||||
# RIGHT2 = 8
|
||||
# SMOKE = 16
|
||||
|
||||
|
||||
class LegacyReplayFrame(MessagePackArrayModel):
|
||||
time: float # from ReplayFrame,the parent of LegacyReplayFrame
|
||||
x: float | None = None
|
||||
y: float | None = None
|
||||
button_state: ReplayButtonState
|
||||
button_state: int
|
||||
|
||||
|
||||
class FrameDataBundle(MessagePackArrayModel):
|
||||
@@ -135,10 +129,10 @@ class StoreScore(BaseModel):
|
||||
|
||||
|
||||
class StoreClientState(BaseModel):
|
||||
state: SpectatorState | None
|
||||
beatmap_status: BeatmapRankStatus
|
||||
checksum: str
|
||||
ruleset_id: int
|
||||
score_token: int
|
||||
watched_user: set[int]
|
||||
score: StoreScore
|
||||
state: SpectatorState | None = None
|
||||
beatmap_status: BeatmapRankStatus | None = None
|
||||
checksum: str | None = None
|
||||
ruleset_id: int | None = None
|
||||
score_token: int | None = None
|
||||
watched_user: set[int] = Field(default_factory=set)
|
||||
score: StoreScore | None = None
|
||||
|
||||
@@ -44,11 +44,11 @@ class Client:
|
||||
async def send_packet(self, packet: Packet):
|
||||
await self.connection.send_bytes(self.procotol.encode(packet))
|
||||
|
||||
async def receive_packet(self) -> Packet:
|
||||
async def receive_packets(self) -> list[Packet]:
|
||||
message = await self.connection.receive()
|
||||
d = message.get("bytes") or message.get("text", "").encode()
|
||||
if not d:
|
||||
return PingPacket() # FIXME: Graceful empty message handling
|
||||
return [PingPacket()] # FIXME: Graceful empty message handling
|
||||
return self.procotol.decode(d)
|
||||
|
||||
async def _ping(self):
|
||||
@@ -138,12 +138,13 @@ class Hub:
|
||||
jump = False
|
||||
while not jump:
|
||||
try:
|
||||
packet = await client.receive_packet()
|
||||
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except StopIteration:
|
||||
pass
|
||||
packets = await client.receive_packets()
|
||||
for packet in packets:
|
||||
if isinstance(packet, PingPacket):
|
||||
continue
|
||||
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except WebSocketDisconnect as e:
|
||||
print(
|
||||
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
||||
|
||||
@@ -106,11 +106,12 @@ class MetadataHub(Hub):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
|
||||
if activity_dict is None:
|
||||
# idle
|
||||
return
|
||||
user_id = int(client.connection_id)
|
||||
activity = TypeAdapter(UserActivity).validate_python(activity_dict)
|
||||
activity = (
|
||||
TypeAdapter(UserActivity).validate_python(activity_dict)
|
||||
if activity_dict
|
||||
else None
|
||||
)
|
||||
store = self.state.get(user_id)
|
||||
if store:
|
||||
store.user_activity = activity
|
||||
@@ -119,7 +120,6 @@ class MetadataHub(Hub):
|
||||
user_activity=activity,
|
||||
)
|
||||
self.state[user_id] = store
|
||||
|
||||
tasks = self.broadcast_tasks(user_id, store)
|
||||
tasks.add(
|
||||
self.call_noblock(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import lzma
|
||||
import struct
|
||||
@@ -13,6 +14,7 @@ from app.dependencies.database import engine
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import mods_to_int
|
||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
|
||||
from app.models.signalr import serialize_to_list
|
||||
from app.models.spectator_hub import (
|
||||
APIUser,
|
||||
FrameDataBundle,
|
||||
@@ -106,7 +108,7 @@ def save_replay(
|
||||
for frame in frames:
|
||||
frame_strs.append(
|
||||
f"{frame.time - last_time}|{frame.x or 0.0}"
|
||||
f"|{frame.y or 0.0}|{frame.button_state.value}"
|
||||
f"|{frame.y or 0.0}|{frame.button_state}"
|
||||
)
|
||||
last_time = frame.time
|
||||
frame_strs.append("-12345|0|0|0")
|
||||
@@ -143,6 +145,20 @@ class SpectatorHub(Hub):
|
||||
super().__init__()
|
||||
self.state: dict[int, StoreClientState] = {}
|
||||
|
||||
@staticmethod
|
||||
def group_id(user_id: int) -> str:
|
||||
return f"watch:{user_id}"
|
||||
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
tasks = [
|
||||
self.call_noblock(
|
||||
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
|
||||
)
|
||||
for user_id, store in self.state.items()
|
||||
if store.state is not None
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def BeginPlaySession(
|
||||
self, client: Client, score_token: int, state: SpectatorState
|
||||
) -> None:
|
||||
@@ -184,7 +200,12 @@ class SpectatorHub(Hub):
|
||||
),
|
||||
)
|
||||
self.state[user_id] = store
|
||||
await self.broadcast_call("UserBeganPlaying", user_id, state.model_dump())
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserBeganPlaying",
|
||||
user_id,
|
||||
serialize_to_list(state),
|
||||
)
|
||||
|
||||
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
@@ -201,14 +222,14 @@ class SpectatorHub(Hub):
|
||||
score.score_info.total_score = frame_data.header.total_score
|
||||
score.score_info.mods = frame_data.header.mods
|
||||
score.replay_frames.extend(frame_data.frames)
|
||||
await self.broadcast_call(
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserSentFrames",
|
||||
user_id,
|
||||
frame_data.model_dump(),
|
||||
)
|
||||
|
||||
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
||||
print(f"EndPlaySession -> {client.connection_id} {state.model_dump()!r}")
|
||||
user_id = int(client.connection_id)
|
||||
store = self.state.get(user_id)
|
||||
if not store:
|
||||
@@ -217,7 +238,13 @@ class SpectatorHub(Hub):
|
||||
if not score or not store.score_token:
|
||||
return
|
||||
|
||||
assert store.beatmap_status is not None
|
||||
|
||||
async def _save_replay():
|
||||
assert store.checksum is not None
|
||||
assert store.ruleset_id is not None
|
||||
assert store.state is not None
|
||||
assert store.score is not None
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session:
|
||||
start_time = time.time()
|
||||
@@ -271,8 +298,51 @@ class SpectatorHub(Hub):
|
||||
del self.state[user_id]
|
||||
if state.state == SpectatedUserState.Playing:
|
||||
state.state = SpectatedUserState.Quit
|
||||
await self.broadcast_call(
|
||||
"UserEndedPlaying",
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserFinishedPlaying",
|
||||
user_id,
|
||||
state.model_dump(),
|
||||
serialize_to_list(state) if state else None,
|
||||
)
|
||||
|
||||
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
print(f"StartWatchingUser -> {client.connection_id} {target_id}")
|
||||
user_id = int(client.connection_id)
|
||||
target_store = self.state.get(target_id)
|
||||
if target_store and target_store.state:
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"UserBeganPlaying",
|
||||
target_id,
|
||||
serialize_to_list(target_store.state),
|
||||
)
|
||||
store = self.state.get(user_id)
|
||||
if store is None:
|
||||
store = StoreClientState(
|
||||
watched_user=set(),
|
||||
)
|
||||
store.watched_user.add(target_id)
|
||||
self.state[user_id] = store
|
||||
self.groups.setdefault(self.group_id(target_id), set()).add(client)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
username = (
|
||||
await session.exec(select(User.name).where(User.id == user_id))
|
||||
).first()
|
||||
if not username:
|
||||
return
|
||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||
await self.call_noblock(
|
||||
target_client, "UserStartedWatching", [[user_id, username]]
|
||||
)
|
||||
|
||||
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
print(f"EndWatchingUser -> {client.connection_id} {target_id}")
|
||||
user_id = int(client.connection_id)
|
||||
self.groups[self.group_id(target_id)].discard(client)
|
||||
store = self.state.get(user_id)
|
||||
if store:
|
||||
store.watched_user.discard(target_id)
|
||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
||||
|
||||
@@ -60,7 +60,7 @@ PACKETS = {
|
||||
|
||||
class Protocol(TypingProtocol):
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> Packet: ...
|
||||
def decode(input: bytes) -> list[Packet]: ...
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes: ...
|
||||
@@ -93,7 +93,7 @@ class MsgpackProtocol:
|
||||
return result, pos
|
||||
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> Packet:
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
length, offset = MsgpackProtocol._decode_varint(input)
|
||||
message_data = input[offset : offset + length]
|
||||
# FIXME: custom deserializer for APIMod
|
||||
@@ -106,23 +106,27 @@ class MsgpackProtocol:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
match packet_type:
|
||||
case PacketType.INVOCATION:
|
||||
return InvocationPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
target=unpacked[3],
|
||||
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
||||
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
||||
)
|
||||
return [
|
||||
InvocationPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
target=unpacked[3],
|
||||
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
||||
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
||||
)
|
||||
]
|
||||
case PacketType.COMPLETION:
|
||||
result_kind = unpacked[3]
|
||||
return CompletionPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
error=unpacked[4] if result_kind == 1 else None,
|
||||
result=unpacked[5] if result_kind == 3 else None,
|
||||
)
|
||||
return [
|
||||
CompletionPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
error=unpacked[4] if result_kind == 1 else None,
|
||||
result=unpacked[5] if result_kind == 3 else None,
|
||||
)
|
||||
]
|
||||
case PacketType.PING:
|
||||
return PingPacket()
|
||||
return [PingPacket()]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@staticmethod
|
||||
@@ -153,38 +157,48 @@ class MsgpackProtocol:
|
||||
]
|
||||
)
|
||||
elif isinstance(packet, PingPacket):
|
||||
pass
|
||||
|
||||
data = msgpack.packb(payload, use_bin_type=True)
|
||||
payload.pop(-1)
|
||||
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
|
||||
return MsgpackProtocol._encode_varint(len(data)) + data
|
||||
|
||||
|
||||
class JSONProtocol:
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> Packet:
|
||||
data = json.loads(input[:-1].decode("utf-8"))
|
||||
packet_type = PacketType(data["type"])
|
||||
if packet_type not in PACKETS:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
match packet_type:
|
||||
case PacketType.INVOCATION:
|
||||
return InvocationPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data.get("invocationId"),
|
||||
target=data["target"],
|
||||
arguments=data.get("arguments"),
|
||||
stream_ids=data.get("streamIds"),
|
||||
)
|
||||
case PacketType.COMPLETION:
|
||||
return CompletionPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data["invocationId"],
|
||||
error=data.get("error"),
|
||||
result=data.get("result"),
|
||||
)
|
||||
case PacketType.PING:
|
||||
return PingPacket()
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
packets_raw = input.removesuffix(SEP).split(SEP)
|
||||
packets = []
|
||||
if len(packets_raw) > 1:
|
||||
for packet_raw in packets_raw:
|
||||
packets.extend(JSONProtocol.decode(packet_raw))
|
||||
return packets
|
||||
else:
|
||||
data = json.loads(packets_raw[0])
|
||||
packet_type = PacketType(data["type"])
|
||||
if packet_type not in PACKETS:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
match packet_type:
|
||||
case PacketType.INVOCATION:
|
||||
return [
|
||||
InvocationPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data.get("invocationId"),
|
||||
target=data["target"],
|
||||
arguments=data.get("arguments"),
|
||||
stream_ids=data.get("streamIds"),
|
||||
)
|
||||
]
|
||||
case PacketType.COMPLETION:
|
||||
return [
|
||||
CompletionPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data["invocationId"],
|
||||
error=data.get("error"),
|
||||
result=data.get("result"),
|
||||
)
|
||||
]
|
||||
case PacketType.PING:
|
||||
return [PingPacket()]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes:
|
||||
|
||||
Reference in New Issue
Block a user