feat(signalr): graceful state manager

This commit is contained in:
MingxuanGame
2025-07-28 08:46:20 +00:00
parent 722a6e57d8
commit f60283a6c2
9 changed files with 234 additions and 109 deletions

View File

@@ -1,13 +1,16 @@
from __future__ import annotations
from abc import abstractmethod
import asyncio
import time
import traceback
from typing import Any
from app.config import settings
from app.models.signalr import UserState
from app.signalr.exception import InvokeException
from app.signalr.packet import (
ClosePacket,
CompletionPacket,
InvocationPacket,
Packet,
@@ -22,6 +25,19 @@ from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
class CloseConnection(Exception):
def __init__(
self,
message: str = "Connection closed",
allow_reconnect: bool = False,
from_client: bool = False,
) -> None:
super().__init__(message)
self.message = message
self.allow_reconnect = allow_reconnect
self.from_client = from_client
class Client:
def __init__(
self,
@@ -39,7 +55,11 @@ class Client:
self._store = ResultStore()
def __hash__(self) -> int:
return hash(self.connection_id + self.connection_token)
return hash(self.connection_token)
@property
def user_id(self) -> int:
return int(self.connection_id)
async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.procotol.encode(packet))
@@ -48,7 +68,7 @@ class Client:
message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
return [PingPacket()] # FIXME: Graceful empty message handling
return []
return self.procotol.decode(d)
async def _ping(self):
@@ -63,12 +83,13 @@ class Client:
break
class Hub:
class Hub[TState: UserState]:
def __init__(self) -> None:
self.clients: dict[str, Client] = {}
self.waited_clients: dict[str, int] = {}
self.tasks: set[asyncio.Task] = set()
self.groups: dict[str, set[Client]] = {}
self.state: dict[int, TState] = {}
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
self.waited_clients[connection_token] = timestamp
@@ -79,7 +100,25 @@ class Hub:
return client
return default
def add_client(
@abstractmethod
def create_state(self, client: Client) -> TState:
raise NotImplementedError
def get_or_create_state(self, client: Client) -> TState:
if (state := self.state.get(client.user_id)) is not None:
return state
state = self.create_state(client)
self.state[client.user_id] = state
return state
def add_to_group(self, client: Client, group_id: str) -> None:
self.groups.setdefault(group_id, set()).add(client)
def remove_from_group(self, client: Client, group_id: str) -> None:
if group_id in self.groups:
self.groups[group_id].discard(client)
async def add_client(
self,
connection_id: str,
connection_token: str,
@@ -104,19 +143,34 @@ class Hub:
client._ping_task = task
return client
async def remove_client(self, client: Client) -> None:
del self.clients[client.connection_token]
if client._listen_task:
client._listen_task.cancel()
if client._ping_task:
client._ping_task.cancel()
for group in self.groups.values():
group.discard(client)
await self.clean_state(client, False)
@abstractmethod
async def _clean_state(self, state: TState) -> None:
return
async def clean_state(self, client: Client, disconnected: bool) -> None:
if (state := self.state.get(client.user_id)) is None:
return
if disconnected and client.connection_token != state.connection_token:
return
try:
await self._clean_state(state)
except Exception:
...
async def on_connect(self, client: Client) -> None:
if method := getattr(self, "on_client_connect", None):
await method(client)
async def remove_client(self, connection_id: str) -> None:
if client := self.clients.get(connection_id):
del self.clients[connection_id]
if client._listen_task:
client._listen_task.cancel()
if client._ping_task:
client._ping_task.cancel()
await client.connection.close()
async def send_packet(self, client: Client, packet: Packet) -> None:
await client.send_packet(packet)
@@ -135,26 +189,40 @@ class Hub:
await asyncio.gather(*tasks)
async def _listen_client(self, client: Client) -> None:
jump = False
while not jump:
try:
try:
while True:
packets = await client.receive_packets()
for packet in packets:
if isinstance(packet, PingPacket):
continue
elif isinstance(packet, ClosePacket):
raise CloseConnection(
packet.error or "Connection closed by client",
packet.allow_reconnect,
True,
)
task = asyncio.create_task(self._handle_packet(client, packet))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e:
print(
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
)
jump = True
except Exception as e:
except WebSocketDisconnect as e:
print(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}")
except RuntimeError as e:
if "disconnect message" in str(e):
print(f"Client {client.connection_id} closed the connection.")
else:
traceback.print_exc()
print(f"Error in client {client.connection_id}: {e}")
jump = True
await self.remove_client(client.connection_id)
print(f"RuntimeError in client {client.connection_id}: {e}")
except CloseConnection as e:
if not e.from_client:
await client.send_packet(
ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect)
)
print(f"Client {client.connection_id} closed the connection: {e.message}")
except Exception as e:
traceback.print_exc()
print(f"Error in client {client.connection_id}: {e}")
await self.remove_client(client)
async def _handle_packet(self, client: Client, packet: Packet) -> None:
if isinstance(packet, PingPacket):