feat(signalr): support json protocol
This commit is contained in:
@@ -20,7 +20,7 @@ class MessagePackArrayModel(BaseModel):
|
|||||||
class Transport(BaseModel):
|
class Transport(BaseModel):
|
||||||
transport: str
|
transport: str
|
||||||
transfer_formats: list[str] = Field(
|
transfer_formats: list[str] = Field(
|
||||||
default_factory=lambda: ["Binary"], alias="transferFormats"
|
default_factory=lambda: ["Binary", "Text"], alias="transferFormats"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,41 +8,50 @@ from typing import Any
|
|||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.signalr.exception import InvokeException
|
from app.signalr.exception import InvokeException
|
||||||
from app.signalr.packet import (
|
from app.signalr.packet import (
|
||||||
PacketType,
|
CompletionPacket,
|
||||||
ResultKind,
|
InvocationPacket,
|
||||||
encode_varint,
|
Packet,
|
||||||
parse_packet,
|
PingPacket,
|
||||||
|
Protocol,
|
||||||
)
|
)
|
||||||
from app.signalr.store import ResultStore
|
from app.signalr.store import ResultStore
|
||||||
from app.signalr.utils import get_signature
|
from app.signalr.utils import get_signature
|
||||||
|
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
import msgpack
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.websockets import WebSocketDisconnect
|
from starlette.websockets import WebSocketDisconnect
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, connection_id: str, connection_token: str, connection: WebSocket
|
self,
|
||||||
|
connection_id: str,
|
||||||
|
connection_token: str,
|
||||||
|
connection: WebSocket,
|
||||||
|
protocol: Protocol,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.connection_id = connection_id
|
self.connection_id = connection_id
|
||||||
self.connection_token = connection_token
|
self.connection_token = connection_token
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
self.procotol = protocol
|
||||||
self._listen_task: asyncio.Task | None = None
|
self._listen_task: asyncio.Task | None = None
|
||||||
self._ping_task: asyncio.Task | None = None
|
self._ping_task: asyncio.Task | None = None
|
||||||
self._store = ResultStore()
|
self._store = ResultStore()
|
||||||
|
|
||||||
async def send_packet(self, type: PacketType, packet: list[Any]):
|
async def send_packet(self, packet: Packet):
|
||||||
packet.insert(0, type.value)
|
await self.connection.send_bytes(self.procotol.encode(packet))
|
||||||
payload = msgpack.packb(packet)
|
|
||||||
length = encode_varint(len(payload))
|
async def receive_packet(self) -> Packet:
|
||||||
await self.connection.send_bytes(length + payload)
|
message = await self.connection.receive()
|
||||||
|
d = message.get("bytes") or message.get("text", "").encode()
|
||||||
|
if not d:
|
||||||
|
raise WebSocketDisconnect(code=1008, reason="Empty message received.")
|
||||||
|
return self.procotol.decode(d)
|
||||||
|
|
||||||
async def _ping(self):
|
async def _ping(self):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await self.send_packet(PacketType.PING, [])
|
await self.send_packet(PingPacket())
|
||||||
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
|
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
break
|
break
|
||||||
@@ -61,7 +70,11 @@ class Hub:
|
|||||||
self.waited_clients[connection_token] = timestamp
|
self.waited_clients[connection_token] = timestamp
|
||||||
|
|
||||||
def add_client(
|
def add_client(
|
||||||
self, connection_id: str, connection_token: str, connection: WebSocket
|
self,
|
||||||
|
connection_id: str,
|
||||||
|
connection_token: str,
|
||||||
|
protocol: Protocol,
|
||||||
|
connection: WebSocket,
|
||||||
) -> Client:
|
) -> Client:
|
||||||
if connection_token in self.clients:
|
if connection_token in self.clients:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -74,7 +87,7 @@ class Hub:
|
|||||||
):
|
):
|
||||||
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
||||||
del self.waited_clients[connection_token]
|
del self.waited_clients[connection_token]
|
||||||
client = Client(connection_id, connection_token, connection)
|
client = Client(connection_id, connection_token, connection, protocol)
|
||||||
self.clients[connection_token] = client
|
self.clients[connection_token] = client
|
||||||
task = asyncio.create_task(client._ping())
|
task = asyncio.create_task(client._ping())
|
||||||
self.tasks.add(task)
|
self.tasks.add(task)
|
||||||
@@ -90,8 +103,8 @@ class Hub:
|
|||||||
client._ping_task.cancel()
|
client._ping_task.cancel()
|
||||||
await client.connection.close()
|
await client.connection.close()
|
||||||
|
|
||||||
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]):
|
async def send_packet(self, client: Client, packet: Packet) -> None:
|
||||||
await client.send_packet(type, packet)
|
await client.send_packet(packet)
|
||||||
|
|
||||||
async def broadcast_call(self, method: str, *args: Any) -> None:
|
async def broadcast_call(self, method: str, *args: Any) -> None:
|
||||||
tasks = []
|
tasks = []
|
||||||
@@ -103,11 +116,8 @@ class Hub:
|
|||||||
jump = False
|
jump = False
|
||||||
while not jump:
|
while not jump:
|
||||||
try:
|
try:
|
||||||
message = await client.connection.receive_bytes()
|
packet = await client.receive_packet()
|
||||||
packet_type, packet_data = parse_packet(message)
|
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||||
task = asyncio.create_task(
|
|
||||||
self._handle_packet(client, packet_type, packet_data)
|
|
||||||
)
|
|
||||||
self.tasks.add(task)
|
self.tasks.add(task)
|
||||||
task.add_done_callback(self.tasks.discard)
|
task.add_done_callback(self.tasks.discard)
|
||||||
except WebSocketDisconnect as e:
|
except WebSocketDisconnect as e:
|
||||||
@@ -121,51 +131,30 @@ class Hub:
|
|||||||
jump = True
|
jump = True
|
||||||
await self.remove_client(client.connection_id)
|
await self.remove_client(client.connection_id)
|
||||||
|
|
||||||
async def _handle_packet(
|
async def _handle_packet(self, client: Client, packet: Packet) -> None:
|
||||||
self, client: Client, type: PacketType, packet: list[Any]
|
if isinstance(packet, PingPacket):
|
||||||
) -> None:
|
return
|
||||||
match type:
|
elif isinstance(packet, InvocationPacket):
|
||||||
case PacketType.PING:
|
args = packet.arguments or []
|
||||||
...
|
error = None
|
||||||
case PacketType.INVOCATION:
|
result = None
|
||||||
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration]
|
try:
|
||||||
target: str = packet[2]
|
result = await self.invoke_method(client, packet.target, args)
|
||||||
args: list[Any] | None = packet[3]
|
except InvokeException as e:
|
||||||
if args is None:
|
error = e.message
|
||||||
args = []
|
except Exception as e:
|
||||||
# streams: list[str] | None = packet[4] # TODO: stream support
|
traceback.print_exc()
|
||||||
code = ResultKind.VOID
|
error = str(e)
|
||||||
result = None
|
if packet.invocation_id is not None:
|
||||||
try:
|
await client.send_packet(
|
||||||
result = await self.invoke_method(client, target, args)
|
CompletionPacket(
|
||||||
if result is not None:
|
invocation_id=packet.invocation_id,
|
||||||
code = ResultKind.HAS_VALUE
|
error=error,
|
||||||
except InvokeException as e:
|
result=result,
|
||||||
code = ResultKind.ERROR
|
|
||||||
result = e.message
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
code = ResultKind.ERROR
|
|
||||||
result = str(e)
|
|
||||||
|
|
||||||
packet = [
|
|
||||||
{}, # header
|
|
||||||
invocation_id,
|
|
||||||
code.value,
|
|
||||||
]
|
|
||||||
if result is not None:
|
|
||||||
packet.append(result)
|
|
||||||
if invocation_id is not None:
|
|
||||||
await client.send_packet(
|
|
||||||
PacketType.COMPLETION,
|
|
||||||
packet,
|
|
||||||
)
|
)
|
||||||
case PacketType.COMPLETION:
|
)
|
||||||
invocation_id: str = packet[1]
|
elif isinstance(packet, CompletionPacket):
|
||||||
code: ResultKind = ResultKind(packet[2])
|
client._store.add_result(packet.invocation_id, packet.result, packet.error)
|
||||||
result: Any = packet[3] if len(packet) > 3 else None
|
|
||||||
client._store.add_result(invocation_id, code, result)
|
|
||||||
|
|
||||||
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
|
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
|
||||||
method_ = getattr(self, method, None)
|
method_ = getattr(self, method, None)
|
||||||
@@ -185,32 +174,28 @@ class Hub:
|
|||||||
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
||||||
invocation_id = client._store.get_invocation_id()
|
invocation_id = client._store.get_invocation_id()
|
||||||
await client.send_packet(
|
await client.send_packet(
|
||||||
PacketType.INVOCATION,
|
InvocationPacket(
|
||||||
[
|
header={},
|
||||||
{}, # header
|
invocation_id=invocation_id,
|
||||||
invocation_id,
|
target=method,
|
||||||
method,
|
arguments=list(args),
|
||||||
list(args),
|
stream_ids=None,
|
||||||
None, # streams
|
)
|
||||||
],
|
|
||||||
)
|
)
|
||||||
r = await client._store.fetch(invocation_id, None)
|
r = await client._store.fetch(invocation_id, None)
|
||||||
if r[0] == ResultKind.HAS_VALUE:
|
if r[1]:
|
||||||
return r[1]
|
|
||||||
if r[0] == ResultKind.ERROR:
|
|
||||||
raise InvokeException(r[1])
|
raise InvokeException(r[1])
|
||||||
return None
|
return r[0]
|
||||||
|
|
||||||
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
|
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
|
||||||
await client.send_packet(
|
await client.send_packet(
|
||||||
PacketType.INVOCATION,
|
InvocationPacket(
|
||||||
[
|
header={},
|
||||||
{}, # header
|
invocation_id=None,
|
||||||
None, # invocation_id
|
target=method,
|
||||||
method,
|
arguments=list(args),
|
||||||
list(args),
|
stream_ids=None,
|
||||||
None, # streams
|
)
|
||||||
],
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any
|
import json
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Protocol as TypingProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
|
|
||||||
@@ -18,43 +23,205 @@ class PacketType(IntEnum):
|
|||||||
CLOSE = 7
|
CLOSE = 7
|
||||||
|
|
||||||
|
|
||||||
class ResultKind(IntEnum):
|
@dataclass(kw_only=True)
|
||||||
ERROR = 1
|
class Packet:
|
||||||
VOID = 2
|
type: PacketType
|
||||||
HAS_VALUE = 3
|
header: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
def parse_packet(data: bytes) -> tuple[PacketType, list[Any]]:
|
@dataclass(kw_only=True)
|
||||||
length, offset = decode_varint(data)
|
class InvocationPacket(Packet):
|
||||||
message_data = data[offset : offset + length]
|
type: PacketType = PacketType.INVOCATION
|
||||||
# FIXME: custom deserializer for APIMod
|
invocation_id: str | None
|
||||||
# https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
|
target: str
|
||||||
unpacked = msgpack.unpackb(
|
arguments: list[Any] | None = None
|
||||||
message_data, raw=False, strict_map_key=False, use_list=True
|
stream_ids: list[str] | None = None
|
||||||
)
|
|
||||||
return PacketType(unpacked[0]), unpacked[1:]
|
|
||||||
|
|
||||||
|
|
||||||
def encode_varint(value: int) -> bytes:
|
@dataclass(kw_only=True)
|
||||||
result = []
|
class CompletionPacket(Packet):
|
||||||
while value >= 0x80:
|
type: PacketType = PacketType.COMPLETION
|
||||||
result.append((value & 0x7F) | 0x80)
|
invocation_id: str
|
||||||
value >>= 7
|
result: Any
|
||||||
result.append(value & 0x7F)
|
error: str | None = None
|
||||||
return bytes(result)
|
|
||||||
|
|
||||||
|
|
||||||
def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
|
@dataclass(kw_only=True)
|
||||||
result = 0
|
class PingPacket(Packet):
|
||||||
shift = 0
|
type: PacketType = PacketType.PING
|
||||||
pos = offset
|
|
||||||
|
|
||||||
while pos < len(data):
|
|
||||||
byte = data[pos]
|
|
||||||
result |= (byte & 0x7F) << shift
|
|
||||||
pos += 1
|
|
||||||
if (byte & 0x80) == 0:
|
|
||||||
break
|
|
||||||
shift += 7
|
|
||||||
|
|
||||||
return result, pos
|
PACKETS = {
|
||||||
|
PacketType.INVOCATION: InvocationPacket,
|
||||||
|
PacketType.COMPLETION: CompletionPacket,
|
||||||
|
PacketType.PING: PingPacket,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Protocol(TypingProtocol):
|
||||||
|
@staticmethod
|
||||||
|
def decode(input: bytes) -> Packet: ...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode(packet: Packet) -> bytes: ...
|
||||||
|
|
||||||
|
|
||||||
|
class MsgpackProtocol:
|
||||||
|
@staticmethod
|
||||||
|
def _encode_varint(value: int) -> bytes:
|
||||||
|
result = []
|
||||||
|
while value >= 0x80:
|
||||||
|
result.append((value & 0x7F) | 0x80)
|
||||||
|
value >>= 7
|
||||||
|
result.append(value & 0x7F)
|
||||||
|
return bytes(result)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
|
||||||
|
result = 0
|
||||||
|
shift = 0
|
||||||
|
pos = offset
|
||||||
|
|
||||||
|
while pos < len(data):
|
||||||
|
byte = data[pos]
|
||||||
|
result |= (byte & 0x7F) << shift
|
||||||
|
pos += 1
|
||||||
|
if (byte & 0x80) == 0:
|
||||||
|
break
|
||||||
|
shift += 7
|
||||||
|
|
||||||
|
return result, pos
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def decode(input: bytes) -> Packet:
|
||||||
|
length, offset = MsgpackProtocol._decode_varint(input)
|
||||||
|
message_data = input[offset : offset + length]
|
||||||
|
# FIXME: custom deserializer for APIMod
|
||||||
|
# https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
|
||||||
|
unpacked = msgpack.unpackb(
|
||||||
|
message_data, raw=False, strict_map_key=False, use_list=True
|
||||||
|
)
|
||||||
|
packet_type = PacketType(unpacked[0])
|
||||||
|
if packet_type not in PACKETS:
|
||||||
|
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||||
|
match packet_type:
|
||||||
|
case PacketType.INVOCATION:
|
||||||
|
return InvocationPacket(
|
||||||
|
header=unpacked[1],
|
||||||
|
invocation_id=unpacked[2],
|
||||||
|
target=unpacked[3],
|
||||||
|
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
||||||
|
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
||||||
|
)
|
||||||
|
case PacketType.COMPLETION:
|
||||||
|
result_kind = unpacked[3]
|
||||||
|
return CompletionPacket(
|
||||||
|
header=unpacked[1],
|
||||||
|
invocation_id=unpacked[2],
|
||||||
|
error=unpacked[4] if result_kind == 1 else None,
|
||||||
|
result=unpacked[5] if result_kind == 3 else None,
|
||||||
|
)
|
||||||
|
case PacketType.PING:
|
||||||
|
return PingPacket()
|
||||||
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode(packet: Packet) -> bytes:
|
||||||
|
payload = [packet.type.value, packet.header or {}]
|
||||||
|
if isinstance(packet, InvocationPacket):
|
||||||
|
payload.extend(
|
||||||
|
[
|
||||||
|
packet.invocation_id,
|
||||||
|
packet.target,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if packet.arguments is not None:
|
||||||
|
payload.append(packet.arguments)
|
||||||
|
if packet.stream_ids is not None:
|
||||||
|
payload.append(packet.stream_ids)
|
||||||
|
elif isinstance(packet, CompletionPacket):
|
||||||
|
result_kind = 2
|
||||||
|
if packet.error:
|
||||||
|
result_kind = 1
|
||||||
|
elif packet.result is None:
|
||||||
|
result_kind = 3
|
||||||
|
payload.extend(
|
||||||
|
[
|
||||||
|
packet.invocation_id,
|
||||||
|
result_kind,
|
||||||
|
packet.error or packet.result or None,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif isinstance(packet, PingPacket):
|
||||||
|
pass
|
||||||
|
|
||||||
|
data = msgpack.packb(payload, use_bin_type=True)
|
||||||
|
return MsgpackProtocol._encode_varint(len(data)) + data
|
||||||
|
|
||||||
|
|
||||||
|
class JSONProtocol:
|
||||||
|
@staticmethod
|
||||||
|
def decode(input: bytes) -> Packet:
|
||||||
|
data = json.loads(input[:-1].decode("utf-8"))
|
||||||
|
packet_type = PacketType(data["type"])
|
||||||
|
if packet_type not in PACKETS:
|
||||||
|
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||||
|
match packet_type:
|
||||||
|
case PacketType.INVOCATION:
|
||||||
|
return InvocationPacket(
|
||||||
|
header=data.get("header"),
|
||||||
|
invocation_id=data.get("invocationId"),
|
||||||
|
target=data["target"],
|
||||||
|
arguments=data.get("arguments"),
|
||||||
|
stream_ids=data.get("streamIds"),
|
||||||
|
)
|
||||||
|
case PacketType.COMPLETION:
|
||||||
|
return CompletionPacket(
|
||||||
|
header=data.get("header"),
|
||||||
|
invocation_id=data["invocationId"],
|
||||||
|
error=data.get("error"),
|
||||||
|
result=data.get("result"),
|
||||||
|
)
|
||||||
|
case PacketType.PING:
|
||||||
|
return PingPacket()
|
||||||
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode(packet: Packet) -> bytes:
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"type": packet.type.value,
|
||||||
|
}
|
||||||
|
if packet.header:
|
||||||
|
payload["header"] = packet.header
|
||||||
|
if isinstance(packet, InvocationPacket):
|
||||||
|
payload.update(
|
||||||
|
{
|
||||||
|
"target": packet.target,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if packet.invocation_id is not None:
|
||||||
|
payload["invocationId"] = packet.invocation_id
|
||||||
|
if packet.arguments is not None:
|
||||||
|
payload["arguments"] = packet.arguments
|
||||||
|
if packet.stream_ids is not None:
|
||||||
|
payload["streamIds"] = packet.stream_ids
|
||||||
|
elif isinstance(packet, CompletionPacket):
|
||||||
|
payload.update(
|
||||||
|
{
|
||||||
|
"invocationId": packet.invocation_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if packet.error is not None:
|
||||||
|
payload["error"] = packet.error
|
||||||
|
if packet.result is not None:
|
||||||
|
payload["result"] = packet.result
|
||||||
|
elif isinstance(packet, PingPacket):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return json.dumps(payload).encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
PROTOCOLS: dict[str, Protocol] = {
|
||||||
|
"json": JSONProtocol,
|
||||||
|
"messagepack": MsgpackProtocol,
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.dependencies.user import get_current_user_by_token
|
|||||||
from app.models.signalr import NegotiateResponse, Transport
|
from app.models.signalr import NegotiateResponse, Transport
|
||||||
|
|
||||||
from .hub import Hubs
|
from .hub import Hubs
|
||||||
from .packet import SEP
|
from .packet import PROTOCOLS, SEP
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -62,13 +62,14 @@ async def connect(
|
|||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
||||||
# handshake
|
# handshake
|
||||||
handshake = await websocket.receive_bytes()
|
handshake = await websocket.receive()
|
||||||
handshake_payload = json.loads(handshake[:-1])
|
message = handshake.get("bytes") or handshake.get("text")
|
||||||
|
if not message:
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
handshake_payload = json.loads(message[:-1])
|
||||||
error = ""
|
error = ""
|
||||||
if (protocol := handshake_payload.get("protocol")) != "messagepack" or (
|
protocol = handshake_payload.get("protocol", "json")
|
||||||
handshake_payload.get("version")
|
|
||||||
) != 1:
|
|
||||||
error = f"Requested protocol '{protocol}' is not available."
|
|
||||||
|
|
||||||
client = None
|
client = None
|
||||||
try:
|
try:
|
||||||
@@ -76,7 +77,10 @@ async def connect(
|
|||||||
connection_id=user_id,
|
connection_id=user_id,
|
||||||
connection_token=id,
|
connection_token=id,
|
||||||
connection=websocket,
|
connection=websocket,
|
||||||
|
protocol=PROTOCOLS[protocol],
|
||||||
)
|
)
|
||||||
|
except KeyError:
|
||||||
|
error = f"Protocol '{protocol}' is not supported."
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
error = f"Connection {id} has waited too long."
|
error = f"Connection {id} has waited too long."
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Literal
|
from typing import Any
|
||||||
|
|
||||||
from .packet import ResultKind
|
|
||||||
|
|
||||||
|
|
||||||
class ResultStore:
|
class ResultStore:
|
||||||
@@ -22,21 +20,17 @@ class ResultStore:
|
|||||||
return str(s)
|
return str(s)
|
||||||
|
|
||||||
def add_result(
|
def add_result(
|
||||||
self, invocation_id: str, type: ResultKind, result: dict[str, Any] | None
|
self, invocation_id: str, result: Any, error: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(invocation_id, str) and invocation_id.isdecimal():
|
if isinstance(invocation_id, str) and invocation_id.isdecimal():
|
||||||
if future := self._futures.get(invocation_id):
|
if future := self._futures.get(invocation_id):
|
||||||
future.set_result((type, result))
|
future.set_result((result, error))
|
||||||
|
|
||||||
async def fetch(
|
async def fetch(
|
||||||
self,
|
self,
|
||||||
invocation_id: str,
|
invocation_id: str,
|
||||||
timeout: float | None, # noqa: ASYNC109
|
timeout: float | None, # noqa: ASYNC109
|
||||||
) -> (
|
) -> tuple[Any, str | None]:
|
||||||
tuple[Literal[ResultKind.ERROR], str]
|
|
||||||
| tuple[Literal[ResultKind.VOID], None]
|
|
||||||
| tuple[Literal[ResultKind.HAS_VALUE], Any]
|
|
||||||
):
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
future = asyncio.get_event_loop().create_future()
|
||||||
self._futures[invocation_id] = future
|
self._futures[invocation_id] = future
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user