feat(signalr): graceful state manager
This commit is contained in:
@@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.models.signalr import UserState
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
@@ -126,7 +128,7 @@ UserActivity = (
|
||||
)
|
||||
|
||||
|
||||
class MetadataClientState(BaseModel):
|
||||
class MetadataClientState(UserState):
|
||||
user_activity: UserActivity | None = None
|
||||
status: OnlineStatus | None = None
|
||||
|
||||
|
||||
@@ -64,3 +64,8 @@ class NegotiateResponse(BaseModel):
|
||||
connectionToken: str
|
||||
negotiateVersion: int = 1
|
||||
availableTransports: list[Transport]
|
||||
|
||||
|
||||
class UserState(BaseModel):
|
||||
connection_id: str
|
||||
connection_token: str
|
||||
|
||||
@@ -9,7 +9,7 @@ from app.models.beatmap import BeatmapRankStatus
|
||||
from .score import (
|
||||
ScoreStatisticsInt,
|
||||
)
|
||||
from .signalr import MessagePackArrayModel
|
||||
from .signalr import MessagePackArrayModel, UserState
|
||||
|
||||
import msgpack
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -128,7 +128,7 @@ class StoreScore(BaseModel):
|
||||
replay_frames: list[LegacyReplayFrame] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StoreClientState(BaseModel):
|
||||
class StoreClientState(UserState):
|
||||
state: SpectatorState | None = None
|
||||
beatmap_status: BeatmapRankStatus | None = None
|
||||
checksum: str | None = None
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.models.signalr import UserState
|
||||
from app.signalr.exception import InvokeException
|
||||
from app.signalr.packet import (
|
||||
ClosePacket,
|
||||
CompletionPacket,
|
||||
InvocationPacket,
|
||||
Packet,
|
||||
@@ -22,6 +25,19 @@ from pydantic import BaseModel
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
|
||||
class CloseConnection(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Connection closed",
|
||||
allow_reconnect: bool = False,
|
||||
from_client: bool = False,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.allow_reconnect = allow_reconnect
|
||||
self.from_client = from_client
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -39,7 +55,11 @@ class Client:
|
||||
self._store = ResultStore()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.connection_id + self.connection_token)
|
||||
return hash(self.connection_token)
|
||||
|
||||
@property
|
||||
def user_id(self) -> int:
|
||||
return int(self.connection_id)
|
||||
|
||||
async def send_packet(self, packet: Packet):
|
||||
await self.connection.send_bytes(self.procotol.encode(packet))
|
||||
@@ -48,7 +68,7 @@ class Client:
|
||||
message = await self.connection.receive()
|
||||
d = message.get("bytes") or message.get("text", "").encode()
|
||||
if not d:
|
||||
return [PingPacket()] # FIXME: Graceful empty message handling
|
||||
return []
|
||||
return self.procotol.decode(d)
|
||||
|
||||
async def _ping(self):
|
||||
@@ -63,12 +83,13 @@ class Client:
|
||||
break
|
||||
|
||||
|
||||
class Hub:
|
||||
class Hub[TState: UserState]:
|
||||
def __init__(self) -> None:
|
||||
self.clients: dict[str, Client] = {}
|
||||
self.waited_clients: dict[str, int] = {}
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
self.groups: dict[str, set[Client]] = {}
|
||||
self.state: dict[int, TState] = {}
|
||||
|
||||
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
|
||||
self.waited_clients[connection_token] = timestamp
|
||||
@@ -79,7 +100,25 @@ class Hub:
|
||||
return client
|
||||
return default
|
||||
|
||||
def add_client(
|
||||
@abstractmethod
|
||||
def create_state(self, client: Client) -> TState:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_or_create_state(self, client: Client) -> TState:
|
||||
if (state := self.state.get(client.user_id)) is not None:
|
||||
return state
|
||||
state = self.create_state(client)
|
||||
self.state[client.user_id] = state
|
||||
return state
|
||||
|
||||
def add_to_group(self, client: Client, group_id: str) -> None:
|
||||
self.groups.setdefault(group_id, set()).add(client)
|
||||
|
||||
def remove_from_group(self, client: Client, group_id: str) -> None:
|
||||
if group_id in self.groups:
|
||||
self.groups[group_id].discard(client)
|
||||
|
||||
async def add_client(
|
||||
self,
|
||||
connection_id: str,
|
||||
connection_token: str,
|
||||
@@ -104,19 +143,34 @@ class Hub:
|
||||
client._ping_task = task
|
||||
return client
|
||||
|
||||
async def remove_client(self, client: Client) -> None:
|
||||
del self.clients[client.connection_token]
|
||||
if client._listen_task:
|
||||
client._listen_task.cancel()
|
||||
if client._ping_task:
|
||||
client._ping_task.cancel()
|
||||
for group in self.groups.values():
|
||||
group.discard(client)
|
||||
await self.clean_state(client, False)
|
||||
|
||||
@abstractmethod
|
||||
async def _clean_state(self, state: TState) -> None:
|
||||
return
|
||||
|
||||
async def clean_state(self, client: Client, disconnected: bool) -> None:
|
||||
if (state := self.state.get(client.user_id)) is None:
|
||||
return
|
||||
if disconnected and client.connection_token != state.connection_token:
|
||||
return
|
||||
try:
|
||||
await self._clean_state(state)
|
||||
except Exception:
|
||||
...
|
||||
|
||||
async def on_connect(self, client: Client) -> None:
|
||||
if method := getattr(self, "on_client_connect", None):
|
||||
await method(client)
|
||||
|
||||
async def remove_client(self, connection_id: str) -> None:
|
||||
if client := self.clients.get(connection_id):
|
||||
del self.clients[connection_id]
|
||||
if client._listen_task:
|
||||
client._listen_task.cancel()
|
||||
if client._ping_task:
|
||||
client._ping_task.cancel()
|
||||
await client.connection.close()
|
||||
|
||||
async def send_packet(self, client: Client, packet: Packet) -> None:
|
||||
await client.send_packet(packet)
|
||||
|
||||
@@ -135,26 +189,40 @@ class Hub:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _listen_client(self, client: Client) -> None:
|
||||
jump = False
|
||||
while not jump:
|
||||
try:
|
||||
try:
|
||||
while True:
|
||||
packets = await client.receive_packets()
|
||||
for packet in packets:
|
||||
if isinstance(packet, PingPacket):
|
||||
continue
|
||||
elif isinstance(packet, ClosePacket):
|
||||
raise CloseConnection(
|
||||
packet.error or "Connection closed by client",
|
||||
packet.allow_reconnect,
|
||||
True,
|
||||
)
|
||||
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except WebSocketDisconnect as e:
|
||||
print(
|
||||
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
jump = True
|
||||
except Exception as e:
|
||||
except WebSocketDisconnect as e:
|
||||
print(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}")
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e):
|
||||
print(f"Client {client.connection_id} closed the connection.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
print(f"Error in client {client.connection_id}: {e}")
|
||||
jump = True
|
||||
await self.remove_client(client.connection_id)
|
||||
print(f"RuntimeError in client {client.connection_id}: {e}")
|
||||
except CloseConnection as e:
|
||||
if not e.from_client:
|
||||
await client.send_packet(
|
||||
ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect)
|
||||
)
|
||||
print(f"Client {client.connection_id} closed the connection: {e.message}")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Error in client {client.connection_id}: {e}")
|
||||
|
||||
await self.remove_client(client)
|
||||
|
||||
async def _handle_packet(self, client: Client, packet: Packet) -> None:
|
||||
if isinstance(packet, PingPacket):
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
from typing import override
|
||||
|
||||
from app.database.relationship import Relationship, RelationshipType
|
||||
from app.dependencies.database import engine
|
||||
@@ -16,32 +17,32 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
||||
|
||||
|
||||
class MetadataHub(Hub):
|
||||
class MetadataHub(Hub[MetadataClientState]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.state: dict[int, MetadataClientState] = {}
|
||||
|
||||
@staticmethod
|
||||
def online_presence_watchers_group() -> str:
|
||||
return ONLINE_PRESENCE_WATCHERS_GROUP
|
||||
|
||||
def broadcast_tasks(
|
||||
self, user_id: int, store: MetadataClientState
|
||||
self, user_id: int, store: MetadataClientState | None
|
||||
) -> set[Coroutine]:
|
||||
if not store.pushable:
|
||||
if store is not None and not store.pushable:
|
||||
return set()
|
||||
data = store.to_dict() if store else None
|
||||
return {
|
||||
self.broadcast_group_call(
|
||||
self.online_presence_watchers_group(),
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
data,
|
||||
),
|
||||
self.broadcast_group_call(
|
||||
self.friend_presence_watchers_group(user_id),
|
||||
"FriendPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
data,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -49,11 +50,21 @@ class MetadataHub(Hub):
|
||||
def friend_presence_watchers_group(user_id: int):
|
||||
return f"metadata:friend-presence-watchers:{user_id}"
|
||||
|
||||
@override
|
||||
async def _clean_state(self, state: MetadataClientState) -> None:
|
||||
if state.pushable:
|
||||
await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None))
|
||||
|
||||
@override
|
||||
def create_state(self, client: Client) -> MetadataClientState:
|
||||
return MetadataClientState(
|
||||
connection_id=client.connection_id,
|
||||
connection_token=client.connection_token,
|
||||
)
|
||||
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
if store := self.state.get(user_id):
|
||||
store = MetadataClientState()
|
||||
self.state[user_id] = store
|
||||
self.get_or_create_state(client)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
@@ -73,6 +84,7 @@ class MetadataHub(Hub):
|
||||
if (
|
||||
friend_state := self.state.get(friend_id)
|
||||
) and friend_state.pushable:
|
||||
print("Pushed")
|
||||
tasks.append(
|
||||
self.broadcast_group_call(
|
||||
self.friend_presence_watchers_group(friend_id),
|
||||
@@ -86,14 +98,10 @@ class MetadataHub(Hub):
|
||||
async def UpdateStatus(self, client: Client, status: int) -> None:
|
||||
status_ = OnlineStatus(status)
|
||||
user_id = int(client.connection_id)
|
||||
store = self.state.get(user_id)
|
||||
if store:
|
||||
if store.status is not None and store.status == status_:
|
||||
return
|
||||
store.status = OnlineStatus(status_)
|
||||
else:
|
||||
store = MetadataClientState(status=OnlineStatus(status_))
|
||||
self.state[user_id] = store
|
||||
store = self.get_or_create_state(client)
|
||||
if store.status is not None and store.status == status_:
|
||||
return
|
||||
store.status = OnlineStatus(status_)
|
||||
tasks = self.broadcast_tasks(user_id, store)
|
||||
tasks.add(
|
||||
self.call_noblock(
|
||||
@@ -112,14 +120,8 @@ class MetadataHub(Hub):
|
||||
if activity_dict
|
||||
else None
|
||||
)
|
||||
store = self.state.get(user_id)
|
||||
if store:
|
||||
store.user_activity = activity
|
||||
else:
|
||||
store = MetadataClientState(
|
||||
user_activity=activity,
|
||||
)
|
||||
self.state[user_id] = store
|
||||
store = self.get_or_create_state(client)
|
||||
store.user_activity = activity
|
||||
tasks = self.broadcast_tasks(user_id, store)
|
||||
tasks.add(
|
||||
self.call_noblock(
|
||||
@@ -144,9 +146,7 @@ class MetadataHub(Hub):
|
||||
if store.pushable
|
||||
]
|
||||
)
|
||||
self.groups.setdefault(self.online_presence_watchers_group(), set()).add(client)
|
||||
self.add_to_group(client, self.online_presence_watchers_group())
|
||||
|
||||
async def EndWatchingUserPresence(self, client: Client) -> None:
|
||||
self.groups.setdefault(self.online_presence_watchers_group(), set()).discard(
|
||||
client
|
||||
)
|
||||
self.remove_from_group(client, self.online_presence_watchers_group())
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import lzma
|
||||
import struct
|
||||
import time
|
||||
from typing import override
|
||||
|
||||
from app.database import Beatmap
|
||||
from app.database.score import Score
|
||||
@@ -140,15 +141,29 @@ def save_replay(
|
||||
replay_path.write_bytes(data)
|
||||
|
||||
|
||||
class SpectatorHub(Hub):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.state: dict[int, StoreClientState] = {}
|
||||
|
||||
class SpectatorHub(Hub[StoreClientState]):
|
||||
@staticmethod
|
||||
def group_id(user_id: int) -> str:
|
||||
return f"watch:{user_id}"
|
||||
|
||||
@override
|
||||
def create_state(self, client: Client) -> StoreClientState:
|
||||
return StoreClientState(
|
||||
connection_id=client.connection_id,
|
||||
connection_token=client.connection_token,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _clean_state(self, state: StoreClientState) -> None:
|
||||
if state.state:
|
||||
await self._end_session(int(state.connection_id), state.state)
|
||||
for target in self.waited_clients:
|
||||
target_client = self.get_client_by_id(target)
|
||||
if target_client:
|
||||
await self.call_noblock(
|
||||
target_client, "UserEndedWatching", int(state.connection_id)
|
||||
)
|
||||
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
tasks = [
|
||||
self.call_noblock(
|
||||
@@ -163,8 +178,8 @@ class SpectatorHub(Hub):
|
||||
self, client: Client, score_token: int, state: SpectatorState
|
||||
) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
previous_state = self.state.get(user_id)
|
||||
if previous_state is not None:
|
||||
store = self.get_or_create_state(client)
|
||||
if store.state is not None:
|
||||
return
|
||||
if state.beatmap_id is None or state.ruleset_id is None:
|
||||
return
|
||||
@@ -183,23 +198,19 @@ class SpectatorHub(Hub):
|
||||
if not user:
|
||||
return
|
||||
name = user.name
|
||||
store = StoreClientState(
|
||||
state=state,
|
||||
beatmap_status=beatmap.beatmap_status,
|
||||
checksum=beatmap.checksum,
|
||||
ruleset_id=state.ruleset_id,
|
||||
score_token=score_token,
|
||||
watched_user=set(),
|
||||
score=StoreScore(
|
||||
score_info=ScoreInfo(
|
||||
mods=state.mods,
|
||||
user=APIUser(id=user_id, name=name),
|
||||
ruleset=state.ruleset_id,
|
||||
maximum_statistics=state.maximum_statistics,
|
||||
)
|
||||
),
|
||||
store.state = state
|
||||
store.beatmap_status = beatmap.beatmap_status
|
||||
store.checksum = beatmap.checksum
|
||||
store.ruleset_id = state.ruleset_id
|
||||
store.score_token = score_token
|
||||
store.score = StoreScore(
|
||||
score_info=ScoreInfo(
|
||||
mods=state.mods,
|
||||
user=APIUser(id=user_id, name=name),
|
||||
ruleset=state.ruleset_id,
|
||||
maximum_statistics=state.maximum_statistics,
|
||||
)
|
||||
)
|
||||
self.state[user_id] = store
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserBeganPlaying",
|
||||
@@ -209,19 +220,16 @@ class SpectatorHub(Hub):
|
||||
|
||||
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
state = self.state.get(user_id)
|
||||
if not state:
|
||||
state = self.get_or_create_state(client)
|
||||
if not state.score:
|
||||
return
|
||||
score = state.score
|
||||
if not score:
|
||||
return
|
||||
score.score_info.acc = frame_data.header.acc
|
||||
score.score_info.combo = frame_data.header.combo
|
||||
score.score_info.max_combo = frame_data.header.max_combo
|
||||
score.score_info.statistics = frame_data.header.statistics
|
||||
score.score_info.total_score = frame_data.header.total_score
|
||||
score.score_info.mods = frame_data.header.mods
|
||||
score.replay_frames.extend(frame_data.frames)
|
||||
state.score.score_info.acc = frame_data.header.acc
|
||||
state.score.score_info.combo = frame_data.header.combo
|
||||
state.score.score_info.max_combo = frame_data.header.max_combo
|
||||
state.score.score_info.statistics = frame_data.header.statistics
|
||||
state.score.score_info.total_score = frame_data.header.total_score
|
||||
state.score.score_info.mods = frame_data.header.mods
|
||||
state.score.replay_frames.extend(frame_data.frames)
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserSentFrames",
|
||||
@@ -231,9 +239,7 @@ class SpectatorHub(Hub):
|
||||
|
||||
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
store = self.state.get(user_id)
|
||||
if not store:
|
||||
return
|
||||
store = self.get_or_create_state(client)
|
||||
score = store.score
|
||||
if not score or not store.score_token:
|
||||
return
|
||||
@@ -294,8 +300,15 @@ class SpectatorHub(Hub):
|
||||
):
|
||||
# save replay
|
||||
await _save_replay()
|
||||
store.state = None
|
||||
store.beatmap_status = None
|
||||
store.checksum = None
|
||||
store.ruleset_id = None
|
||||
store.score_token = None
|
||||
store.score = None
|
||||
await self._end_session(user_id, state)
|
||||
|
||||
del self.state[user_id]
|
||||
async def _end_session(self, user_id: int, state: SpectatorState) -> None:
|
||||
if state.state == SpectatedUserState.Playing:
|
||||
state.state = SpectatedUserState.Quit
|
||||
await self.broadcast_group_call(
|
||||
@@ -308,22 +321,18 @@ class SpectatorHub(Hub):
|
||||
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
print(f"StartWatchingUser -> {client.connection_id} {target_id}")
|
||||
user_id = int(client.connection_id)
|
||||
target_store = self.state.get(target_id)
|
||||
if target_store and target_store.state:
|
||||
target_store = self.get_or_create_state(client)
|
||||
if target_store.state:
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"UserBeganPlaying",
|
||||
target_id,
|
||||
serialize_to_list(target_store.state),
|
||||
)
|
||||
store = self.state.get(user_id)
|
||||
if store is None:
|
||||
store = StoreClientState(
|
||||
watched_user=set(),
|
||||
)
|
||||
store = self.get_or_create_state(client)
|
||||
store.watched_user.add(target_id)
|
||||
self.state[user_id] = store
|
||||
self.groups.setdefault(self.group_id(target_id), set()).add(client)
|
||||
|
||||
self.add_to_group(client, self.group_id(target_id))
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
@@ -340,7 +349,7 @@ class SpectatorHub(Hub):
|
||||
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
print(f"EndWatchingUser -> {client.connection_id} {target_id}")
|
||||
user_id = int(client.connection_id)
|
||||
self.groups[self.group_id(target_id)].discard(client)
|
||||
self.remove_from_group(client, self.group_id(target_id))
|
||||
store = self.state.get(user_id)
|
||||
if store:
|
||||
store.watched_user.discard(target_id)
|
||||
|
||||
@@ -51,10 +51,18 @@ class PingPacket(Packet):
|
||||
type: PacketType = PacketType.PING
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ClosePacket(Packet):
|
||||
type: PacketType = PacketType.CLOSE
|
||||
error: str | None = None
|
||||
allow_reconnect: bool = False
|
||||
|
||||
|
||||
PACKETS = {
|
||||
PacketType.INVOCATION: InvocationPacket,
|
||||
PacketType.COMPLETION: CompletionPacket,
|
||||
PacketType.PING: PingPacket,
|
||||
PacketType.CLOSE: ClosePacket,
|
||||
}
|
||||
|
||||
|
||||
@@ -127,6 +135,13 @@ class MsgpackProtocol:
|
||||
]
|
||||
case PacketType.PING:
|
||||
return [PingPacket()]
|
||||
case PacketType.CLOSE:
|
||||
return [
|
||||
ClosePacket(
|
||||
error=unpacked[1],
|
||||
allow_reconnect=unpacked[2] if len(unpacked) > 2 else False,
|
||||
)
|
||||
]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@staticmethod
|
||||
@@ -156,6 +171,13 @@ class MsgpackProtocol:
|
||||
packet.error or packet.result or None,
|
||||
]
|
||||
)
|
||||
elif isinstance(packet, ClosePacket):
|
||||
payload.extend(
|
||||
[
|
||||
packet.error or "",
|
||||
packet.allow_reconnect,
|
||||
]
|
||||
)
|
||||
elif isinstance(packet, PingPacket):
|
||||
payload.pop(-1)
|
||||
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
|
||||
@@ -198,6 +220,13 @@ class JSONProtocol:
|
||||
]
|
||||
case PacketType.PING:
|
||||
return [PingPacket()]
|
||||
case PacketType.CLOSE:
|
||||
return [
|
||||
ClosePacket(
|
||||
error=data.get("error"),
|
||||
allow_reconnect=data.get("allowReconnect", False),
|
||||
)
|
||||
]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@staticmethod
|
||||
@@ -231,6 +260,14 @@ class JSONProtocol:
|
||||
payload["result"] = packet.result
|
||||
elif isinstance(packet, PingPacket):
|
||||
pass
|
||||
elif isinstance(packet, ClosePacket):
|
||||
payload.update(
|
||||
{
|
||||
"allowReconnect": packet.allow_reconnect,
|
||||
}
|
||||
)
|
||||
if packet.error is not None:
|
||||
payload["error"] = packet.error
|
||||
return json.dumps(payload).encode("utf-8") + SEP
|
||||
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ async def connect(
|
||||
|
||||
client = None
|
||||
try:
|
||||
client = hub_.add_client(
|
||||
client = await hub_.add_client(
|
||||
connection_id=user_id,
|
||||
connection_token=id,
|
||||
connection=websocket,
|
||||
@@ -87,13 +87,17 @@ async def connect(
|
||||
except ValueError as e:
|
||||
error = str(e)
|
||||
payload = {"error": error} if error else {}
|
||||
|
||||
# finish handshake
|
||||
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
|
||||
if error or not client:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
await hub_.clean_state(client, False)
|
||||
task = asyncio.create_task(hub_.on_connect(client))
|
||||
hub_.tasks.add(task)
|
||||
task.add_done_callback(hub_.tasks.discard)
|
||||
await hub_._listen_client(client)
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
...
|
||||
|
||||
@@ -77,7 +77,7 @@ mark-parentheses = false
|
||||
keep-runtime-typing = true
|
||||
|
||||
[tool.pyright]
|
||||
pythonVersion = "3.11"
|
||||
pythonVersion = "3.12"
|
||||
pythonPlatform = "All"
|
||||
|
||||
typeCheckingMode = "standard"
|
||||
|
||||
Reference in New Issue
Block a user