212 lines
7.6 KiB
Python
212 lines
7.6 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import time
|
|
from typing import Any
|
|
|
|
from app.config import settings
|
|
from app.router.signalr.exception import InvokeException
|
|
from app.router.signalr.packet import (
|
|
PacketType,
|
|
ResultKind,
|
|
encode_varint,
|
|
parse_packet,
|
|
)
|
|
from app.router.signalr.store import ResultStore
|
|
from app.router.signalr.utils import get_signature
|
|
|
|
from fastapi import WebSocket
|
|
import msgpack
|
|
from pydantic import BaseModel
|
|
from starlette.websockets import WebSocketDisconnect
|
|
|
|
|
|
class Client:
|
|
def __init__(
|
|
self, connection_id: str, connection_token: str, connection: WebSocket
|
|
) -> None:
|
|
self.connection_id = connection_id
|
|
self.connection_token = connection_token
|
|
self.connection = connection
|
|
self._listen_task: asyncio.Task | None = None
|
|
self._ping_task: asyncio.Task | None = None
|
|
self._store = ResultStore()
|
|
|
|
async def send_packet(self, type: PacketType, packet: list[Any]):
|
|
packet.insert(0, type.value)
|
|
payload = msgpack.packb(packet)
|
|
length = encode_varint(len(payload))
|
|
await self.connection.send_bytes(length + payload)
|
|
|
|
async def _ping(self):
|
|
while True:
|
|
try:
|
|
await self.send_packet(PacketType.PING, [])
|
|
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
|
|
except WebSocketDisconnect:
|
|
break
|
|
except Exception as e:
|
|
print(f"Error in ping task for {self.connection_id}: {e}")
|
|
break
|
|
|
|
|
|
class Hub:
|
|
def __init__(self) -> None:
|
|
self.clients: dict[str, Client] = {}
|
|
self.waited_clients: dict[str, int] = {}
|
|
self.tasks: set[asyncio.Task] = set()
|
|
|
|
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
|
|
self.waited_clients[connection_token] = timestamp
|
|
|
|
def add_client(
|
|
self, connection_id: str, connection_token: str, connection: WebSocket
|
|
) -> Client:
|
|
if connection_token in self.clients:
|
|
raise ValueError(
|
|
f"Client with connection token {connection_token} already exists."
|
|
)
|
|
if connection_token in self.waited_clients:
|
|
if (
|
|
self.waited_clients[connection_token]
|
|
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
|
|
):
|
|
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
|
del self.waited_clients[connection_token]
|
|
client = Client(connection_id, connection_token, connection)
|
|
self.clients[connection_token] = client
|
|
task = asyncio.create_task(client._ping())
|
|
self.tasks.add(task)
|
|
client._ping_task = task
|
|
return client
|
|
|
|
async def remove_client(self, connection_id: str) -> None:
|
|
if client := self.clients.get(connection_id):
|
|
del self.clients[connection_id]
|
|
if client._listen_task:
|
|
client._listen_task.cancel()
|
|
if client._ping_task:
|
|
client._ping_task.cancel()
|
|
await client.connection.close()
|
|
|
|
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]):
|
|
await client.send_packet(type, packet)
|
|
|
|
async def _listen_client(self, client: Client) -> None:
|
|
jump = False
|
|
while not jump:
|
|
try:
|
|
message = await client.connection.receive_bytes()
|
|
packet_type, packet_data = parse_packet(message)
|
|
task = asyncio.create_task(
|
|
self._handle_packet(client, packet_type, packet_data)
|
|
)
|
|
self.tasks.add(task)
|
|
task.add_done_callback(self.tasks.discard)
|
|
except WebSocketDisconnect as e:
|
|
if e.code == 1005:
|
|
continue
|
|
print(
|
|
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
|
)
|
|
jump = True
|
|
except Exception as e:
|
|
print(f"Error in client {client.connection_id}: {e}")
|
|
jump = True
|
|
await self.remove_client(client.connection_id)
|
|
|
|
async def _handle_packet(
|
|
self, client: Client, type: PacketType, packet: list[Any]
|
|
) -> None:
|
|
match type:
|
|
case PacketType.PING:
|
|
...
|
|
case PacketType.INVOCATION:
|
|
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration]
|
|
target: str = packet[2]
|
|
args: list[Any] | None = packet[3]
|
|
if args is None:
|
|
args = []
|
|
# streams: list[str] | None = packet[4] # TODO: stream support
|
|
code = ResultKind.VOID
|
|
result = None
|
|
try:
|
|
result = await self.invoke_method(client, target, args)
|
|
if result is not None:
|
|
code = ResultKind.HAS_VALUE
|
|
except InvokeException as e:
|
|
code = ResultKind.ERROR
|
|
result = e.message
|
|
|
|
except Exception as e:
|
|
code = ResultKind.ERROR
|
|
result = str(e)
|
|
|
|
packet = [
|
|
{}, # header
|
|
invocation_id,
|
|
code.value,
|
|
]
|
|
if result is not None:
|
|
packet.append(result)
|
|
if invocation_id is not None:
|
|
await client.send_packet(
|
|
PacketType.COMPLETION,
|
|
packet,
|
|
)
|
|
case PacketType.COMPLETION:
|
|
invocation_id: str = packet[1]
|
|
code: ResultKind = ResultKind(packet[2])
|
|
result: Any = packet[3] if len(packet) > 3 else None
|
|
client._store.add_result(invocation_id, code, result)
|
|
|
|
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
|
|
method_ = getattr(self, method, None)
|
|
call_params = []
|
|
if not method_:
|
|
raise InvokeException(f"Method '{method}' not found in hub.")
|
|
signature = get_signature(method_)
|
|
for name, param in signature.parameters.items():
|
|
if name == "self" or param.annotation is Client:
|
|
continue
|
|
if issubclass(param.annotation, BaseModel):
|
|
call_params.append(param.annotation.model_validate(args.pop(0)))
|
|
else:
|
|
call_params.append(args.pop(0))
|
|
return await method_(client, *call_params)
|
|
|
|
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
|
invocation_id = client._store.get_invocation_id()
|
|
await client.send_packet(
|
|
PacketType.INVOCATION,
|
|
[
|
|
{}, # header
|
|
invocation_id,
|
|
method,
|
|
list(args),
|
|
None, # streams
|
|
],
|
|
)
|
|
r = await client._store.fetch(invocation_id, None)
|
|
if r[0] == ResultKind.HAS_VALUE:
|
|
return r[1]
|
|
if r[0] == ResultKind.ERROR:
|
|
raise InvokeException(r[1])
|
|
return None
|
|
|
|
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
|
|
await client.send_packet(
|
|
PacketType.INVOCATION,
|
|
[
|
|
{}, # header
|
|
None, # invocation_id
|
|
method,
|
|
list(args),
|
|
None, # streams
|
|
],
|
|
)
|
|
return None
|
|
|
|
def __contains__(self, item: str) -> bool:
|
|
return item in self.clients or item in self.waited_clients
|