From 589927a3004b2bfb7d7ef78790f9c8affadbbb26 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 27 Jul 2025 11:45:55 +0000 Subject: [PATCH] feat(signalr): support json protocol --- app/models/signalr.py | 2 +- app/signalr/hub/hub.py | 155 +++++++++++++-------------- app/signalr/packet.py | 233 +++++++++++++++++++++++++++++++++++------ app/signalr/router.py | 18 ++-- app/signalr/store.py | 14 +-- 5 files changed, 286 insertions(+), 136 deletions(-) diff --git a/app/models/signalr.py b/app/models/signalr.py index fb4f55f..49db11f 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -20,7 +20,7 @@ class MessagePackArrayModel(BaseModel): class Transport(BaseModel): transport: str transfer_formats: list[str] = Field( - default_factory=lambda: ["Binary"], alias="transferFormats" + default_factory=lambda: ["Binary", "Text"], alias="transferFormats" ) diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index 3175b8d..42c6aef 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -8,41 +8,50 @@ from typing import Any from app.config import settings from app.signalr.exception import InvokeException from app.signalr.packet import ( - PacketType, - ResultKind, - encode_varint, - parse_packet, + CompletionPacket, + InvocationPacket, + Packet, + PingPacket, + Protocol, ) from app.signalr.store import ResultStore from app.signalr.utils import get_signature from fastapi import WebSocket -import msgpack from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect class Client: def __init__( - self, connection_id: str, connection_token: str, connection: WebSocket + self, + connection_id: str, + connection_token: str, + connection: WebSocket, + protocol: Protocol, ) -> None: self.connection_id = connection_id self.connection_token = connection_token self.connection = connection + self.procotol = protocol self._listen_task: asyncio.Task | None = None self._ping_task: asyncio.Task | None = None self._store = ResultStore() - async def send_packet(self, type: PacketType, packet: list[Any]): - packet.insert(0, type.value) - payload = msgpack.packb(packet) - length = encode_varint(len(payload)) - await self.connection.send_bytes(length + payload) + async def send_packet(self, packet: Packet): + await self.connection.send_bytes(self.procotol.encode(packet)) + + async def receive_packet(self) -> Packet: + message = await self.connection.receive() + d = message.get("bytes") or message.get("text", "").encode() + if not d: + raise WebSocketDisconnect(code=1008, reason="Empty message received.") + return self.procotol.decode(d) async def _ping(self): while True: try: - await self.send_packet(PacketType.PING, []) + await self.send_packet(PingPacket()) await asyncio.sleep(settings.SIGNALR_PING_INTERVAL) except WebSocketDisconnect: break @@ -61,7 +70,11 @@ class Hub: self.waited_clients[connection_token] = timestamp def add_client( - self, connection_id: str, connection_token: str, connection: WebSocket + self, + connection_id: str, + connection_token: str, + protocol: Protocol, + connection: WebSocket, ) -> Client: if connection_token in self.clients: raise ValueError( @@ -74,7 +87,7 @@ class Hub: ): raise TimeoutError(f"Connection {connection_id} has waited too long.") del self.waited_clients[connection_token] - client = Client(connection_id, connection_token, connection) + client = Client(connection_id, connection_token, connection, protocol) self.clients[connection_token] = client task = asyncio.create_task(client._ping()) self.tasks.add(task) @@ -90,8 +103,8 @@ class Hub: client._ping_task.cancel() await client.connection.close() - async def send_packet(self, client: Client, type: PacketType, packet: list[Any]): - await client.send_packet(type, packet) + async def send_packet(self, client: Client, packet: Packet) -> None: + await client.send_packet(packet) async def broadcast_call(self, method: str, *args: Any) -> None: tasks = [] @@ -103,11 +116,8 @@ class Hub: jump = False while not jump: try: - message = await client.connection.receive_bytes() - packet_type, packet_data = parse_packet(message) - task = asyncio.create_task( - self._handle_packet(client, packet_type, packet_data) - ) + packet = await client.receive_packet() + task = asyncio.create_task(self._handle_packet(client, packet)) self.tasks.add(task) task.add_done_callback(self.tasks.discard) except WebSocketDisconnect as e: @@ -121,51 +131,30 @@ class Hub: jump = True await self.remove_client(client.connection_id) - async def _handle_packet( - self, client: Client, type: PacketType, packet: list[Any] - ) -> None: - match type: - case PacketType.PING: - ... - case PacketType.INVOCATION: - invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration] - target: str = packet[2] - args: list[Any] | None = packet[3] - if args is None: - args = [] - # streams: list[str] | None = packet[4] # TODO: stream support - code = ResultKind.VOID - result = None - try: - result = await self.invoke_method(client, target, args) - if result is not None: - code = ResultKind.HAS_VALUE - except InvokeException as e: - code = ResultKind.ERROR - result = e.message - - except Exception as e: - traceback.print_exc() - code = ResultKind.ERROR - result = str(e) - - packet = [ - {}, # header - invocation_id, - code.value, - ] - if result is not None: - packet.append(result) - if invocation_id is not None: - await client.send_packet( - PacketType.COMPLETION, - packet, + async def _handle_packet(self, client: Client, packet: Packet) -> None: + if isinstance(packet, PingPacket): + return + elif isinstance(packet, InvocationPacket): + args = packet.arguments or [] + error = None + result = None + try: + result = await self.invoke_method(client, packet.target, args) + except InvokeException as e: + error = e.message + except Exception as e: + traceback.print_exc() + error = str(e) + if packet.invocation_id is not None: + await client.send_packet( + CompletionPacket( + invocation_id=packet.invocation_id, + error=error, + result=result, ) - case PacketType.COMPLETION: - invocation_id: str = packet[1] - code: ResultKind = ResultKind(packet[2]) - result: Any = packet[3] if len(packet) > 3 else None - client._store.add_result(invocation_id, code, result) + ) + elif isinstance(packet, CompletionPacket): + client._store.add_result(packet.invocation_id, packet.result, packet.error) async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any: method_ = getattr(self, method, None) @@ -185,32 +174,28 @@ class Hub: async def call(self, client: Client, method: str, *args: Any) -> Any: invocation_id = client._store.get_invocation_id() await client.send_packet( - PacketType.INVOCATION, - [ - {}, # header - invocation_id, - method, - list(args), - None, # streams - ], + InvocationPacket( + header={}, + invocation_id=invocation_id, + target=method, + arguments=list(args), + stream_ids=None, + ) ) r = await client._store.fetch(invocation_id, None) - if r[0] == ResultKind.HAS_VALUE: - return r[1] - if r[0] == ResultKind.ERROR: + if r[1]: raise InvokeException(r[1]) - return None + return r[0] async def call_noblock(self, client: Client, method: str, *args: Any) -> None: await client.send_packet( - PacketType.INVOCATION, - [ - {}, # header - None, # invocation_id - method, - list(args), - None, # streams - ], + InvocationPacket( + header={}, + invocation_id=None, + target=method, + arguments=list(args), + stream_ids=None, + ) ) return None diff --git a/app/signalr/packet.py b/app/signalr/packet.py index bb97afd..d3da080 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -1,7 +1,12 @@ from __future__ import annotations +from dataclasses import dataclass from enum import IntEnum -from typing import Any +import json +from typing import ( + Any, + Protocol as TypingProtocol, +) import msgpack @@ -18,43 +23,205 @@ class PacketType(IntEnum): CLOSE = 7 -class ResultKind(IntEnum): - ERROR = 1 - VOID = 2 - HAS_VALUE = 3 +@dataclass(kw_only=True) +class Packet: + type: PacketType + header: dict[str, Any] | None = None -def parse_packet(data: bytes) -> tuple[PacketType, list[Any]]: - length, offset = decode_varint(data) - message_data = data[offset : offset + length] - # FIXME: custom deserializer for APIMod - # https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs - unpacked = msgpack.unpackb( - message_data, raw=False, strict_map_key=False, use_list=True - ) - return PacketType(unpacked[0]), unpacked[1:] +@dataclass(kw_only=True) +class InvocationPacket(Packet): + type: PacketType = PacketType.INVOCATION + invocation_id: str | None + target: str + arguments: list[Any] | None = None + stream_ids: list[str] | None = None -def encode_varint(value: int) -> bytes: - result = [] - while value >= 0x80: - result.append((value & 0x7F) | 0x80) - value >>= 7 - result.append(value & 0x7F) - return bytes(result) +@dataclass(kw_only=True) +class CompletionPacket(Packet): + type: PacketType = PacketType.COMPLETION + invocation_id: str + result: Any + error: str | None = None -def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]: - result = 0 - shift = 0 - pos = offset +@dataclass(kw_only=True) +class PingPacket(Packet): + type: PacketType = PacketType.PING - while pos < len(data): - byte = data[pos] - result |= (byte & 0x7F) << shift - pos += 1 - if (byte & 0x80) == 0: - break - shift += 7 - return result, pos +PACKETS = { + PacketType.INVOCATION: InvocationPacket, + PacketType.COMPLETION: CompletionPacket, + PacketType.PING: PingPacket, +} + + +class Protocol(TypingProtocol): + @staticmethod + def decode(input: bytes) -> Packet: ... + + @staticmethod + def encode(packet: Packet) -> bytes: ... + + +class MsgpackProtocol: + @staticmethod + def _encode_varint(value: int) -> bytes: + result = [] + while value >= 0x80: + result.append((value & 0x7F) | 0x80) + value >>= 7 + result.append(value & 0x7F) + return bytes(result) + + @staticmethod + def _decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]: + result = 0 + shift = 0 + pos = offset + + while pos < len(data): + byte = data[pos] + result |= (byte & 0x7F) << shift + pos += 1 + if (byte & 0x80) == 0: + break + shift += 7 + + return result, pos + + @staticmethod + def decode(input: bytes) -> Packet: + length, offset = MsgpackProtocol._decode_varint(input) + message_data = input[offset : offset + length] + # FIXME: custom deserializer for APIMod + # https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs + unpacked = msgpack.unpackb( + message_data, raw=False, strict_map_key=False, use_list=True + ) + packet_type = PacketType(unpacked[0]) + if packet_type not in PACKETS: + raise ValueError(f"Unknown packet type: {packet_type}") + match packet_type: + case PacketType.INVOCATION: + return InvocationPacket( + header=unpacked[1], + invocation_id=unpacked[2], + target=unpacked[3], + arguments=unpacked[4] if len(unpacked) > 4 else None, + stream_ids=unpacked[5] if len(unpacked) > 5 else None, + ) + case PacketType.COMPLETION: + result_kind = unpacked[3] + return CompletionPacket( + header=unpacked[1], + invocation_id=unpacked[2], + error=unpacked[4] if result_kind == 1 else None, + result=unpacked[5] if result_kind == 3 else None, + ) + case PacketType.PING: + return PingPacket() + raise ValueError(f"Unsupported packet type: {packet_type}") + + @staticmethod + def encode(packet: Packet) -> bytes: + payload = [packet.type.value, packet.header or {}] + if isinstance(packet, InvocationPacket): + payload.extend( + [ + packet.invocation_id, + packet.target, + ] + ) + if packet.arguments is not None: + payload.append(packet.arguments) + if packet.stream_ids is not None: + payload.append(packet.stream_ids) + elif isinstance(packet, CompletionPacket): + result_kind = 2 + if packet.error: + result_kind = 1 + elif packet.result is None: + result_kind = 3 + payload.extend( + [ + packet.invocation_id, + result_kind, + packet.error or packet.result or None, + ] + ) + elif isinstance(packet, PingPacket): + pass + + data = msgpack.packb(payload, use_bin_type=True) + return MsgpackProtocol._encode_varint(len(data)) + data + + +class JSONProtocol: + @staticmethod + def decode(input: bytes) -> Packet: + data = json.loads(input[:-1].decode("utf-8")) + packet_type = PacketType(data["type"]) + if packet_type not in PACKETS: + raise ValueError(f"Unknown packet type: {packet_type}") + match packet_type: + case PacketType.INVOCATION: + return InvocationPacket( + header=data.get("header"), + invocation_id=data.get("invocationId"), + target=data["target"], + arguments=data.get("arguments"), + stream_ids=data.get("streamIds"), + ) + case PacketType.COMPLETION: + return CompletionPacket( + header=data.get("header"), + invocation_id=data["invocationId"], + error=data.get("error"), + result=data.get("result"), + ) + case PacketType.PING: + return PingPacket() + raise ValueError(f"Unsupported packet type: {packet_type}") + + @staticmethod + def encode(packet: Packet) -> bytes: + payload: dict[str, Any] = { + "type": packet.type.value, + } + if packet.header: + payload["header"] = packet.header + if isinstance(packet, InvocationPacket): + payload.update( + { + "target": packet.target, + } + ) + if packet.invocation_id is not None: + payload["invocationId"] = packet.invocation_id + if packet.arguments is not None: + payload["arguments"] = packet.arguments + if packet.stream_ids is not None: + payload["streamIds"] = packet.stream_ids + elif isinstance(packet, CompletionPacket): + payload.update( + { + "invocationId": packet.invocation_id, + } + ) + if packet.error is not None: + payload["error"] = packet.error + if packet.result is not None: + payload["result"] = packet.result + elif isinstance(packet, PingPacket): + pass + + return json.dumps(payload).encode("utf-8") + + +PROTOCOLS: dict[str, Protocol] = { + "json": JSONProtocol, + "messagepack": MsgpackProtocol, +} diff --git a/app/signalr/router.py b/app/signalr/router.py index 49934b7..3d70931 100644 --- a/app/signalr/router.py +++ b/app/signalr/router.py @@ -12,7 +12,7 @@ from app.dependencies.user import get_current_user_by_token from app.models.signalr import NegotiateResponse, Transport from .hub import Hubs -from .packet import SEP +from .packet import PROTOCOLS, SEP from fastapi import APIRouter, Depends, Header, Query, WebSocket from sqlmodel.ext.asyncio.session import AsyncSession @@ -62,13 +62,14 @@ async def connect( await websocket.accept() # handshake - handshake = await websocket.receive_bytes() - handshake_payload = json.loads(handshake[:-1]) + handshake = await websocket.receive() + message = handshake.get("bytes") or handshake.get("text") + if not message: + await websocket.close(code=1008) + return + handshake_payload = json.loads(message[:-1]) error = "" - if (protocol := handshake_payload.get("protocol")) != "messagepack" or ( - handshake_payload.get("version") - ) != 1: - error = f"Requested protocol '{protocol}' is not available." + protocol = handshake_payload.get("protocol", "json") client = None try: @@ -76,7 +77,10 @@ async def connect( connection_id=user_id, connection_token=id, connection=websocket, + protocol=PROTOCOLS[protocol], ) + except KeyError: + error = f"Protocol '{protocol}' is not supported." except TimeoutError: error = f"Connection {id} has waited too long." except ValueError as e: diff --git a/app/signalr/store.py b/app/signalr/store.py index 5258293..008da03 100644 --- a/app/signalr/store.py +++ b/app/signalr/store.py @@ -2,9 +2,7 @@ from __future__ import annotations import asyncio import sys -from typing import Any, Literal - -from .packet import ResultKind +from typing import Any class ResultStore: @@ -22,21 +20,17 @@ class ResultStore: return str(s) def add_result( - self, invocation_id: str, type: ResultKind, result: dict[str, Any] | None + self, invocation_id: str, result: Any, error: str | None = None ) -> None: if isinstance(invocation_id, str) and invocation_id.isdecimal(): if future := self._futures.get(invocation_id): - future.set_result((type, result)) + future.set_result((result, error)) async def fetch( self, invocation_id: str, timeout: float | None, # noqa: ASYNC109 - ) -> ( - tuple[Literal[ResultKind.ERROR], str] - | tuple[Literal[ResultKind.VOID], None] - | tuple[Literal[ResultKind.HAS_VALUE], Any] - ): + ) -> tuple[Any, str | None]: future = asyncio.get_event_loop().create_future() self._futures[invocation_id] = future try: