feat(signalr): support json protocol

This commit is contained in:
MingxuanGame
2025-07-27 11:45:55 +00:00
parent 9e44121427
commit 589927a300
5 changed files with 286 additions and 136 deletions

View File

@@ -20,7 +20,7 @@ class MessagePackArrayModel(BaseModel):
class Transport(BaseModel):
transport: str
transfer_formats: list[str] = Field(
default_factory=lambda: ["Binary"], alias="transferFormats"
default_factory=lambda: ["Binary", "Text"], alias="transferFormats"
)

View File

@@ -8,41 +8,50 @@ from typing import Any
from app.config import settings
from app.signalr.exception import InvokeException
from app.signalr.packet import (
PacketType,
ResultKind,
encode_varint,
parse_packet,
CompletionPacket,
InvocationPacket,
Packet,
PingPacket,
Protocol,
)
from app.signalr.store import ResultStore
from app.signalr.utils import get_signature
from fastapi import WebSocket
import msgpack
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
class Client:
def __init__(
self, connection_id: str, connection_token: str, connection: WebSocket
self,
connection_id: str,
connection_token: str,
connection: WebSocket,
protocol: Protocol,
) -> None:
self.connection_id = connection_id
self.connection_token = connection_token
self.connection = connection
self.procotol = protocol
self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None
self._store = ResultStore()
async def send_packet(self, type: PacketType, packet: list[Any]):
packet.insert(0, type.value)
payload = msgpack.packb(packet)
length = encode_varint(len(payload))
await self.connection.send_bytes(length + payload)
async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.procotol.encode(packet))
async def receive_packet(self) -> Packet:
message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
raise WebSocketDisconnect(code=1008, reason="Empty message received.")
return self.procotol.decode(d)
async def _ping(self):
while True:
try:
await self.send_packet(PacketType.PING, [])
await self.send_packet(PingPacket())
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
except WebSocketDisconnect:
break
@@ -61,7 +70,11 @@ class Hub:
self.waited_clients[connection_token] = timestamp
def add_client(
self, connection_id: str, connection_token: str, connection: WebSocket
self,
connection_id: str,
connection_token: str,
protocol: Protocol,
connection: WebSocket,
) -> Client:
if connection_token in self.clients:
raise ValueError(
@@ -74,7 +87,7 @@ class Hub:
):
raise TimeoutError(f"Connection {connection_id} has waited too long.")
del self.waited_clients[connection_token]
client = Client(connection_id, connection_token, connection)
client = Client(connection_id, connection_token, connection, protocol)
self.clients[connection_token] = client
task = asyncio.create_task(client._ping())
self.tasks.add(task)
@@ -90,8 +103,8 @@ class Hub:
client._ping_task.cancel()
await client.connection.close()
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]):
await client.send_packet(type, packet)
async def send_packet(self, client: Client, packet: Packet) -> None:
await client.send_packet(packet)
async def broadcast_call(self, method: str, *args: Any) -> None:
tasks = []
@@ -103,11 +116,8 @@ class Hub:
jump = False
while not jump:
try:
message = await client.connection.receive_bytes()
packet_type, packet_data = parse_packet(message)
task = asyncio.create_task(
self._handle_packet(client, packet_type, packet_data)
)
packet = await client.receive_packet()
task = asyncio.create_task(self._handle_packet(client, packet))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e:
@@ -121,51 +131,30 @@ class Hub:
jump = True
await self.remove_client(client.connection_id)
async def _handle_packet(
self, client: Client, type: PacketType, packet: list[Any]
) -> None:
match type:
case PacketType.PING:
...
case PacketType.INVOCATION:
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration]
target: str = packet[2]
args: list[Any] | None = packet[3]
if args is None:
args = []
# streams: list[str] | None = packet[4] # TODO: stream support
code = ResultKind.VOID
result = None
try:
result = await self.invoke_method(client, target, args)
if result is not None:
code = ResultKind.HAS_VALUE
except InvokeException as e:
code = ResultKind.ERROR
result = e.message
except Exception as e:
traceback.print_exc()
code = ResultKind.ERROR
result = str(e)
packet = [
{}, # header
invocation_id,
code.value,
]
if result is not None:
packet.append(result)
if invocation_id is not None:
await client.send_packet(
PacketType.COMPLETION,
packet,
async def _handle_packet(self, client: Client, packet: Packet) -> None:
if isinstance(packet, PingPacket):
return
elif isinstance(packet, InvocationPacket):
args = packet.arguments or []
error = None
result = None
try:
result = await self.invoke_method(client, packet.target, args)
except InvokeException as e:
error = e.message
except Exception as e:
traceback.print_exc()
error = str(e)
if packet.invocation_id is not None:
await client.send_packet(
CompletionPacket(
invocation_id=packet.invocation_id,
error=error,
result=result,
)
case PacketType.COMPLETION:
invocation_id: str = packet[1]
code: ResultKind = ResultKind(packet[2])
result: Any = packet[3] if len(packet) > 3 else None
client._store.add_result(invocation_id, code, result)
)
elif isinstance(packet, CompletionPacket):
client._store.add_result(packet.invocation_id, packet.result, packet.error)
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
method_ = getattr(self, method, None)
@@ -185,32 +174,28 @@ class Hub:
async def call(self, client: Client, method: str, *args: Any) -> Any:
invocation_id = client._store.get_invocation_id()
await client.send_packet(
PacketType.INVOCATION,
[
{}, # header
invocation_id,
method,
list(args),
None, # streams
],
InvocationPacket(
header={},
invocation_id=invocation_id,
target=method,
arguments=list(args),
stream_ids=None,
)
)
r = await client._store.fetch(invocation_id, None)
if r[0] == ResultKind.HAS_VALUE:
return r[1]
if r[0] == ResultKind.ERROR:
if r[1]:
raise InvokeException(r[1])
return None
return r[0]
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
await client.send_packet(
PacketType.INVOCATION,
[
{}, # header
None, # invocation_id
method,
list(args),
None, # streams
],
InvocationPacket(
header={},
invocation_id=None,
target=method,
arguments=list(args),
stream_ids=None,
)
)
return None

View File

@@ -1,7 +1,12 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
from typing import Any
import json
from typing import (
Any,
Protocol as TypingProtocol,
)
import msgpack
@@ -18,43 +23,205 @@ class PacketType(IntEnum):
CLOSE = 7
class ResultKind(IntEnum):
ERROR = 1
VOID = 2
HAS_VALUE = 3
@dataclass(kw_only=True)
class Packet:
type: PacketType
header: dict[str, Any] | None = None
def parse_packet(data: bytes) -> tuple[PacketType, list[Any]]:
length, offset = decode_varint(data)
message_data = data[offset : offset + length]
# FIXME: custom deserializer for APIMod
# https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
unpacked = msgpack.unpackb(
message_data, raw=False, strict_map_key=False, use_list=True
)
return PacketType(unpacked[0]), unpacked[1:]
@dataclass(kw_only=True)
class InvocationPacket(Packet):
type: PacketType = PacketType.INVOCATION
invocation_id: str | None
target: str
arguments: list[Any] | None = None
stream_ids: list[str] | None = None
def encode_varint(value: int) -> bytes:
result = []
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
@dataclass(kw_only=True)
class CompletionPacket(Packet):
type: PacketType = PacketType.COMPLETION
invocation_id: str
result: Any
error: str | None = None
def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
result = 0
shift = 0
pos = offset
@dataclass(kw_only=True)
class PingPacket(Packet):
type: PacketType = PacketType.PING
while pos < len(data):
byte = data[pos]
result |= (byte & 0x7F) << shift
pos += 1
if (byte & 0x80) == 0:
break
shift += 7
return result, pos
PACKETS = {
PacketType.INVOCATION: InvocationPacket,
PacketType.COMPLETION: CompletionPacket,
PacketType.PING: PingPacket,
}
class Protocol(TypingProtocol):
@staticmethod
def decode(input: bytes) -> Packet: ...
@staticmethod
def encode(packet: Packet) -> bytes: ...
class MsgpackProtocol:
@staticmethod
def _encode_varint(value: int) -> bytes:
result = []
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
@staticmethod
def _decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
result = 0
shift = 0
pos = offset
while pos < len(data):
byte = data[pos]
result |= (byte & 0x7F) << shift
pos += 1
if (byte & 0x80) == 0:
break
shift += 7
return result, pos
@staticmethod
def decode(input: bytes) -> Packet:
length, offset = MsgpackProtocol._decode_varint(input)
message_data = input[offset : offset + length]
# FIXME: custom deserializer for APIMod
# https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
unpacked = msgpack.unpackb(
message_data, raw=False, strict_map_key=False, use_list=True
)
packet_type = PacketType(unpacked[0])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return InvocationPacket(
header=unpacked[1],
invocation_id=unpacked[2],
target=unpacked[3],
arguments=unpacked[4] if len(unpacked) > 4 else None,
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
)
case PacketType.COMPLETION:
result_kind = unpacked[3]
return CompletionPacket(
header=unpacked[1],
invocation_id=unpacked[2],
error=unpacked[4] if result_kind == 1 else None,
result=unpacked[5] if result_kind == 3 else None,
)
case PacketType.PING:
return PingPacket()
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod
def encode(packet: Packet) -> bytes:
payload = [packet.type.value, packet.header or {}]
if isinstance(packet, InvocationPacket):
payload.extend(
[
packet.invocation_id,
packet.target,
]
)
if packet.arguments is not None:
payload.append(packet.arguments)
if packet.stream_ids is not None:
payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket):
result_kind = 2
if packet.error:
result_kind = 1
elif packet.result is None:
result_kind = 3
payload.extend(
[
packet.invocation_id,
result_kind,
packet.error or packet.result or None,
]
)
elif isinstance(packet, PingPacket):
pass
data = msgpack.packb(payload, use_bin_type=True)
return MsgpackProtocol._encode_varint(len(data)) + data
class JSONProtocol:
@staticmethod
def decode(input: bytes) -> Packet:
data = json.loads(input[:-1].decode("utf-8"))
packet_type = PacketType(data["type"])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return InvocationPacket(
header=data.get("header"),
invocation_id=data.get("invocationId"),
target=data["target"],
arguments=data.get("arguments"),
stream_ids=data.get("streamIds"),
)
case PacketType.COMPLETION:
return CompletionPacket(
header=data.get("header"),
invocation_id=data["invocationId"],
error=data.get("error"),
result=data.get("result"),
)
case PacketType.PING:
return PingPacket()
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod
def encode(packet: Packet) -> bytes:
payload: dict[str, Any] = {
"type": packet.type.value,
}
if packet.header:
payload["header"] = packet.header
if isinstance(packet, InvocationPacket):
payload.update(
{
"target": packet.target,
}
)
if packet.invocation_id is not None:
payload["invocationId"] = packet.invocation_id
if packet.arguments is not None:
payload["arguments"] = packet.arguments
if packet.stream_ids is not None:
payload["streamIds"] = packet.stream_ids
elif isinstance(packet, CompletionPacket):
payload.update(
{
"invocationId": packet.invocation_id,
}
)
if packet.error is not None:
payload["error"] = packet.error
if packet.result is not None:
payload["result"] = packet.result
elif isinstance(packet, PingPacket):
pass
return json.dumps(payload).encode("utf-8")
PROTOCOLS: dict[str, Protocol] = {
"json": JSONProtocol,
"messagepack": MsgpackProtocol,
}

View File

@@ -12,7 +12,7 @@ from app.dependencies.user import get_current_user_by_token
from app.models.signalr import NegotiateResponse, Transport
from .hub import Hubs
from .packet import SEP
from .packet import PROTOCOLS, SEP
from fastapi import APIRouter, Depends, Header, Query, WebSocket
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -62,13 +62,14 @@ async def connect(
await websocket.accept()
# handshake
handshake = await websocket.receive_bytes()
handshake_payload = json.loads(handshake[:-1])
handshake = await websocket.receive()
message = handshake.get("bytes") or handshake.get("text")
if not message:
await websocket.close(code=1008)
return
handshake_payload = json.loads(message[:-1])
error = ""
if (protocol := handshake_payload.get("protocol")) != "messagepack" or (
handshake_payload.get("version")
) != 1:
error = f"Requested protocol '{protocol}' is not available."
protocol = handshake_payload.get("protocol", "json")
client = None
try:
@@ -76,7 +77,10 @@ async def connect(
connection_id=user_id,
connection_token=id,
connection=websocket,
protocol=PROTOCOLS[protocol],
)
except KeyError:
error = f"Protocol '{protocol}' is not supported."
except TimeoutError:
error = f"Connection {id} has waited too long."
except ValueError as e:

View File

@@ -2,9 +2,7 @@ from __future__ import annotations
import asyncio
import sys
from typing import Any, Literal
from .packet import ResultKind
from typing import Any
class ResultStore:
@@ -22,21 +20,17 @@ class ResultStore:
return str(s)
def add_result(
self, invocation_id: str, type: ResultKind, result: dict[str, Any] | None
self, invocation_id: str, result: Any, error: str | None = None
) -> None:
if isinstance(invocation_id, str) and invocation_id.isdecimal():
if future := self._futures.get(invocation_id):
future.set_result((type, result))
future.set_result((result, error))
async def fetch(
self,
invocation_id: str,
timeout: float | None, # noqa: ASYNC109
) -> (
tuple[Literal[ResultKind.ERROR], str]
| tuple[Literal[ResultKind.VOID], None]
| tuple[Literal[ResultKind.HAS_VALUE], Any]
):
) -> tuple[Any, str | None]:
future = asyncio.get_event_loop().create_future()
self._futures[invocation_id] = future
try: