feat(signalr): support json protocol
This commit is contained in:
@@ -8,41 +8,50 @@ from typing import Any
|
||||
from app.config import settings
|
||||
from app.signalr.exception import InvokeException
|
||||
from app.signalr.packet import (
|
||||
PacketType,
|
||||
ResultKind,
|
||||
encode_varint,
|
||||
parse_packet,
|
||||
CompletionPacket,
|
||||
InvocationPacket,
|
||||
Packet,
|
||||
PingPacket,
|
||||
Protocol,
|
||||
)
|
||||
from app.signalr.store import ResultStore
|
||||
from app.signalr.utils import get_signature
|
||||
|
||||
from fastapi import WebSocket
|
||||
import msgpack
|
||||
from pydantic import BaseModel
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self, connection_id: str, connection_token: str, connection: WebSocket
|
||||
self,
|
||||
connection_id: str,
|
||||
connection_token: str,
|
||||
connection: WebSocket,
|
||||
protocol: Protocol,
|
||||
) -> None:
|
||||
self.connection_id = connection_id
|
||||
self.connection_token = connection_token
|
||||
self.connection = connection
|
||||
self.procotol = protocol
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._ping_task: asyncio.Task | None = None
|
||||
self._store = ResultStore()
|
||||
|
||||
async def send_packet(self, type: PacketType, packet: list[Any]):
|
||||
packet.insert(0, type.value)
|
||||
payload = msgpack.packb(packet)
|
||||
length = encode_varint(len(payload))
|
||||
await self.connection.send_bytes(length + payload)
|
||||
async def send_packet(self, packet: Packet):
|
||||
await self.connection.send_bytes(self.procotol.encode(packet))
|
||||
|
||||
async def receive_packet(self) -> Packet:
|
||||
message = await self.connection.receive()
|
||||
d = message.get("bytes") or message.get("text", "").encode()
|
||||
if not d:
|
||||
raise WebSocketDisconnect(code=1008, reason="Empty message received.")
|
||||
return self.procotol.decode(d)
|
||||
|
||||
async def _ping(self):
|
||||
while True:
|
||||
try:
|
||||
await self.send_packet(PacketType.PING, [])
|
||||
await self.send_packet(PingPacket())
|
||||
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
@@ -61,7 +70,11 @@ class Hub:
|
||||
self.waited_clients[connection_token] = timestamp
|
||||
|
||||
def add_client(
|
||||
self, connection_id: str, connection_token: str, connection: WebSocket
|
||||
self,
|
||||
connection_id: str,
|
||||
connection_token: str,
|
||||
protocol: Protocol,
|
||||
connection: WebSocket,
|
||||
) -> Client:
|
||||
if connection_token in self.clients:
|
||||
raise ValueError(
|
||||
@@ -74,7 +87,7 @@ class Hub:
|
||||
):
|
||||
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
||||
del self.waited_clients[connection_token]
|
||||
client = Client(connection_id, connection_token, connection)
|
||||
client = Client(connection_id, connection_token, connection, protocol)
|
||||
self.clients[connection_token] = client
|
||||
task = asyncio.create_task(client._ping())
|
||||
self.tasks.add(task)
|
||||
@@ -90,8 +103,8 @@ class Hub:
|
||||
client._ping_task.cancel()
|
||||
await client.connection.close()
|
||||
|
||||
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]):
|
||||
await client.send_packet(type, packet)
|
||||
async def send_packet(self, client: Client, packet: Packet) -> None:
|
||||
await client.send_packet(packet)
|
||||
|
||||
async def broadcast_call(self, method: str, *args: Any) -> None:
|
||||
tasks = []
|
||||
@@ -103,11 +116,8 @@ class Hub:
|
||||
jump = False
|
||||
while not jump:
|
||||
try:
|
||||
message = await client.connection.receive_bytes()
|
||||
packet_type, packet_data = parse_packet(message)
|
||||
task = asyncio.create_task(
|
||||
self._handle_packet(client, packet_type, packet_data)
|
||||
)
|
||||
packet = await client.receive_packet()
|
||||
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except WebSocketDisconnect as e:
|
||||
@@ -121,51 +131,30 @@ class Hub:
|
||||
jump = True
|
||||
await self.remove_client(client.connection_id)
|
||||
|
||||
async def _handle_packet(
|
||||
self, client: Client, type: PacketType, packet: list[Any]
|
||||
) -> None:
|
||||
match type:
|
||||
case PacketType.PING:
|
||||
...
|
||||
case PacketType.INVOCATION:
|
||||
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration]
|
||||
target: str = packet[2]
|
||||
args: list[Any] | None = packet[3]
|
||||
if args is None:
|
||||
args = []
|
||||
# streams: list[str] | None = packet[4] # TODO: stream support
|
||||
code = ResultKind.VOID
|
||||
result = None
|
||||
try:
|
||||
result = await self.invoke_method(client, target, args)
|
||||
if result is not None:
|
||||
code = ResultKind.HAS_VALUE
|
||||
except InvokeException as e:
|
||||
code = ResultKind.ERROR
|
||||
result = e.message
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
code = ResultKind.ERROR
|
||||
result = str(e)
|
||||
|
||||
packet = [
|
||||
{}, # header
|
||||
invocation_id,
|
||||
code.value,
|
||||
]
|
||||
if result is not None:
|
||||
packet.append(result)
|
||||
if invocation_id is not None:
|
||||
await client.send_packet(
|
||||
PacketType.COMPLETION,
|
||||
packet,
|
||||
async def _handle_packet(self, client: Client, packet: Packet) -> None:
|
||||
if isinstance(packet, PingPacket):
|
||||
return
|
||||
elif isinstance(packet, InvocationPacket):
|
||||
args = packet.arguments or []
|
||||
error = None
|
||||
result = None
|
||||
try:
|
||||
result = await self.invoke_method(client, packet.target, args)
|
||||
except InvokeException as e:
|
||||
error = e.message
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
error = str(e)
|
||||
if packet.invocation_id is not None:
|
||||
await client.send_packet(
|
||||
CompletionPacket(
|
||||
invocation_id=packet.invocation_id,
|
||||
error=error,
|
||||
result=result,
|
||||
)
|
||||
case PacketType.COMPLETION:
|
||||
invocation_id: str = packet[1]
|
||||
code: ResultKind = ResultKind(packet[2])
|
||||
result: Any = packet[3] if len(packet) > 3 else None
|
||||
client._store.add_result(invocation_id, code, result)
|
||||
)
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
client._store.add_result(packet.invocation_id, packet.result, packet.error)
|
||||
|
||||
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
|
||||
method_ = getattr(self, method, None)
|
||||
@@ -185,32 +174,28 @@ class Hub:
|
||||
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
||||
invocation_id = client._store.get_invocation_id()
|
||||
await client.send_packet(
|
||||
PacketType.INVOCATION,
|
||||
[
|
||||
{}, # header
|
||||
invocation_id,
|
||||
method,
|
||||
list(args),
|
||||
None, # streams
|
||||
],
|
||||
InvocationPacket(
|
||||
header={},
|
||||
invocation_id=invocation_id,
|
||||
target=method,
|
||||
arguments=list(args),
|
||||
stream_ids=None,
|
||||
)
|
||||
)
|
||||
r = await client._store.fetch(invocation_id, None)
|
||||
if r[0] == ResultKind.HAS_VALUE:
|
||||
return r[1]
|
||||
if r[0] == ResultKind.ERROR:
|
||||
if r[1]:
|
||||
raise InvokeException(r[1])
|
||||
return None
|
||||
return r[0]
|
||||
|
||||
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
|
||||
await client.send_packet(
|
||||
PacketType.INVOCATION,
|
||||
[
|
||||
{}, # header
|
||||
None, # invocation_id
|
||||
method,
|
||||
list(args),
|
||||
None, # streams
|
||||
],
|
||||
InvocationPacket(
|
||||
header={},
|
||||
invocation_id=None,
|
||||
target=method,
|
||||
arguments=list(args),
|
||||
stream_ids=None,
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user