diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py index 615ea9b..8ae3e65 100644 --- a/app/models/metadata_hub.py +++ b/app/models/metadata_hub.py @@ -3,6 +3,8 @@ from __future__ import annotations from enum import IntEnum from typing import Any, Literal +from app.models.signalr import UserState + from pydantic import BaseModel, ConfigDict, Field @@ -126,7 +128,7 @@ UserActivity = ( ) -class MetadataClientState(BaseModel): +class MetadataClientState(UserState): user_activity: UserActivity | None = None status: OnlineStatus | None = None diff --git a/app/models/signalr.py b/app/models/signalr.py index ac8475f..09c85be 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -64,3 +64,8 @@ class NegotiateResponse(BaseModel): connectionToken: str negotiateVersion: int = 1 availableTransports: list[Transport] + + +class UserState(BaseModel): + connection_id: str + connection_token: str diff --git a/app/models/spectator_hub.py b/app/models/spectator_hub.py index 820eb16..0575a57 100644 --- a/app/models/spectator_hub.py +++ b/app/models/spectator_hub.py @@ -9,7 +9,7 @@ from app.models.beatmap import BeatmapRankStatus from .score import ( ScoreStatisticsInt, ) -from .signalr import MessagePackArrayModel +from .signalr import MessagePackArrayModel, UserState import msgpack from pydantic import BaseModel, Field, field_validator @@ -128,7 +128,7 @@ class StoreScore(BaseModel): replay_frames: list[LegacyReplayFrame] = Field(default_factory=list) -class StoreClientState(BaseModel): +class StoreClientState(UserState): state: SpectatorState | None = None beatmap_status: BeatmapRankStatus | None = None checksum: str | None = None diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index e5c807c..a4882b2 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -1,13 +1,16 @@ from __future__ import annotations +from abc import abstractmethod import asyncio import time import traceback from typing import Any from app.config import settings +from app.models.signalr import UserState from app.signalr.exception import InvokeException from app.signalr.packet import ( + ClosePacket, CompletionPacket, InvocationPacket, Packet, @@ -22,6 +25,19 @@ from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect +class CloseConnection(Exception): + def __init__( + self, + message: str = "Connection closed", + allow_reconnect: bool = False, + from_client: bool = False, + ) -> None: + super().__init__(message) + self.message = message + self.allow_reconnect = allow_reconnect + self.from_client = from_client + + class Client: def __init__( self, @@ -39,7 +55,11 @@ class Client: self._store = ResultStore() def __hash__(self) -> int: - return hash(self.connection_id + self.connection_token) + return hash(self.connection_token) + + @property + def user_id(self) -> int: + return int(self.connection_id) async def send_packet(self, packet: Packet): await self.connection.send_bytes(self.procotol.encode(packet)) @@ -48,7 +68,7 @@ class Client: message = await self.connection.receive() d = message.get("bytes") or message.get("text", "").encode() if not d: - return [PingPacket()] # FIXME: Graceful empty message handling + return [] return self.procotol.decode(d) async def _ping(self): @@ -63,12 +83,13 @@ class Client: break -class Hub: +class Hub[TState: UserState]: def __init__(self) -> None: self.clients: dict[str, Client] = {} self.waited_clients: dict[str, int] = {} self.tasks: set[asyncio.Task] = set() self.groups: dict[str, set[Client]] = {} + self.state: dict[int, TState] = {} def add_waited_client(self, connection_token: str, timestamp: int) -> None: self.waited_clients[connection_token] = timestamp @@ -79,7 +100,25 @@ class Hub: return client return default - def add_client( + @abstractmethod + def create_state(self, client: Client) -> TState: + raise NotImplementedError + + def get_or_create_state(self, client: Client) -> TState: + if (state := self.state.get(client.user_id)) is not None: + return state + state = self.create_state(client) + self.state[client.user_id] = state + return state + + def add_to_group(self, client: Client, group_id: str) -> None: + self.groups.setdefault(group_id, set()).add(client) + + def remove_from_group(self, client: Client, group_id: str) -> None: + if group_id in self.groups: + self.groups[group_id].discard(client) + + async def add_client( self, connection_id: str, connection_token: str, @@ -104,19 +143,34 @@ class Hub: client._ping_task = task return client + async def remove_client(self, client: Client) -> None: + del self.clients[client.connection_token] + if client._listen_task: + client._listen_task.cancel() + if client._ping_task: + client._ping_task.cancel() + for group in self.groups.values(): + group.discard(client) + await self.clean_state(client, False) + + @abstractmethod + async def _clean_state(self, state: TState) -> None: + return + + async def clean_state(self, client: Client, disconnected: bool) -> None: + if (state := self.state.get(client.user_id)) is None: + return + if disconnected and client.connection_token != state.connection_token: + return + try: + await self._clean_state(state) + except Exception: + ... + async def on_connect(self, client: Client) -> None: if method := getattr(self, "on_client_connect", None): await method(client) - async def remove_client(self, connection_id: str) -> None: - if client := self.clients.get(connection_id): - del self.clients[connection_id] - if client._listen_task: - client._listen_task.cancel() - if client._ping_task: - client._ping_task.cancel() - await client.connection.close() - async def send_packet(self, client: Client, packet: Packet) -> None: await client.send_packet(packet) @@ -135,26 +189,40 @@ class Hub: await asyncio.gather(*tasks) async def _listen_client(self, client: Client) -> None: - jump = False - while not jump: - try: + try: + while True: packets = await client.receive_packets() for packet in packets: if isinstance(packet, PingPacket): continue + elif isinstance(packet, ClosePacket): + raise CloseConnection( + packet.error or "Connection closed by client", + packet.allow_reconnect, + True, + ) task = asyncio.create_task(self._handle_packet(client, packet)) self.tasks.add(task) task.add_done_callback(self.tasks.discard) - except WebSocketDisconnect as e: - print( - f"Client {client.connection_id} disconnected: {e.code}, {e.reason}" - ) - jump = True - except Exception as e: + except WebSocketDisconnect as e: + print(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}") + except RuntimeError as e: + if "disconnect message" in str(e): + print(f"Client {client.connection_id} closed the connection.") + else: traceback.print_exc() - print(f"Error in client {client.connection_id}: {e}") - jump = True - await self.remove_client(client.connection_id) + print(f"RuntimeError in client {client.connection_id}: {e}") + except CloseConnection as e: + if not e.from_client: + await client.send_packet( + ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect) + ) + print(f"Client {client.connection_id} closed the connection: {e.message}") + except Exception as e: + traceback.print_exc() + print(f"Error in client {client.connection_id}: {e}") + + await self.remove_client(client) async def _handle_packet(self, client: Client, packet: Packet) -> None: if isinstance(packet, PingPacket): diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 4229723..03774e7 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio from collections.abc import Coroutine +from typing import override from app.database.relationship import Relationship, RelationshipType from app.dependencies.database import engine @@ -16,32 +17,32 @@ from sqlmodel.ext.asyncio.session import AsyncSession ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers" -class MetadataHub(Hub): +class MetadataHub(Hub[MetadataClientState]): def __init__(self) -> None: super().__init__() - self.state: dict[int, MetadataClientState] = {} @staticmethod def online_presence_watchers_group() -> str: return ONLINE_PRESENCE_WATCHERS_GROUP def broadcast_tasks( - self, user_id: int, store: MetadataClientState + self, user_id: int, store: MetadataClientState | None ) -> set[Coroutine]: - if not store.pushable: + if store is not None and not store.pushable: return set() + data = store.to_dict() if store else None return { self.broadcast_group_call( self.online_presence_watchers_group(), "UserPresenceUpdated", user_id, - store.to_dict(), + data, ), self.broadcast_group_call( self.friend_presence_watchers_group(user_id), "FriendPresenceUpdated", user_id, - store.to_dict(), + data, ), } @@ -49,11 +50,21 @@ class MetadataHub(Hub): def friend_presence_watchers_group(user_id: int): return f"metadata:friend-presence-watchers:{user_id}" + @override + async def _clean_state(self, state: MetadataClientState) -> None: + if state.pushable: + await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None)) + + @override + def create_state(self, client: Client) -> MetadataClientState: + return MetadataClientState( + connection_id=client.connection_id, + connection_token=client.connection_token, + ) + async def on_client_connect(self, client: Client) -> None: user_id = int(client.connection_id) - if store := self.state.get(user_id): - store = MetadataClientState() - self.state[user_id] = store + self.get_or_create_state(client) async with AsyncSession(engine) as session: async with session.begin(): @@ -73,6 +84,7 @@ class MetadataHub(Hub): if ( friend_state := self.state.get(friend_id) ) and friend_state.pushable: + print("Pushed") tasks.append( self.broadcast_group_call( self.friend_presence_watchers_group(friend_id), @@ -86,14 +98,10 @@ class MetadataHub(Hub): async def UpdateStatus(self, client: Client, status: int) -> None: status_ = OnlineStatus(status) user_id = int(client.connection_id) - store = self.state.get(user_id) - if store: - if store.status is not None and store.status == status_: - return - store.status = OnlineStatus(status_) - else: - store = MetadataClientState(status=OnlineStatus(status_)) - self.state[user_id] = store + store = self.get_or_create_state(client) + if store.status is not None and store.status == status_: + return + store.status = OnlineStatus(status_) tasks = self.broadcast_tasks(user_id, store) tasks.add( self.call_noblock( @@ -112,14 +120,8 @@ class MetadataHub(Hub): if activity_dict else None ) - store = self.state.get(user_id) - if store: - store.user_activity = activity - else: - store = MetadataClientState( - user_activity=activity, - ) - self.state[user_id] = store + store = self.get_or_create_state(client) + store.user_activity = activity tasks = self.broadcast_tasks(user_id, store) tasks.add( self.call_noblock( @@ -144,9 +146,7 @@ class MetadataHub(Hub): if store.pushable ] ) - self.groups.setdefault(self.online_presence_watchers_group(), set()).add(client) + self.add_to_group(client, self.online_presence_watchers_group()) async def EndWatchingUserPresence(self, client: Client) -> None: - self.groups.setdefault(self.online_presence_watchers_group(), set()).discard( - client - ) + self.remove_from_group(client, self.online_presence_watchers_group()) diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index c0be2dd..5f7e6d1 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -5,6 +5,7 @@ import json import lzma import struct import time +from typing import override from app.database import Beatmap from app.database.score import Score @@ -140,15 +141,29 @@ def save_replay( replay_path.write_bytes(data) -class SpectatorHub(Hub): - def __init__(self) -> None: - super().__init__() - self.state: dict[int, StoreClientState] = {} - +class SpectatorHub(Hub[StoreClientState]): @staticmethod def group_id(user_id: int) -> str: return f"watch:{user_id}" + @override + def create_state(self, client: Client) -> StoreClientState: + return StoreClientState( + connection_id=client.connection_id, + connection_token=client.connection_token, + ) + + @override + async def _clean_state(self, state: StoreClientState) -> None: + if state.state: + await self._end_session(int(state.connection_id), state.state) + for target in self.waited_clients: + target_client = self.get_client_by_id(target) + if target_client: + await self.call_noblock( + target_client, "UserEndedWatching", int(state.connection_id) + ) + async def on_client_connect(self, client: Client) -> None: tasks = [ self.call_noblock( @@ -163,8 +178,8 @@ class SpectatorHub(Hub): self, client: Client, score_token: int, state: SpectatorState ) -> None: user_id = int(client.connection_id) - previous_state = self.state.get(user_id) - if previous_state is not None: + store = self.get_or_create_state(client) + if store.state is not None: return if state.beatmap_id is None or state.ruleset_id is None: return @@ -183,23 +198,19 @@ class SpectatorHub(Hub): if not user: return name = user.name - store = StoreClientState( - state=state, - beatmap_status=beatmap.beatmap_status, - checksum=beatmap.checksum, - ruleset_id=state.ruleset_id, - score_token=score_token, - watched_user=set(), - score=StoreScore( - score_info=ScoreInfo( - mods=state.mods, - user=APIUser(id=user_id, name=name), - ruleset=state.ruleset_id, - maximum_statistics=state.maximum_statistics, - ) - ), + store.state = state + store.beatmap_status = beatmap.beatmap_status + store.checksum = beatmap.checksum + store.ruleset_id = state.ruleset_id + store.score_token = score_token + store.score = StoreScore( + score_info=ScoreInfo( + mods=state.mods, + user=APIUser(id=user_id, name=name), + ruleset=state.ruleset_id, + maximum_statistics=state.maximum_statistics, + ) ) - self.state[user_id] = store await self.broadcast_group_call( self.group_id(user_id), "UserBeganPlaying", @@ -209,19 +220,16 @@ class SpectatorHub(Hub): async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None: user_id = int(client.connection_id) - state = self.state.get(user_id) - if not state: + state = self.get_or_create_state(client) + if not state.score: return - score = state.score - if not score: - return - score.score_info.acc = frame_data.header.acc - score.score_info.combo = frame_data.header.combo - score.score_info.max_combo = frame_data.header.max_combo - score.score_info.statistics = frame_data.header.statistics - score.score_info.total_score = frame_data.header.total_score - score.score_info.mods = frame_data.header.mods - score.replay_frames.extend(frame_data.frames) + state.score.score_info.acc = frame_data.header.acc + state.score.score_info.combo = frame_data.header.combo + state.score.score_info.max_combo = frame_data.header.max_combo + state.score.score_info.statistics = frame_data.header.statistics + state.score.score_info.total_score = frame_data.header.total_score + state.score.score_info.mods = frame_data.header.mods + state.score.replay_frames.extend(frame_data.frames) await self.broadcast_group_call( self.group_id(user_id), "UserSentFrames", @@ -231,9 +239,7 @@ class SpectatorHub(Hub): async def EndPlaySession(self, client: Client, state: SpectatorState) -> None: user_id = int(client.connection_id) - store = self.state.get(user_id) - if not store: - return + store = self.get_or_create_state(client) score = store.score if not score or not store.score_token: return @@ -294,8 +300,15 @@ class SpectatorHub(Hub): ): # save replay await _save_replay() + store.state = None + store.beatmap_status = None + store.checksum = None + store.ruleset_id = None + store.score_token = None + store.score = None + await self._end_session(user_id, state) - del self.state[user_id] + async def _end_session(self, user_id: int, state: SpectatorState) -> None: if state.state == SpectatedUserState.Playing: state.state = SpectatedUserState.Quit await self.broadcast_group_call( @@ -308,22 +321,18 @@ class SpectatorHub(Hub): async def StartWatchingUser(self, client: Client, target_id: int) -> None: print(f"StartWatchingUser -> {client.connection_id} {target_id}") user_id = int(client.connection_id) - target_store = self.state.get(target_id) - if target_store and target_store.state: + target_store = self.get_or_create_state(client) + if target_store.state: await self.call_noblock( client, "UserBeganPlaying", target_id, serialize_to_list(target_store.state), ) - store = self.state.get(user_id) - if store is None: - store = StoreClientState( - watched_user=set(), - ) + store = self.get_or_create_state(client) store.watched_user.add(target_id) - self.state[user_id] = store - self.groups.setdefault(self.group_id(target_id), set()).add(client) + + self.add_to_group(client, self.group_id(target_id)) async with AsyncSession(engine) as session: async with session.begin(): @@ -340,7 +349,7 @@ class SpectatorHub(Hub): async def EndWatchingUser(self, client: Client, target_id: int) -> None: print(f"EndWatchingUser -> {client.connection_id} {target_id}") user_id = int(client.connection_id) - self.groups[self.group_id(target_id)].discard(client) + self.remove_from_group(client, self.group_id(target_id)) store = self.state.get(user_id) if store: store.watched_user.discard(target_id) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 1ff9b83..3dfc8ca 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -51,10 +51,18 @@ class PingPacket(Packet): type: PacketType = PacketType.PING +@dataclass(kw_only=True) +class ClosePacket(Packet): + type: PacketType = PacketType.CLOSE + error: str | None = None + allow_reconnect: bool = False + + PACKETS = { PacketType.INVOCATION: InvocationPacket, PacketType.COMPLETION: CompletionPacket, PacketType.PING: PingPacket, + PacketType.CLOSE: ClosePacket, } @@ -127,6 +135,13 @@ class MsgpackProtocol: ] case PacketType.PING: return [PingPacket()] + case PacketType.CLOSE: + return [ + ClosePacket( + error=unpacked[1], + allow_reconnect=unpacked[2] if len(unpacked) > 2 else False, + ) + ] raise ValueError(f"Unsupported packet type: {packet_type}") @staticmethod @@ -156,6 +171,13 @@ class MsgpackProtocol: packet.error or packet.result or None, ] ) + elif isinstance(packet, ClosePacket): + payload.extend( + [ + packet.error or "", + packet.allow_reconnect, + ] + ) elif isinstance(packet, PingPacket): payload.pop(-1) data = msgpack.packb(payload, use_bin_type=True, datetime=True) @@ -198,6 +220,13 @@ class JSONProtocol: ] case PacketType.PING: return [PingPacket()] + case PacketType.CLOSE: + return [ + ClosePacket( + error=data.get("error"), + allow_reconnect=data.get("allowReconnect", False), + ) + ] raise ValueError(f"Unsupported packet type: {packet_type}") @staticmethod @@ -231,6 +260,14 @@ class JSONProtocol: payload["result"] = packet.result elif isinstance(packet, PingPacket): pass + elif isinstance(packet, ClosePacket): + payload.update( + { + "allowReconnect": packet.allow_reconnect, + } + ) + if packet.error is not None: + payload["error"] = packet.error return json.dumps(payload).encode("utf-8") + SEP diff --git a/app/signalr/router.py b/app/signalr/router.py index 5c2f08f..237a575 100644 --- a/app/signalr/router.py +++ b/app/signalr/router.py @@ -74,7 +74,7 @@ async def connect( client = None try: - client = hub_.add_client( + client = await hub_.add_client( connection_id=user_id, connection_token=id, connection=websocket, @@ -87,13 +87,17 @@ async def connect( except ValueError as e: error = str(e) payload = {"error": error} if error else {} - # finish handshake await websocket.send_bytes(json.dumps(payload).encode() + SEP) if error or not client: await websocket.close(code=1008) return + await hub_.clean_state(client, False) task = asyncio.create_task(hub_.on_connect(client)) hub_.tasks.add(task) task.add_done_callback(hub_.tasks.discard) await hub_._listen_client(client) + try: + await websocket.close() + except Exception: + ... diff --git a/pyproject.toml b/pyproject.toml index 0c50d97..415208c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ mark-parentheses = false keep-runtime-typing = true [tool.pyright] -pythonVersion = "3.11" +pythonVersion = "3.12" pythonPlatform = "All" typeCheckingMode = "standard"