feat(signalr): support json protocol

This commit is contained in:
MingxuanGame
2025-07-27 11:45:55 +00:00
parent 9e44121427
commit 589927a300
5 changed files with 286 additions and 136 deletions

View File

@@ -20,7 +20,7 @@ class MessagePackArrayModel(BaseModel):
class Transport(BaseModel): class Transport(BaseModel):
transport: str transport: str
transfer_formats: list[str] = Field( transfer_formats: list[str] = Field(
default_factory=lambda: ["Binary"], alias="transferFormats" default_factory=lambda: ["Binary", "Text"], alias="transferFormats"
) )

View File

@@ -8,41 +8,50 @@ from typing import Any
from app.config import settings from app.config import settings
from app.signalr.exception import InvokeException from app.signalr.exception import InvokeException
from app.signalr.packet import ( from app.signalr.packet import (
PacketType, CompletionPacket,
ResultKind, InvocationPacket,
encode_varint, Packet,
parse_packet, PingPacket,
Protocol,
) )
from app.signalr.store import ResultStore from app.signalr.store import ResultStore
from app.signalr.utils import get_signature from app.signalr.utils import get_signature
from fastapi import WebSocket from fastapi import WebSocket
import msgpack
from pydantic import BaseModel from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect from starlette.websockets import WebSocketDisconnect
class Client: class Client:
def __init__( def __init__(
self, connection_id: str, connection_token: str, connection: WebSocket self,
connection_id: str,
connection_token: str,
connection: WebSocket,
protocol: Protocol,
) -> None: ) -> None:
self.connection_id = connection_id self.connection_id = connection_id
self.connection_token = connection_token self.connection_token = connection_token
self.connection = connection self.connection = connection
self.procotol = protocol
self._listen_task: asyncio.Task | None = None self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None self._ping_task: asyncio.Task | None = None
self._store = ResultStore() self._store = ResultStore()
async def send_packet(self, type: PacketType, packet: list[Any]): async def send_packet(self, packet: Packet):
packet.insert(0, type.value) await self.connection.send_bytes(self.procotol.encode(packet))
payload = msgpack.packb(packet)
length = encode_varint(len(payload)) async def receive_packet(self) -> Packet:
await self.connection.send_bytes(length + payload) message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
raise WebSocketDisconnect(code=1008, reason="Empty message received.")
return self.procotol.decode(d)
async def _ping(self): async def _ping(self):
while True: while True:
try: try:
await self.send_packet(PacketType.PING, []) await self.send_packet(PingPacket())
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL) await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
except WebSocketDisconnect: except WebSocketDisconnect:
break break
@@ -61,7 +70,11 @@ class Hub:
self.waited_clients[connection_token] = timestamp self.waited_clients[connection_token] = timestamp
def add_client( def add_client(
self, connection_id: str, connection_token: str, connection: WebSocket self,
connection_id: str,
connection_token: str,
protocol: Protocol,
connection: WebSocket,
) -> Client: ) -> Client:
if connection_token in self.clients: if connection_token in self.clients:
raise ValueError( raise ValueError(
@@ -74,7 +87,7 @@ class Hub:
): ):
raise TimeoutError(f"Connection {connection_id} has waited too long.") raise TimeoutError(f"Connection {connection_id} has waited too long.")
del self.waited_clients[connection_token] del self.waited_clients[connection_token]
client = Client(connection_id, connection_token, connection) client = Client(connection_id, connection_token, connection, protocol)
self.clients[connection_token] = client self.clients[connection_token] = client
task = asyncio.create_task(client._ping()) task = asyncio.create_task(client._ping())
self.tasks.add(task) self.tasks.add(task)
@@ -90,8 +103,8 @@ class Hub:
client._ping_task.cancel() client._ping_task.cancel()
await client.connection.close() await client.connection.close()
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]): async def send_packet(self, client: Client, packet: Packet) -> None:
await client.send_packet(type, packet) await client.send_packet(packet)
async def broadcast_call(self, method: str, *args: Any) -> None: async def broadcast_call(self, method: str, *args: Any) -> None:
tasks = [] tasks = []
@@ -103,11 +116,8 @@ class Hub:
jump = False jump = False
while not jump: while not jump:
try: try:
message = await client.connection.receive_bytes() packet = await client.receive_packet()
packet_type, packet_data = parse_packet(message) task = asyncio.create_task(self._handle_packet(client, packet))
task = asyncio.create_task(
self._handle_packet(client, packet_type, packet_data)
)
self.tasks.add(task) self.tasks.add(task)
task.add_done_callback(self.tasks.discard) task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e: except WebSocketDisconnect as e:
@@ -121,51 +131,30 @@ class Hub:
jump = True jump = True
await self.remove_client(client.connection_id) await self.remove_client(client.connection_id)
async def _handle_packet( async def _handle_packet(self, client: Client, packet: Packet) -> None:
self, client: Client, type: PacketType, packet: list[Any] if isinstance(packet, PingPacket):
) -> None: return
match type: elif isinstance(packet, InvocationPacket):
case PacketType.PING: args = packet.arguments or []
... error = None
case PacketType.INVOCATION: result = None
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration] try:
target: str = packet[2] result = await self.invoke_method(client, packet.target, args)
args: list[Any] | None = packet[3] except InvokeException as e:
if args is None: error = e.message
args = [] except Exception as e:
# streams: list[str] | None = packet[4] # TODO: stream support traceback.print_exc()
code = ResultKind.VOID error = str(e)
result = None if packet.invocation_id is not None:
try: await client.send_packet(
result = await self.invoke_method(client, target, args) CompletionPacket(
if result is not None: invocation_id=packet.invocation_id,
code = ResultKind.HAS_VALUE error=error,
except InvokeException as e: result=result,
code = ResultKind.ERROR
result = e.message
except Exception as e:
traceback.print_exc()
code = ResultKind.ERROR
result = str(e)
packet = [
{}, # header
invocation_id,
code.value,
]
if result is not None:
packet.append(result)
if invocation_id is not None:
await client.send_packet(
PacketType.COMPLETION,
packet,
) )
case PacketType.COMPLETION: )
invocation_id: str = packet[1] elif isinstance(packet, CompletionPacket):
code: ResultKind = ResultKind(packet[2]) client._store.add_result(packet.invocation_id, packet.result, packet.error)
result: Any = packet[3] if len(packet) > 3 else None
client._store.add_result(invocation_id, code, result)
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any: async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
method_ = getattr(self, method, None) method_ = getattr(self, method, None)
@@ -185,32 +174,28 @@ class Hub:
async def call(self, client: Client, method: str, *args: Any) -> Any: async def call(self, client: Client, method: str, *args: Any) -> Any:
invocation_id = client._store.get_invocation_id() invocation_id = client._store.get_invocation_id()
await client.send_packet( await client.send_packet(
PacketType.INVOCATION, InvocationPacket(
[ header={},
{}, # header invocation_id=invocation_id,
invocation_id, target=method,
method, arguments=list(args),
list(args), stream_ids=None,
None, # streams )
],
) )
r = await client._store.fetch(invocation_id, None) r = await client._store.fetch(invocation_id, None)
if r[0] == ResultKind.HAS_VALUE: if r[1]:
return r[1]
if r[0] == ResultKind.ERROR:
raise InvokeException(r[1]) raise InvokeException(r[1])
return None return r[0]
async def call_noblock(self, client: Client, method: str, *args: Any) -> None: async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
await client.send_packet( await client.send_packet(
PacketType.INVOCATION, InvocationPacket(
[ header={},
{}, # header invocation_id=None,
None, # invocation_id target=method,
method, arguments=list(args),
list(args), stream_ids=None,
None, # streams )
],
) )
return None return None

View File

@@ -1,7 +1,12 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from typing import Any import json
from typing import (
Any,
Protocol as TypingProtocol,
)
import msgpack import msgpack
@@ -18,43 +23,205 @@ class PacketType(IntEnum):
CLOSE = 7 CLOSE = 7
class ResultKind(IntEnum): @dataclass(kw_only=True)
ERROR = 1 class Packet:
VOID = 2 type: PacketType
HAS_VALUE = 3 header: dict[str, Any] | None = None
def parse_packet(data: bytes) -> tuple[PacketType, list[Any]]: @dataclass(kw_only=True)
length, offset = decode_varint(data) class InvocationPacket(Packet):
message_data = data[offset : offset + length] type: PacketType = PacketType.INVOCATION
# FIXME: custom deserializer for APIMod invocation_id: str | None
# https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs target: str
unpacked = msgpack.unpackb( arguments: list[Any] | None = None
message_data, raw=False, strict_map_key=False, use_list=True stream_ids: list[str] | None = None
)
return PacketType(unpacked[0]), unpacked[1:]
def encode_varint(value: int) -> bytes: @dataclass(kw_only=True)
result = [] class CompletionPacket(Packet):
while value >= 0x80: type: PacketType = PacketType.COMPLETION
result.append((value & 0x7F) | 0x80) invocation_id: str
value >>= 7 result: Any
result.append(value & 0x7F) error: str | None = None
return bytes(result)
def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]: @dataclass(kw_only=True)
result = 0 class PingPacket(Packet):
shift = 0 type: PacketType = PacketType.PING
pos = offset
while pos < len(data):
byte = data[pos]
result |= (byte & 0x7F) << shift
pos += 1
if (byte & 0x80) == 0:
break
shift += 7
return result, pos PACKETS = {
PacketType.INVOCATION: InvocationPacket,
PacketType.COMPLETION: CompletionPacket,
PacketType.PING: PingPacket,
}
class Protocol(TypingProtocol):
@staticmethod
def decode(input: bytes) -> Packet: ...
@staticmethod
def encode(packet: Packet) -> bytes: ...
class MsgpackProtocol:
@staticmethod
def _encode_varint(value: int) -> bytes:
result = []
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
@staticmethod
def _decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
result = 0
shift = 0
pos = offset
while pos < len(data):
byte = data[pos]
result |= (byte & 0x7F) << shift
pos += 1
if (byte & 0x80) == 0:
break
shift += 7
return result, pos
@staticmethod
def decode(input: bytes) -> Packet:
length, offset = MsgpackProtocol._decode_varint(input)
message_data = input[offset : offset + length]
# FIXME: custom deserializer for APIMod
# https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
unpacked = msgpack.unpackb(
message_data, raw=False, strict_map_key=False, use_list=True
)
packet_type = PacketType(unpacked[0])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return InvocationPacket(
header=unpacked[1],
invocation_id=unpacked[2],
target=unpacked[3],
arguments=unpacked[4] if len(unpacked) > 4 else None,
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
)
case PacketType.COMPLETION:
result_kind = unpacked[3]
return CompletionPacket(
header=unpacked[1],
invocation_id=unpacked[2],
error=unpacked[4] if result_kind == 1 else None,
result=unpacked[5] if result_kind == 3 else None,
)
case PacketType.PING:
return PingPacket()
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod
def encode(packet: Packet) -> bytes:
payload = [packet.type.value, packet.header or {}]
if isinstance(packet, InvocationPacket):
payload.extend(
[
packet.invocation_id,
packet.target,
]
)
if packet.arguments is not None:
payload.append(packet.arguments)
if packet.stream_ids is not None:
payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket):
result_kind = 2
if packet.error:
result_kind = 1
elif packet.result is None:
result_kind = 3
payload.extend(
[
packet.invocation_id,
result_kind,
packet.error or packet.result or None,
]
)
elif isinstance(packet, PingPacket):
pass
data = msgpack.packb(payload, use_bin_type=True)
return MsgpackProtocol._encode_varint(len(data)) + data
class JSONProtocol:
@staticmethod
def decode(input: bytes) -> Packet:
data = json.loads(input[:-1].decode("utf-8"))
packet_type = PacketType(data["type"])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return InvocationPacket(
header=data.get("header"),
invocation_id=data.get("invocationId"),
target=data["target"],
arguments=data.get("arguments"),
stream_ids=data.get("streamIds"),
)
case PacketType.COMPLETION:
return CompletionPacket(
header=data.get("header"),
invocation_id=data["invocationId"],
error=data.get("error"),
result=data.get("result"),
)
case PacketType.PING:
return PingPacket()
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod
def encode(packet: Packet) -> bytes:
payload: dict[str, Any] = {
"type": packet.type.value,
}
if packet.header:
payload["header"] = packet.header
if isinstance(packet, InvocationPacket):
payload.update(
{
"target": packet.target,
}
)
if packet.invocation_id is not None:
payload["invocationId"] = packet.invocation_id
if packet.arguments is not None:
payload["arguments"] = packet.arguments
if packet.stream_ids is not None:
payload["streamIds"] = packet.stream_ids
elif isinstance(packet, CompletionPacket):
payload.update(
{
"invocationId": packet.invocation_id,
}
)
if packet.error is not None:
payload["error"] = packet.error
if packet.result is not None:
payload["result"] = packet.result
elif isinstance(packet, PingPacket):
pass
return json.dumps(payload).encode("utf-8")
PROTOCOLS: dict[str, Protocol] = {
"json": JSONProtocol,
"messagepack": MsgpackProtocol,
}

View File

@@ -12,7 +12,7 @@ from app.dependencies.user import get_current_user_by_token
from app.models.signalr import NegotiateResponse, Transport from app.models.signalr import NegotiateResponse, Transport
from .hub import Hubs from .hub import Hubs
from .packet import SEP from .packet import PROTOCOLS, SEP
from fastapi import APIRouter, Depends, Header, Query, WebSocket from fastapi import APIRouter, Depends, Header, Query, WebSocket
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -62,13 +62,14 @@ async def connect(
await websocket.accept() await websocket.accept()
# handshake # handshake
handshake = await websocket.receive_bytes() handshake = await websocket.receive()
handshake_payload = json.loads(handshake[:-1]) message = handshake.get("bytes") or handshake.get("text")
if not message:
await websocket.close(code=1008)
return
handshake_payload = json.loads(message[:-1])
error = "" error = ""
if (protocol := handshake_payload.get("protocol")) != "messagepack" or ( protocol = handshake_payload.get("protocol", "json")
handshake_payload.get("version")
) != 1:
error = f"Requested protocol '{protocol}' is not available."
client = None client = None
try: try:
@@ -76,7 +77,10 @@ async def connect(
connection_id=user_id, connection_id=user_id,
connection_token=id, connection_token=id,
connection=websocket, connection=websocket,
protocol=PROTOCOLS[protocol],
) )
except KeyError:
error = f"Protocol '{protocol}' is not supported."
except TimeoutError: except TimeoutError:
error = f"Connection {id} has waited too long." error = f"Connection {id} has waited too long."
except ValueError as e: except ValueError as e:

View File

@@ -2,9 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import sys import sys
from typing import Any, Literal from typing import Any
from .packet import ResultKind
class ResultStore: class ResultStore:
@@ -22,21 +20,17 @@ class ResultStore:
return str(s) return str(s)
def add_result( def add_result(
self, invocation_id: str, type: ResultKind, result: dict[str, Any] | None self, invocation_id: str, result: Any, error: str | None = None
) -> None: ) -> None:
if isinstance(invocation_id, str) and invocation_id.isdecimal(): if isinstance(invocation_id, str) and invocation_id.isdecimal():
if future := self._futures.get(invocation_id): if future := self._futures.get(invocation_id):
future.set_result((type, result)) future.set_result((result, error))
async def fetch( async def fetch(
self, self,
invocation_id: str, invocation_id: str,
timeout: float | None, # noqa: ASYNC109 timeout: float | None, # noqa: ASYNC109
) -> ( ) -> tuple[Any, str | None]:
tuple[Literal[ResultKind.ERROR], str]
| tuple[Literal[ResultKind.VOID], None]
| tuple[Literal[ResultKind.HAS_VALUE], Any]
):
future = asyncio.get_event_loop().create_future() future = asyncio.get_event_loop().create_future()
self._futures[invocation_id] = future self._futures[invocation_id] = future
try: try: