feat(spectator): support spectate solo player
This commit is contained in:
@@ -28,8 +28,7 @@ class _UserActivity(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChoosingBeatmap(_UserActivity):
|
class ChoosingBeatmap(_UserActivity):
|
||||||
type: Literal["ChoosingBeatmap"] = "ChoosingBeatmap"
|
type: Literal["ChoosingBeatmap"] = Field(alias="$dtype")
|
||||||
value: Literal[None] = None
|
|
||||||
|
|
||||||
|
|
||||||
class InGameValue(BaseModel):
|
class InGameValue(BaseModel):
|
||||||
@@ -44,19 +43,19 @@ class _InGame(_UserActivity):
|
|||||||
|
|
||||||
|
|
||||||
class InSoloGame(_InGame):
|
class InSoloGame(_InGame):
|
||||||
type: Literal["InSoloGame"] = "InSoloGame"
|
type: Literal["InSoloGame"] = Field(alias="$dtype")
|
||||||
|
|
||||||
|
|
||||||
class InMultiplayerGame(_InGame):
|
class InMultiplayerGame(_InGame):
|
||||||
type: Literal["InMultiplayerGame"] = "InMultiplayerGame"
|
type: Literal["InMultiplayerGame"] = Field(alias="$dtype")
|
||||||
|
|
||||||
|
|
||||||
class SpectatingMultiplayerGame(_InGame):
|
class SpectatingMultiplayerGame(_InGame):
|
||||||
type: Literal["SpectatingMultiplayerGame"] = "SpectatingMultiplayerGame"
|
type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype")
|
||||||
|
|
||||||
|
|
||||||
class InPlaylistGame(_InGame):
|
class InPlaylistGame(_InGame):
|
||||||
type: Literal["InPlaylistGame"] = "InPlaylistGame"
|
type: Literal["InPlaylistGame"] = Field(alias="$dtype")
|
||||||
|
|
||||||
|
|
||||||
class EditingBeatmapValue(BaseModel):
|
class EditingBeatmapValue(BaseModel):
|
||||||
@@ -65,16 +64,16 @@ class EditingBeatmapValue(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class EditingBeatmap(_UserActivity):
|
class EditingBeatmap(_UserActivity):
|
||||||
type: Literal["EditingBeatmap"] = "EditingBeatmap"
|
type: Literal["EditingBeatmap"] = Field(alias="$dtype")
|
||||||
value: EditingBeatmapValue = Field(alias="$value")
|
value: EditingBeatmapValue = Field(alias="$value")
|
||||||
|
|
||||||
|
|
||||||
class TestingBeatmap(_UserActivity):
|
class TestingBeatmap(_UserActivity):
|
||||||
type: Literal["TestingBeatmap"] = "TestingBeatmap"
|
type: Literal["TestingBeatmap"] = Field(alias="$dtype")
|
||||||
|
|
||||||
|
|
||||||
class ModdingBeatmap(_UserActivity):
|
class ModdingBeatmap(_UserActivity):
|
||||||
type: Literal["ModdingBeatmap"] = "ModdingBeatmap"
|
type: Literal["ModdingBeatmap"] = Field(alias="$dtype")
|
||||||
|
|
||||||
|
|
||||||
class WatchingReplayValue(BaseModel):
|
class WatchingReplayValue(BaseModel):
|
||||||
@@ -85,17 +84,16 @@ class WatchingReplayValue(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class WatchingReplay(_UserActivity):
|
class WatchingReplay(_UserActivity):
|
||||||
type: Literal["WatchingReplay"] = "WatchingReplay"
|
type: Literal["WatchingReplay"] = Field(alias="$dtype")
|
||||||
value: int | None = Field(alias="$value") # Replay ID
|
value: int | None = Field(alias="$value") # Replay ID
|
||||||
|
|
||||||
|
|
||||||
class SpectatingUser(WatchingReplay):
|
class SpectatingUser(WatchingReplay):
|
||||||
type: Literal["SpectatingUser"] = "SpectatingUser"
|
type: Literal["SpectatingUser"] = Field(alias="$dtype")
|
||||||
|
|
||||||
|
|
||||||
class SearchingForLobby(_UserActivity):
|
class SearchingForLobby(_UserActivity):
|
||||||
type: Literal["SearchingForLobby"] = "SearchingForLobby"
|
type: Literal["SearchingForLobby"] = Field(alias="$dtype")
|
||||||
value: None = Field(alias="$value")
|
|
||||||
|
|
||||||
|
|
||||||
class InLobbyValue(BaseModel):
|
class InLobbyValue(BaseModel):
|
||||||
@@ -105,12 +103,10 @@ class InLobbyValue(BaseModel):
|
|||||||
|
|
||||||
class InLobby(_UserActivity):
|
class InLobby(_UserActivity):
|
||||||
type: Literal["InLobby"] = "InLobby"
|
type: Literal["InLobby"] = "InLobby"
|
||||||
value: None = Field(alias="$value")
|
|
||||||
|
|
||||||
|
|
||||||
class InDailyChallengeLobby(_UserActivity):
|
class InDailyChallengeLobby(_UserActivity):
|
||||||
type: Literal["InDailyChallengeLobby"] = "InDailyChallengeLobby"
|
type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype")
|
||||||
value: None = Field(alias="$value")
|
|
||||||
|
|
||||||
|
|
||||||
UserActivity = (
|
UserActivity = (
|
||||||
|
|||||||
@@ -1,11 +1,42 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
import datetime
|
||||||
|
from typing import Any, get_origin
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
import msgpack
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
TypeAdapter,
|
||||||
|
model_serializer,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_to_list(value: BaseModel) -> list[Any]:
|
||||||
|
data = []
|
||||||
|
for field, info in value.__class__.model_fields.items():
|
||||||
|
v = getattr(value, field)
|
||||||
|
anno = get_origin(info.annotation)
|
||||||
|
if anno and issubclass(anno, BaseModel):
|
||||||
|
data.append(serialize_to_list(v))
|
||||||
|
elif anno and issubclass(anno, list):
|
||||||
|
data.append(
|
||||||
|
TypeAdapter(
|
||||||
|
info.annotation,
|
||||||
|
).dump_python(v)
|
||||||
|
)
|
||||||
|
elif isinstance(v, datetime.datetime):
|
||||||
|
data.append([msgpack.ext.Timestamp.from_datetime(v), 0])
|
||||||
|
else:
|
||||||
|
data.append(v)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class MessagePackArrayModel(BaseModel):
|
class MessagePackArrayModel(BaseModel):
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def unpack(cls, v: Any) -> Any:
|
def unpack(cls, v: Any) -> Any:
|
||||||
@@ -16,6 +47,10 @@ class MessagePackArrayModel(BaseModel):
|
|||||||
return dict(zip(fields, v))
|
return dict(zip(fields, v))
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@model_serializer
|
||||||
|
def serialize(self) -> list[Any]:
|
||||||
|
return serialize_to_list(self)
|
||||||
|
|
||||||
|
|
||||||
class Transport(BaseModel):
|
class Transport(BaseModel):
|
||||||
transport: str
|
transport: str
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from .score import (
|
|||||||
from .signalr import MessagePackArrayModel
|
from .signalr import MessagePackArrayModel
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
class APIMod(MessagePackArrayModel):
|
class APIMod(MessagePackArrayModel):
|
||||||
@@ -58,8 +58,6 @@ class ScoreProcessorStatistics(MessagePackArrayModel):
|
|||||||
|
|
||||||
|
|
||||||
class FrameHeader(MessagePackArrayModel):
|
class FrameHeader(MessagePackArrayModel):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
total_score: int
|
total_score: int
|
||||||
acc: float
|
acc: float
|
||||||
combo: int
|
combo: int
|
||||||
@@ -84,25 +82,21 @@ class FrameHeader(MessagePackArrayModel):
|
|||||||
return datetime.datetime.fromisoformat(v)
|
return datetime.datetime.fromisoformat(v)
|
||||||
raise ValueError(f"Cannot convert {type(v)} to datetime")
|
raise ValueError(f"Cannot convert {type(v)} to datetime")
|
||||||
|
|
||||||
@field_serializer("received_time")
|
|
||||||
def serialize_received_time(self, v: datetime.datetime) -> msgpack.ext.Timestamp:
|
|
||||||
return msgpack.ext.Timestamp.from_datetime(v)
|
|
||||||
|
|
||||||
|
# class ReplayButtonState(IntEnum):
|
||||||
class ReplayButtonState(IntEnum):
|
# NONE = 0
|
||||||
NONE = 0
|
# LEFT1 = 1
|
||||||
LEFT1 = 1
|
# RIGHT1 = 2
|
||||||
RIGHT1 = 2
|
# LEFT2 = 4
|
||||||
LEFT2 = 4
|
# RIGHT2 = 8
|
||||||
RIGHT2 = 8
|
# SMOKE = 16
|
||||||
SMOKE = 16
|
|
||||||
|
|
||||||
|
|
||||||
class LegacyReplayFrame(MessagePackArrayModel):
|
class LegacyReplayFrame(MessagePackArrayModel):
|
||||||
time: float # from ReplayFrame,the parent of LegacyReplayFrame
|
time: float # from ReplayFrame,the parent of LegacyReplayFrame
|
||||||
x: float | None = None
|
x: float | None = None
|
||||||
y: float | None = None
|
y: float | None = None
|
||||||
button_state: ReplayButtonState
|
button_state: int
|
||||||
|
|
||||||
|
|
||||||
class FrameDataBundle(MessagePackArrayModel):
|
class FrameDataBundle(MessagePackArrayModel):
|
||||||
@@ -135,10 +129,10 @@ class StoreScore(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class StoreClientState(BaseModel):
|
class StoreClientState(BaseModel):
|
||||||
state: SpectatorState | None
|
state: SpectatorState | None = None
|
||||||
beatmap_status: BeatmapRankStatus
|
beatmap_status: BeatmapRankStatus | None = None
|
||||||
checksum: str
|
checksum: str | None = None
|
||||||
ruleset_id: int
|
ruleset_id: int | None = None
|
||||||
score_token: int
|
score_token: int | None = None
|
||||||
watched_user: set[int]
|
watched_user: set[int] = Field(default_factory=set)
|
||||||
score: StoreScore
|
score: StoreScore | None = None
|
||||||
|
|||||||
@@ -44,11 +44,11 @@ class Client:
|
|||||||
async def send_packet(self, packet: Packet):
|
async def send_packet(self, packet: Packet):
|
||||||
await self.connection.send_bytes(self.procotol.encode(packet))
|
await self.connection.send_bytes(self.procotol.encode(packet))
|
||||||
|
|
||||||
async def receive_packet(self) -> Packet:
|
async def receive_packets(self) -> list[Packet]:
|
||||||
message = await self.connection.receive()
|
message = await self.connection.receive()
|
||||||
d = message.get("bytes") or message.get("text", "").encode()
|
d = message.get("bytes") or message.get("text", "").encode()
|
||||||
if not d:
|
if not d:
|
||||||
return PingPacket() # FIXME: Graceful empty message handling
|
return [PingPacket()] # FIXME: Graceful empty message handling
|
||||||
return self.procotol.decode(d)
|
return self.procotol.decode(d)
|
||||||
|
|
||||||
async def _ping(self):
|
async def _ping(self):
|
||||||
@@ -138,12 +138,13 @@ class Hub:
|
|||||||
jump = False
|
jump = False
|
||||||
while not jump:
|
while not jump:
|
||||||
try:
|
try:
|
||||||
packet = await client.receive_packet()
|
packets = await client.receive_packets()
|
||||||
task = asyncio.create_task(self._handle_packet(client, packet))
|
for packet in packets:
|
||||||
self.tasks.add(task)
|
if isinstance(packet, PingPacket):
|
||||||
task.add_done_callback(self.tasks.discard)
|
continue
|
||||||
except StopIteration:
|
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||||
pass
|
self.tasks.add(task)
|
||||||
|
task.add_done_callback(self.tasks.discard)
|
||||||
except WebSocketDisconnect as e:
|
except WebSocketDisconnect as e:
|
||||||
print(
|
print(
|
||||||
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
||||||
|
|||||||
@@ -106,11 +106,12 @@ class MetadataHub(Hub):
|
|||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
|
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
|
||||||
if activity_dict is None:
|
|
||||||
# idle
|
|
||||||
return
|
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
activity = TypeAdapter(UserActivity).validate_python(activity_dict)
|
activity = (
|
||||||
|
TypeAdapter(UserActivity).validate_python(activity_dict)
|
||||||
|
if activity_dict
|
||||||
|
else None
|
||||||
|
)
|
||||||
store = self.state.get(user_id)
|
store = self.state.get(user_id)
|
||||||
if store:
|
if store:
|
||||||
store.user_activity = activity
|
store.user_activity = activity
|
||||||
@@ -119,7 +120,6 @@ class MetadataHub(Hub):
|
|||||||
user_activity=activity,
|
user_activity=activity,
|
||||||
)
|
)
|
||||||
self.state[user_id] = store
|
self.state[user_id] = store
|
||||||
|
|
||||||
tasks = self.broadcast_tasks(user_id, store)
|
tasks = self.broadcast_tasks(user_id, store)
|
||||||
tasks.add(
|
tasks.add(
|
||||||
self.call_noblock(
|
self.call_noblock(
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import lzma
|
import lzma
|
||||||
import struct
|
import struct
|
||||||
@@ -13,6 +14,7 @@ from app.dependencies.database import engine
|
|||||||
from app.models.beatmap import BeatmapRankStatus
|
from app.models.beatmap import BeatmapRankStatus
|
||||||
from app.models.mods import mods_to_int
|
from app.models.mods import mods_to_int
|
||||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
|
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
|
||||||
|
from app.models.signalr import serialize_to_list
|
||||||
from app.models.spectator_hub import (
|
from app.models.spectator_hub import (
|
||||||
APIUser,
|
APIUser,
|
||||||
FrameDataBundle,
|
FrameDataBundle,
|
||||||
@@ -106,7 +108,7 @@ def save_replay(
|
|||||||
for frame in frames:
|
for frame in frames:
|
||||||
frame_strs.append(
|
frame_strs.append(
|
||||||
f"{frame.time - last_time}|{frame.x or 0.0}"
|
f"{frame.time - last_time}|{frame.x or 0.0}"
|
||||||
f"|{frame.y or 0.0}|{frame.button_state.value}"
|
f"|{frame.y or 0.0}|{frame.button_state}"
|
||||||
)
|
)
|
||||||
last_time = frame.time
|
last_time = frame.time
|
||||||
frame_strs.append("-12345|0|0|0")
|
frame_strs.append("-12345|0|0|0")
|
||||||
@@ -143,6 +145,20 @@ class SpectatorHub(Hub):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.state: dict[int, StoreClientState] = {}
|
self.state: dict[int, StoreClientState] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def group_id(user_id: int) -> str:
|
||||||
|
return f"watch:{user_id}"
|
||||||
|
|
||||||
|
async def on_client_connect(self, client: Client) -> None:
|
||||||
|
tasks = [
|
||||||
|
self.call_noblock(
|
||||||
|
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
|
||||||
|
)
|
||||||
|
for user_id, store in self.state.items()
|
||||||
|
if store.state is not None
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def BeginPlaySession(
|
async def BeginPlaySession(
|
||||||
self, client: Client, score_token: int, state: SpectatorState
|
self, client: Client, score_token: int, state: SpectatorState
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -184,7 +200,12 @@ class SpectatorHub(Hub):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.state[user_id] = store
|
self.state[user_id] = store
|
||||||
await self.broadcast_call("UserBeganPlaying", user_id, state.model_dump())
|
await self.broadcast_group_call(
|
||||||
|
self.group_id(user_id),
|
||||||
|
"UserBeganPlaying",
|
||||||
|
user_id,
|
||||||
|
serialize_to_list(state),
|
||||||
|
)
|
||||||
|
|
||||||
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
@@ -201,14 +222,14 @@ class SpectatorHub(Hub):
|
|||||||
score.score_info.total_score = frame_data.header.total_score
|
score.score_info.total_score = frame_data.header.total_score
|
||||||
score.score_info.mods = frame_data.header.mods
|
score.score_info.mods = frame_data.header.mods
|
||||||
score.replay_frames.extend(frame_data.frames)
|
score.replay_frames.extend(frame_data.frames)
|
||||||
await self.broadcast_call(
|
await self.broadcast_group_call(
|
||||||
|
self.group_id(user_id),
|
||||||
"UserSentFrames",
|
"UserSentFrames",
|
||||||
user_id,
|
user_id,
|
||||||
frame_data.model_dump(),
|
frame_data.model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
||||||
print(f"EndPlaySession -> {client.connection_id} {state.model_dump()!r}")
|
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
store = self.state.get(user_id)
|
store = self.state.get(user_id)
|
||||||
if not store:
|
if not store:
|
||||||
@@ -217,7 +238,13 @@ class SpectatorHub(Hub):
|
|||||||
if not score or not store.score_token:
|
if not score or not store.score_token:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
assert store.beatmap_status is not None
|
||||||
|
|
||||||
async def _save_replay():
|
async def _save_replay():
|
||||||
|
assert store.checksum is not None
|
||||||
|
assert store.ruleset_id is not None
|
||||||
|
assert store.state is not None
|
||||||
|
assert store.score is not None
|
||||||
async with AsyncSession(engine) as session:
|
async with AsyncSession(engine) as session:
|
||||||
async with session:
|
async with session:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -271,8 +298,51 @@ class SpectatorHub(Hub):
|
|||||||
del self.state[user_id]
|
del self.state[user_id]
|
||||||
if state.state == SpectatedUserState.Playing:
|
if state.state == SpectatedUserState.Playing:
|
||||||
state.state = SpectatedUserState.Quit
|
state.state = SpectatedUserState.Quit
|
||||||
await self.broadcast_call(
|
await self.broadcast_group_call(
|
||||||
"UserEndedPlaying",
|
self.group_id(user_id),
|
||||||
|
"UserFinishedPlaying",
|
||||||
user_id,
|
user_id,
|
||||||
state.model_dump(),
|
serialize_to_list(state) if state else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
||||||
|
print(f"StartWatchingUser -> {client.connection_id} {target_id}")
|
||||||
|
user_id = int(client.connection_id)
|
||||||
|
target_store = self.state.get(target_id)
|
||||||
|
if target_store and target_store.state:
|
||||||
|
await self.call_noblock(
|
||||||
|
client,
|
||||||
|
"UserBeganPlaying",
|
||||||
|
target_id,
|
||||||
|
serialize_to_list(target_store.state),
|
||||||
|
)
|
||||||
|
store = self.state.get(user_id)
|
||||||
|
if store is None:
|
||||||
|
store = StoreClientState(
|
||||||
|
watched_user=set(),
|
||||||
|
)
|
||||||
|
store.watched_user.add(target_id)
|
||||||
|
self.state[user_id] = store
|
||||||
|
self.groups.setdefault(self.group_id(target_id), set()).add(client)
|
||||||
|
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
async with session.begin():
|
||||||
|
username = (
|
||||||
|
await session.exec(select(User.name).where(User.id == user_id))
|
||||||
|
).first()
|
||||||
|
if not username:
|
||||||
|
return
|
||||||
|
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||||
|
await self.call_noblock(
|
||||||
|
target_client, "UserStartedWatching", [[user_id, username]]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
|
||||||
|
print(f"EndWatchingUser -> {client.connection_id} {target_id}")
|
||||||
|
user_id = int(client.connection_id)
|
||||||
|
self.groups[self.group_id(target_id)].discard(client)
|
||||||
|
store = self.state.get(user_id)
|
||||||
|
if store:
|
||||||
|
store.watched_user.discard(target_id)
|
||||||
|
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||||
|
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ PACKETS = {
|
|||||||
|
|
||||||
class Protocol(TypingProtocol):
|
class Protocol(TypingProtocol):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def decode(input: bytes) -> Packet: ...
|
def decode(input: bytes) -> list[Packet]: ...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encode(packet: Packet) -> bytes: ...
|
def encode(packet: Packet) -> bytes: ...
|
||||||
@@ -93,7 +93,7 @@ class MsgpackProtocol:
|
|||||||
return result, pos
|
return result, pos
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def decode(input: bytes) -> Packet:
|
def decode(input: bytes) -> list[Packet]:
|
||||||
length, offset = MsgpackProtocol._decode_varint(input)
|
length, offset = MsgpackProtocol._decode_varint(input)
|
||||||
message_data = input[offset : offset + length]
|
message_data = input[offset : offset + length]
|
||||||
# FIXME: custom deserializer for APIMod
|
# FIXME: custom deserializer for APIMod
|
||||||
@@ -106,23 +106,27 @@ class MsgpackProtocol:
|
|||||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||||
match packet_type:
|
match packet_type:
|
||||||
case PacketType.INVOCATION:
|
case PacketType.INVOCATION:
|
||||||
return InvocationPacket(
|
return [
|
||||||
header=unpacked[1],
|
InvocationPacket(
|
||||||
invocation_id=unpacked[2],
|
header=unpacked[1],
|
||||||
target=unpacked[3],
|
invocation_id=unpacked[2],
|
||||||
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
target=unpacked[3],
|
||||||
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
||||||
)
|
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
||||||
|
)
|
||||||
|
]
|
||||||
case PacketType.COMPLETION:
|
case PacketType.COMPLETION:
|
||||||
result_kind = unpacked[3]
|
result_kind = unpacked[3]
|
||||||
return CompletionPacket(
|
return [
|
||||||
header=unpacked[1],
|
CompletionPacket(
|
||||||
invocation_id=unpacked[2],
|
header=unpacked[1],
|
||||||
error=unpacked[4] if result_kind == 1 else None,
|
invocation_id=unpacked[2],
|
||||||
result=unpacked[5] if result_kind == 3 else None,
|
error=unpacked[4] if result_kind == 1 else None,
|
||||||
)
|
result=unpacked[5] if result_kind == 3 else None,
|
||||||
|
)
|
||||||
|
]
|
||||||
case PacketType.PING:
|
case PacketType.PING:
|
||||||
return PingPacket()
|
return [PingPacket()]
|
||||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -153,38 +157,48 @@ class MsgpackProtocol:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
elif isinstance(packet, PingPacket):
|
elif isinstance(packet, PingPacket):
|
||||||
pass
|
payload.pop(-1)
|
||||||
|
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
|
||||||
data = msgpack.packb(payload, use_bin_type=True)
|
|
||||||
return MsgpackProtocol._encode_varint(len(data)) + data
|
return MsgpackProtocol._encode_varint(len(data)) + data
|
||||||
|
|
||||||
|
|
||||||
class JSONProtocol:
|
class JSONProtocol:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def decode(input: bytes) -> Packet:
|
def decode(input: bytes) -> list[Packet]:
|
||||||
data = json.loads(input[:-1].decode("utf-8"))
|
packets_raw = input.removesuffix(SEP).split(SEP)
|
||||||
packet_type = PacketType(data["type"])
|
packets = []
|
||||||
if packet_type not in PACKETS:
|
if len(packets_raw) > 1:
|
||||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
for packet_raw in packets_raw:
|
||||||
match packet_type:
|
packets.extend(JSONProtocol.decode(packet_raw))
|
||||||
case PacketType.INVOCATION:
|
return packets
|
||||||
return InvocationPacket(
|
else:
|
||||||
header=data.get("header"),
|
data = json.loads(packets_raw[0])
|
||||||
invocation_id=data.get("invocationId"),
|
packet_type = PacketType(data["type"])
|
||||||
target=data["target"],
|
if packet_type not in PACKETS:
|
||||||
arguments=data.get("arguments"),
|
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||||
stream_ids=data.get("streamIds"),
|
match packet_type:
|
||||||
)
|
case PacketType.INVOCATION:
|
||||||
case PacketType.COMPLETION:
|
return [
|
||||||
return CompletionPacket(
|
InvocationPacket(
|
||||||
header=data.get("header"),
|
header=data.get("header"),
|
||||||
invocation_id=data["invocationId"],
|
invocation_id=data.get("invocationId"),
|
||||||
error=data.get("error"),
|
target=data["target"],
|
||||||
result=data.get("result"),
|
arguments=data.get("arguments"),
|
||||||
)
|
stream_ids=data.get("streamIds"),
|
||||||
case PacketType.PING:
|
)
|
||||||
return PingPacket()
|
]
|
||||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
case PacketType.COMPLETION:
|
||||||
|
return [
|
||||||
|
CompletionPacket(
|
||||||
|
header=data.get("header"),
|
||||||
|
invocation_id=data["invocationId"],
|
||||||
|
error=data.get("error"),
|
||||||
|
result=data.get("result"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
case PacketType.PING:
|
||||||
|
return [PingPacket()]
|
||||||
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encode(packet: Packet) -> bytes:
|
def encode(packet: Packet) -> bytes:
|
||||||
|
|||||||
Reference in New Issue
Block a user