274 lines
8.4 KiB
Python
274 lines
8.4 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from enum import IntEnum
|
|
import json
|
|
from typing import (
|
|
Any,
|
|
Protocol as TypingProtocol,
|
|
)
|
|
|
|
import msgpack_lazer_api as m
|
|
|
|
SEP = b"\x1e"
|
|
|
|
|
|
class PacketType(IntEnum):
|
|
INVOCATION = 1
|
|
STREAM_ITEM = 2
|
|
COMPLETION = 3
|
|
STREAM_INVOCATION = 4
|
|
CANCEL_INVOCATION = 5
|
|
PING = 6
|
|
CLOSE = 7
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class Packet:
|
|
type: PacketType
|
|
header: dict[str, Any] | None = None
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class InvocationPacket(Packet):
|
|
type: PacketType = PacketType.INVOCATION
|
|
invocation_id: str | None
|
|
target: str
|
|
arguments: list[Any] | None = None
|
|
stream_ids: list[str] | None = None
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class CompletionPacket(Packet):
|
|
type: PacketType = PacketType.COMPLETION
|
|
invocation_id: str
|
|
result: Any
|
|
error: str | None = None
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class PingPacket(Packet):
|
|
type: PacketType = PacketType.PING
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class ClosePacket(Packet):
|
|
type: PacketType = PacketType.CLOSE
|
|
error: str | None = None
|
|
allow_reconnect: bool = False
|
|
|
|
|
|
PACKETS = {
|
|
PacketType.INVOCATION: InvocationPacket,
|
|
PacketType.COMPLETION: CompletionPacket,
|
|
PacketType.PING: PingPacket,
|
|
PacketType.CLOSE: ClosePacket,
|
|
}
|
|
|
|
|
|
class Protocol(TypingProtocol):
|
|
@staticmethod
|
|
def decode(input: bytes) -> list[Packet]: ...
|
|
|
|
@staticmethod
|
|
def encode(packet: Packet) -> bytes: ...
|
|
|
|
|
|
class MsgpackProtocol:
|
|
@staticmethod
|
|
def _encode_varint(value: int) -> bytes:
|
|
result = []
|
|
while value >= 0x80:
|
|
result.append((value & 0x7F) | 0x80)
|
|
value >>= 7
|
|
result.append(value & 0x7F)
|
|
return bytes(result)
|
|
|
|
@staticmethod
|
|
def _decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
|
|
result = 0
|
|
shift = 0
|
|
pos = offset
|
|
|
|
while pos < len(data):
|
|
byte = data[pos]
|
|
result |= (byte & 0x7F) << shift
|
|
pos += 1
|
|
if (byte & 0x80) == 0:
|
|
break
|
|
shift += 7
|
|
|
|
return result, pos
|
|
|
|
@staticmethod
|
|
def decode(input: bytes) -> list[Packet]:
|
|
length, offset = MsgpackProtocol._decode_varint(input)
|
|
message_data = input[offset : offset + length]
|
|
unpacked = m.decode(message_data)
|
|
packet_type = PacketType(unpacked[0])
|
|
if packet_type not in PACKETS:
|
|
raise ValueError(f"Unknown packet type: {packet_type}")
|
|
match packet_type:
|
|
case PacketType.INVOCATION:
|
|
return [
|
|
InvocationPacket(
|
|
header=unpacked[1],
|
|
invocation_id=unpacked[2],
|
|
target=unpacked[3],
|
|
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
|
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
|
)
|
|
]
|
|
case PacketType.COMPLETION:
|
|
result_kind = unpacked[3]
|
|
return [
|
|
CompletionPacket(
|
|
header=unpacked[1],
|
|
invocation_id=unpacked[2],
|
|
error=unpacked[4] if result_kind == 1 else None,
|
|
result=unpacked[5] if result_kind == 3 else None,
|
|
)
|
|
]
|
|
case PacketType.PING:
|
|
return [PingPacket()]
|
|
case PacketType.CLOSE:
|
|
return [
|
|
ClosePacket(
|
|
error=unpacked[1],
|
|
allow_reconnect=unpacked[2] if len(unpacked) > 2 else False,
|
|
)
|
|
]
|
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
|
|
|
@staticmethod
|
|
def encode(packet: Packet) -> bytes:
|
|
payload = [packet.type.value, packet.header or {}]
|
|
if isinstance(packet, InvocationPacket):
|
|
payload.extend(
|
|
[
|
|
packet.invocation_id,
|
|
packet.target,
|
|
]
|
|
)
|
|
if packet.arguments is not None:
|
|
payload.append(packet.arguments)
|
|
if packet.stream_ids is not None:
|
|
payload.append(packet.stream_ids)
|
|
elif isinstance(packet, CompletionPacket):
|
|
result_kind = 2
|
|
if packet.error:
|
|
result_kind = 1
|
|
elif packet.result is None:
|
|
result_kind = 3
|
|
payload.extend(
|
|
[
|
|
packet.invocation_id,
|
|
result_kind,
|
|
packet.error or packet.result or None,
|
|
]
|
|
)
|
|
elif isinstance(packet, ClosePacket):
|
|
payload.extend(
|
|
[
|
|
packet.error or "",
|
|
packet.allow_reconnect,
|
|
]
|
|
)
|
|
elif isinstance(packet, PingPacket):
|
|
payload.pop(-1)
|
|
data = m.encode(payload)
|
|
return MsgpackProtocol._encode_varint(len(data)) + data
|
|
|
|
|
|
class JSONProtocol:
|
|
@staticmethod
|
|
def decode(input: bytes) -> list[Packet]:
|
|
packets_raw = input.removesuffix(SEP).split(SEP)
|
|
packets = []
|
|
if len(packets_raw) > 1:
|
|
for packet_raw in packets_raw:
|
|
packets.extend(JSONProtocol.decode(packet_raw))
|
|
return packets
|
|
else:
|
|
data = json.loads(packets_raw[0])
|
|
packet_type = PacketType(data["type"])
|
|
if packet_type not in PACKETS:
|
|
raise ValueError(f"Unknown packet type: {packet_type}")
|
|
match packet_type:
|
|
case PacketType.INVOCATION:
|
|
return [
|
|
InvocationPacket(
|
|
header=data.get("header"),
|
|
invocation_id=data.get("invocationId"),
|
|
target=data["target"],
|
|
arguments=data.get("arguments"),
|
|
stream_ids=data.get("streamIds"),
|
|
)
|
|
]
|
|
case PacketType.COMPLETION:
|
|
return [
|
|
CompletionPacket(
|
|
header=data.get("header"),
|
|
invocation_id=data["invocationId"],
|
|
error=data.get("error"),
|
|
result=data.get("result"),
|
|
)
|
|
]
|
|
case PacketType.PING:
|
|
return [PingPacket()]
|
|
case PacketType.CLOSE:
|
|
return [
|
|
ClosePacket(
|
|
error=data.get("error"),
|
|
allow_reconnect=data.get("allowReconnect", False),
|
|
)
|
|
]
|
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
|
|
|
@staticmethod
|
|
def encode(packet: Packet) -> bytes:
|
|
payload: dict[str, Any] = {
|
|
"type": packet.type.value,
|
|
}
|
|
if packet.header:
|
|
payload["header"] = packet.header
|
|
if isinstance(packet, InvocationPacket):
|
|
payload.update(
|
|
{
|
|
"target": packet.target,
|
|
}
|
|
)
|
|
if packet.invocation_id is not None:
|
|
payload["invocationId"] = packet.invocation_id
|
|
if packet.arguments is not None:
|
|
payload["arguments"] = packet.arguments
|
|
if packet.stream_ids is not None:
|
|
payload["streamIds"] = packet.stream_ids
|
|
elif isinstance(packet, CompletionPacket):
|
|
payload.update(
|
|
{
|
|
"invocationId": packet.invocation_id,
|
|
}
|
|
)
|
|
if packet.error is not None:
|
|
payload["error"] = packet.error
|
|
if packet.result is not None:
|
|
payload["result"] = packet.result
|
|
elif isinstance(packet, PingPacket):
|
|
pass
|
|
elif isinstance(packet, ClosePacket):
|
|
payload.update(
|
|
{
|
|
"allowReconnect": packet.allow_reconnect,
|
|
}
|
|
)
|
|
if packet.error is not None:
|
|
payload["error"] = packet.error
|
|
return json.dumps(payload).encode("utf-8") + SEP
|
|
|
|
|
|
PROTOCOLS: dict[str, Protocol] = {
|
|
"json": JSONProtocol,
|
|
"messagepack": MsgpackProtocol,
|
|
}
|