492 lines
17 KiB
Python
492 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import datetime
|
|
from enum import Enum, IntEnum
|
|
import inspect
|
|
import json
|
|
from types import NoneType, UnionType
|
|
from typing import (
|
|
Any,
|
|
Protocol as TypingProtocol,
|
|
Union,
|
|
get_args,
|
|
get_origin,
|
|
)
|
|
|
|
from app.models.signalr import SignalRMeta, SignalRUnionMessage
|
|
from app.utils import camel_to_snake, snake_to_camel
|
|
|
|
import msgpack_lazer_api as m
|
|
from pydantic import BaseModel
|
|
|
|
SEP = b"\x1e"
|
|
|
|
|
|
class PacketType(IntEnum):
|
|
INVOCATION = 1
|
|
STREAM_ITEM = 2
|
|
COMPLETION = 3
|
|
STREAM_INVOCATION = 4
|
|
CANCEL_INVOCATION = 5
|
|
PING = 6
|
|
CLOSE = 7
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class Packet:
|
|
type: PacketType
|
|
header: dict[str, Any] | None = None
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class InvocationPacket(Packet):
|
|
type: PacketType = PacketType.INVOCATION
|
|
invocation_id: str | None
|
|
target: str
|
|
arguments: list[Any] | None = None
|
|
stream_ids: list[str] | None = None
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class CompletionPacket(Packet):
|
|
type: PacketType = PacketType.COMPLETION
|
|
invocation_id: str
|
|
result: Any
|
|
error: str | None = None
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class PingPacket(Packet):
|
|
type: PacketType = PacketType.PING
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class ClosePacket(Packet):
|
|
type: PacketType = PacketType.CLOSE
|
|
error: str | None = None
|
|
allow_reconnect: bool = False
|
|
|
|
|
|
PACKETS = {
|
|
PacketType.INVOCATION: InvocationPacket,
|
|
PacketType.COMPLETION: CompletionPacket,
|
|
PacketType.PING: PingPacket,
|
|
PacketType.CLOSE: ClosePacket,
|
|
}
|
|
|
|
|
|
class Protocol(TypingProtocol):
|
|
@staticmethod
|
|
def decode(input: bytes) -> list[Packet]: ...
|
|
|
|
@staticmethod
|
|
def encode(packet: Packet) -> bytes: ...
|
|
|
|
@classmethod
|
|
def validate_object(cls, v: Any, typ: type) -> Any: ...
|
|
|
|
|
|
class MsgpackProtocol:
|
|
@classmethod
|
|
def serialize_msgpack(cls, v: Any) -> Any:
|
|
typ = v.__class__
|
|
if issubclass(typ, BaseModel):
|
|
return cls.serialize_to_list(v)
|
|
elif issubclass(typ, list):
|
|
return [cls.serialize_msgpack(item) for item in v]
|
|
elif issubclass(typ, datetime.datetime):
|
|
return [v, 0]
|
|
elif isinstance(v, dict):
|
|
return {
|
|
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
|
|
for k, value in v.items()
|
|
}
|
|
elif issubclass(typ, Enum):
|
|
list_ = list(typ)
|
|
return list_.index(v) if v in list_ else v.value
|
|
return v
|
|
|
|
@classmethod
|
|
def serialize_to_list(cls, value: BaseModel) -> list[Any]:
|
|
values = []
|
|
for field, info in value.__class__.model_fields.items():
|
|
metadata = next(
|
|
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
|
)
|
|
if metadata and metadata.member_ignore:
|
|
continue
|
|
values.append(cls.serialize_msgpack(v=getattr(value, field)))
|
|
if issubclass(value.__class__, SignalRUnionMessage):
|
|
return [value.__class__.union_type, values]
|
|
else:
|
|
return values
|
|
|
|
@staticmethod
|
|
def process_object(v: Any, typ: type[BaseModel]) -> Any:
|
|
if isinstance(v, list):
|
|
d = {}
|
|
i = 0
|
|
for field, info in typ.model_fields.items():
|
|
metadata = next(
|
|
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
|
)
|
|
if metadata and metadata.member_ignore:
|
|
continue
|
|
anno = info.annotation
|
|
if anno is None:
|
|
d[camel_to_snake(field)] = v[i]
|
|
else:
|
|
d[field] = MsgpackProtocol.validate_object(v[i], anno)
|
|
i += 1
|
|
return d
|
|
return v
|
|
|
|
@staticmethod
|
|
def _encode_varint(value: int) -> bytes:
|
|
result = []
|
|
while value >= 0x80:
|
|
result.append((value & 0x7F) | 0x80)
|
|
value >>= 7
|
|
result.append(value & 0x7F)
|
|
return bytes(result)
|
|
|
|
@staticmethod
|
|
def _decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
|
|
result = 0
|
|
shift = 0
|
|
pos = offset
|
|
|
|
while pos < len(data):
|
|
byte = data[pos]
|
|
result |= (byte & 0x7F) << shift
|
|
pos += 1
|
|
if (byte & 0x80) == 0:
|
|
break
|
|
shift += 7
|
|
|
|
return result, pos
|
|
|
|
@staticmethod
|
|
def decode(input: bytes) -> list[Packet]:
|
|
length, offset = MsgpackProtocol._decode_varint(input)
|
|
message_data = input[offset : offset + length]
|
|
unpacked = m.decode(message_data)
|
|
packet_type = PacketType(unpacked[0])
|
|
if packet_type not in PACKETS:
|
|
raise ValueError(f"Unknown packet type: {packet_type}")
|
|
match packet_type:
|
|
case PacketType.INVOCATION:
|
|
return [
|
|
InvocationPacket(
|
|
header=unpacked[1],
|
|
invocation_id=unpacked[2],
|
|
target=unpacked[3],
|
|
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
|
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
|
)
|
|
]
|
|
case PacketType.COMPLETION:
|
|
result_kind = unpacked[3]
|
|
return [
|
|
CompletionPacket(
|
|
header=unpacked[1],
|
|
invocation_id=unpacked[2],
|
|
error=unpacked[4] if result_kind == 1 else None,
|
|
result=unpacked[5] if result_kind == 3 else None,
|
|
)
|
|
]
|
|
case PacketType.PING:
|
|
return [PingPacket()]
|
|
case PacketType.CLOSE:
|
|
return [
|
|
ClosePacket(
|
|
error=unpacked[1],
|
|
allow_reconnect=unpacked[2] if len(unpacked) > 2 else False,
|
|
)
|
|
]
|
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
|
|
|
@classmethod
|
|
def validate_object(cls, v: Any, typ: type) -> Any:
|
|
if issubclass(typ, BaseModel):
|
|
return typ.model_validate(obj=cls.process_object(v, typ))
|
|
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
|
return v[0]
|
|
elif isinstance(v, list):
|
|
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
|
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
|
list_ = list(typ)
|
|
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
|
|
elif get_origin(typ) is dict:
|
|
return {
|
|
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
|
|
v, get_args(typ)[1]
|
|
)
|
|
for k, v in v.items()
|
|
}
|
|
elif (origin := get_origin(typ)) is Union or origin is UnionType:
|
|
args = get_args(typ)
|
|
if len(args) == 2 and NoneType in args:
|
|
non_none_args = [arg for arg in args if arg is not NoneType]
|
|
if len(non_none_args) == 1:
|
|
if v is None:
|
|
return None
|
|
return cls.validate_object(v, non_none_args[0])
|
|
|
|
# suppose use `MessagePack-CSharp Union | None`
|
|
# except `X (Other Type) | None`
|
|
if NoneType in args and v is None:
|
|
return None
|
|
if not all(issubclass(arg, SignalRUnionMessage) for arg in args):
|
|
raise ValueError(
|
|
f"Cannot validate {v} to {typ}, "
|
|
"only SignalRUnionMessage subclasses are supported"
|
|
)
|
|
union_type = v[0]
|
|
for arg in args:
|
|
assert issubclass(arg, SignalRUnionMessage)
|
|
if arg.union_type == union_type:
|
|
return cls.validate_object(v[1], arg)
|
|
return v
|
|
|
|
@staticmethod
|
|
def encode(packet: Packet) -> bytes:
|
|
payload = [packet.type.value, packet.header or {}]
|
|
if isinstance(packet, InvocationPacket):
|
|
payload.extend(
|
|
[
|
|
packet.invocation_id,
|
|
packet.target,
|
|
]
|
|
)
|
|
if packet.arguments is not None:
|
|
payload.append(
|
|
[MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments]
|
|
)
|
|
if packet.stream_ids is not None:
|
|
payload.append(packet.stream_ids)
|
|
elif isinstance(packet, CompletionPacket):
|
|
result_kind = 2
|
|
if packet.error:
|
|
result_kind = 1
|
|
elif packet.result is not None:
|
|
result_kind = 3
|
|
payload.extend(
|
|
[
|
|
packet.invocation_id,
|
|
result_kind,
|
|
packet.error
|
|
or MsgpackProtocol.serialize_msgpack(packet.result)
|
|
or None,
|
|
]
|
|
)
|
|
elif isinstance(packet, ClosePacket):
|
|
payload.extend(
|
|
[
|
|
packet.error or "",
|
|
packet.allow_reconnect,
|
|
]
|
|
)
|
|
elif isinstance(packet, PingPacket):
|
|
payload.pop(-1)
|
|
data = m.encode(payload)
|
|
return MsgpackProtocol._encode_varint(len(data)) + data
|
|
|
|
|
|
class JSONProtocol:
|
|
@classmethod
|
|
def serialize_to_json(cls, v: Any):
|
|
typ = v.__class__
|
|
if issubclass(typ, BaseModel):
|
|
return cls.serialize_model(v)
|
|
elif isinstance(v, dict):
|
|
return {
|
|
cls.serialize_to_json(k): cls.serialize_to_json(value)
|
|
for k, value in v.items()
|
|
}
|
|
elif isinstance(v, list):
|
|
return [cls.serialize_to_json(item) for item in v]
|
|
elif isinstance(v, datetime.datetime):
|
|
return v.isoformat()
|
|
elif isinstance(v, Enum):
|
|
return v.value
|
|
return v
|
|
|
|
@classmethod
|
|
def serialize_model(cls, v: BaseModel) -> dict[str, Any]:
|
|
d = {}
|
|
for field, info in v.__class__.model_fields.items():
|
|
metadata = next(
|
|
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
|
)
|
|
if metadata and metadata.json_ignore:
|
|
continue
|
|
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = (
|
|
cls.serialize_to_json(getattr(v, field))
|
|
)
|
|
if issubclass(v.__class__, SignalRUnionMessage):
|
|
return {
|
|
"$dtype": v.__class__.__name__,
|
|
"$value": d,
|
|
}
|
|
return d
|
|
|
|
@staticmethod
|
|
def process_object(
|
|
v: Any, typ: type[BaseModel], from_union: bool = False
|
|
) -> dict[str, Any]:
|
|
d = {}
|
|
for field, info in typ.model_fields.items():
|
|
metadata = next(
|
|
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
|
)
|
|
if metadata and metadata.json_ignore:
|
|
continue
|
|
value = v.get(snake_to_camel(field, not from_union))
|
|
anno = typ.model_fields[field].annotation
|
|
if anno is None:
|
|
d[field] = value
|
|
continue
|
|
d[field] = JSONProtocol.validate_object(value, anno)
|
|
return d
|
|
|
|
@staticmethod
|
|
def decode(input: bytes) -> list[Packet]:
|
|
packets_raw = input.removesuffix(SEP).split(SEP)
|
|
packets = []
|
|
if len(packets_raw) > 1:
|
|
for packet_raw in packets_raw:
|
|
packets.extend(JSONProtocol.decode(packet_raw))
|
|
return packets
|
|
else:
|
|
data = json.loads(packets_raw[0])
|
|
packet_type = PacketType(data["type"])
|
|
if packet_type not in PACKETS:
|
|
raise ValueError(f"Unknown packet type: {packet_type}")
|
|
match packet_type:
|
|
case PacketType.INVOCATION:
|
|
return [
|
|
InvocationPacket(
|
|
header=data.get("header"),
|
|
invocation_id=data.get("invocationId"),
|
|
target=data["target"],
|
|
arguments=data.get("arguments"),
|
|
stream_ids=data.get("streamIds"),
|
|
)
|
|
]
|
|
case PacketType.COMPLETION:
|
|
return [
|
|
CompletionPacket(
|
|
header=data.get("header"),
|
|
invocation_id=data["invocationId"],
|
|
error=data.get("error"),
|
|
result=data.get("result"),
|
|
)
|
|
]
|
|
case PacketType.PING:
|
|
return [PingPacket()]
|
|
case PacketType.CLOSE:
|
|
return [
|
|
ClosePacket(
|
|
error=data.get("error"),
|
|
allow_reconnect=data.get("allowReconnect", False),
|
|
)
|
|
]
|
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
|
|
|
@classmethod
|
|
def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
|
|
if issubclass(typ, BaseModel):
|
|
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
|
|
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
|
return datetime.datetime.fromisoformat(v)
|
|
elif isinstance(v, list):
|
|
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
|
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
|
list_ = list(typ)
|
|
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
|
|
elif get_origin(typ) is dict:
|
|
return {
|
|
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
|
|
v, get_args(typ)[1]
|
|
)
|
|
for k, v in v.items()
|
|
}
|
|
elif (origin := get_origin(typ)) is Union or origin is UnionType:
|
|
args = get_args(typ)
|
|
if len(args) == 2 and NoneType in args:
|
|
non_none_args = [arg for arg in args if arg is not NoneType]
|
|
if len(non_none_args) == 1:
|
|
if v is None:
|
|
return None
|
|
return cls.validate_object(v, non_none_args[0])
|
|
|
|
# suppose use `MessagePack-CSharp Union | None`
|
|
# except `X (Other Type) | None`
|
|
if NoneType in args and v is None:
|
|
return None
|
|
if not all(
|
|
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
|
|
):
|
|
raise ValueError(
|
|
f"Cannot validate {v} to {typ}, "
|
|
"only SignalRUnionMessage subclasses are supported"
|
|
)
|
|
# https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
|
|
union_type = v["$dtype"]
|
|
for arg in args:
|
|
assert issubclass(arg, SignalRUnionMessage)
|
|
if arg.__name__ == union_type:
|
|
return cls.validate_object(v["$value"], arg, True)
|
|
return v
|
|
|
|
@staticmethod
|
|
def encode(packet: Packet) -> bytes:
|
|
payload: dict[str, Any] = {
|
|
"type": packet.type.value,
|
|
}
|
|
if packet.header:
|
|
payload["header"] = packet.header
|
|
if isinstance(packet, InvocationPacket):
|
|
payload.update(
|
|
{
|
|
"target": packet.target,
|
|
}
|
|
)
|
|
if packet.invocation_id is not None:
|
|
payload["invocationId"] = packet.invocation_id
|
|
if packet.arguments is not None:
|
|
payload["arguments"] = [
|
|
JSONProtocol.serialize_to_json(arg) for arg in packet.arguments
|
|
]
|
|
if packet.stream_ids is not None:
|
|
payload["streamIds"] = packet.stream_ids
|
|
elif isinstance(packet, CompletionPacket):
|
|
payload.update(
|
|
{
|
|
"invocationId": packet.invocation_id,
|
|
}
|
|
)
|
|
if packet.error is not None:
|
|
payload["error"] = packet.error
|
|
if packet.result is not None:
|
|
payload["result"] = JSONProtocol.serialize_to_json(packet.result)
|
|
elif isinstance(packet, PingPacket):
|
|
pass
|
|
elif isinstance(packet, ClosePacket):
|
|
payload.update(
|
|
{
|
|
"allowReconnect": packet.allow_reconnect,
|
|
}
|
|
)
|
|
if packet.error is not None:
|
|
payload["error"] = packet.error
|
|
return json.dumps(payload).encode("utf-8") + SEP
|
|
|
|
|
|
PROTOCOLS: dict[str, Protocol] = {
|
|
"json": JSONProtocol,
|
|
"messagepack": MsgpackProtocol,
|
|
}
|