feat(signalr): graceful state manager

This commit is contained in:
MingxuanGame
2025-07-28 08:46:20 +00:00
parent 722a6e57d8
commit f60283a6c2
9 changed files with 234 additions and 109 deletions

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from enum import IntEnum
from typing import Any, Literal
from app.models.signalr import UserState
from pydantic import BaseModel, ConfigDict, Field
@@ -126,7 +128,7 @@ UserActivity = (
)
class MetadataClientState(BaseModel):
class MetadataClientState(UserState):
user_activity: UserActivity | None = None
status: OnlineStatus | None = None

View File

@@ -64,3 +64,8 @@ class NegotiateResponse(BaseModel):
connectionToken: str
negotiateVersion: int = 1
availableTransports: list[Transport]
class UserState(BaseModel):
connection_id: str
connection_token: str

View File

@@ -9,7 +9,7 @@ from app.models.beatmap import BeatmapRankStatus
from .score import (
ScoreStatisticsInt,
)
from .signalr import MessagePackArrayModel
from .signalr import MessagePackArrayModel, UserState
import msgpack
from pydantic import BaseModel, Field, field_validator
@@ -128,7 +128,7 @@ class StoreScore(BaseModel):
replay_frames: list[LegacyReplayFrame] = Field(default_factory=list)
class StoreClientState(BaseModel):
class StoreClientState(UserState):
state: SpectatorState | None = None
beatmap_status: BeatmapRankStatus | None = None
checksum: str | None = None

View File

@@ -1,13 +1,16 @@
from __future__ import annotations
from abc import abstractmethod
import asyncio
import time
import traceback
from typing import Any
from app.config import settings
from app.models.signalr import UserState
from app.signalr.exception import InvokeException
from app.signalr.packet import (
ClosePacket,
CompletionPacket,
InvocationPacket,
Packet,
@@ -22,6 +25,19 @@ from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
class CloseConnection(Exception):
def __init__(
self,
message: str = "Connection closed",
allow_reconnect: bool = False,
from_client: bool = False,
) -> None:
super().__init__(message)
self.message = message
self.allow_reconnect = allow_reconnect
self.from_client = from_client
class Client:
def __init__(
self,
@@ -39,7 +55,11 @@ class Client:
self._store = ResultStore()
def __hash__(self) -> int:
return hash(self.connection_id + self.connection_token)
return hash(self.connection_token)
@property
def user_id(self) -> int:
return int(self.connection_id)
async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.procotol.encode(packet))
@@ -48,7 +68,7 @@ class Client:
message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
return [PingPacket()] # FIXME: Graceful empty message handling
return []
return self.procotol.decode(d)
async def _ping(self):
@@ -63,12 +83,13 @@ class Client:
break
class Hub:
class Hub[TState: UserState]:
def __init__(self) -> None:
self.clients: dict[str, Client] = {}
self.waited_clients: dict[str, int] = {}
self.tasks: set[asyncio.Task] = set()
self.groups: dict[str, set[Client]] = {}
self.state: dict[int, TState] = {}
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
self.waited_clients[connection_token] = timestamp
@@ -79,7 +100,25 @@ class Hub:
return client
return default
def add_client(
@abstractmethod
def create_state(self, client: Client) -> TState:
raise NotImplementedError
def get_or_create_state(self, client: Client) -> TState:
if (state := self.state.get(client.user_id)) is not None:
return state
state = self.create_state(client)
self.state[client.user_id] = state
return state
def add_to_group(self, client: Client, group_id: str) -> None:
self.groups.setdefault(group_id, set()).add(client)
def remove_from_group(self, client: Client, group_id: str) -> None:
if group_id in self.groups:
self.groups[group_id].discard(client)
async def add_client(
self,
connection_id: str,
connection_token: str,
@@ -104,19 +143,34 @@ class Hub:
client._ping_task = task
return client
async def remove_client(self, client: Client) -> None:
del self.clients[client.connection_token]
if client._listen_task:
client._listen_task.cancel()
if client._ping_task:
client._ping_task.cancel()
for group in self.groups.values():
group.discard(client)
await self.clean_state(client, False)
@abstractmethod
async def _clean_state(self, state: TState) -> None:
return
async def clean_state(self, client: Client, disconnected: bool) -> None:
if (state := self.state.get(client.user_id)) is None:
return
if disconnected and client.connection_token != state.connection_token:
return
try:
await self._clean_state(state)
except Exception:
...
async def on_connect(self, client: Client) -> None:
if method := getattr(self, "on_client_connect", None):
await method(client)
async def remove_client(self, connection_id: str) -> None:
if client := self.clients.get(connection_id):
del self.clients[connection_id]
if client._listen_task:
client._listen_task.cancel()
if client._ping_task:
client._ping_task.cancel()
await client.connection.close()
async def send_packet(self, client: Client, packet: Packet) -> None:
await client.send_packet(packet)
@@ -135,26 +189,40 @@ class Hub:
await asyncio.gather(*tasks)
async def _listen_client(self, client: Client) -> None:
jump = False
while not jump:
try:
try:
while True:
packets = await client.receive_packets()
for packet in packets:
if isinstance(packet, PingPacket):
continue
elif isinstance(packet, ClosePacket):
raise CloseConnection(
packet.error or "Connection closed by client",
packet.allow_reconnect,
True,
)
task = asyncio.create_task(self._handle_packet(client, packet))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e:
print(
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
)
jump = True
except Exception as e:
except WebSocketDisconnect as e:
print(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}")
except RuntimeError as e:
if "disconnect message" in str(e):
print(f"Client {client.connection_id} closed the connection.")
else:
traceback.print_exc()
print(f"Error in client {client.connection_id}: {e}")
jump = True
await self.remove_client(client.connection_id)
print(f"RuntimeError in client {client.connection_id}: {e}")
except CloseConnection as e:
if not e.from_client:
await client.send_packet(
ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect)
)
print(f"Client {client.connection_id} closed the connection: {e.message}")
except Exception as e:
traceback.print_exc()
print(f"Error in client {client.connection_id}: {e}")
await self.remove_client(client)
async def _handle_packet(self, client: Client, packet: Packet) -> None:
if isinstance(packet, PingPacket):

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Coroutine
from typing import override
from app.database.relationship import Relationship, RelationshipType
from app.dependencies.database import engine
@@ -16,32 +17,32 @@ from sqlmodel.ext.asyncio.session import AsyncSession
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
class MetadataHub(Hub):
class MetadataHub(Hub[MetadataClientState]):
def __init__(self) -> None:
super().__init__()
self.state: dict[int, MetadataClientState] = {}
@staticmethod
def online_presence_watchers_group() -> str:
return ONLINE_PRESENCE_WATCHERS_GROUP
def broadcast_tasks(
self, user_id: int, store: MetadataClientState
self, user_id: int, store: MetadataClientState | None
) -> set[Coroutine]:
if not store.pushable:
if store is not None and not store.pushable:
return set()
data = store.to_dict() if store else None
return {
self.broadcast_group_call(
self.online_presence_watchers_group(),
"UserPresenceUpdated",
user_id,
store.to_dict(),
data,
),
self.broadcast_group_call(
self.friend_presence_watchers_group(user_id),
"FriendPresenceUpdated",
user_id,
store.to_dict(),
data,
),
}
@@ -49,11 +50,21 @@ class MetadataHub(Hub):
def friend_presence_watchers_group(user_id: int):
return f"metadata:friend-presence-watchers:{user_id}"
@override
async def _clean_state(self, state: MetadataClientState) -> None:
if state.pushable:
await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None))
@override
def create_state(self, client: Client) -> MetadataClientState:
return MetadataClientState(
connection_id=client.connection_id,
connection_token=client.connection_token,
)
async def on_client_connect(self, client: Client) -> None:
user_id = int(client.connection_id)
if store := self.state.get(user_id):
store = MetadataClientState()
self.state[user_id] = store
self.get_or_create_state(client)
async with AsyncSession(engine) as session:
async with session.begin():
@@ -73,6 +84,7 @@ class MetadataHub(Hub):
if (
friend_state := self.state.get(friend_id)
) and friend_state.pushable:
print("Pushed")
tasks.append(
self.broadcast_group_call(
self.friend_presence_watchers_group(friend_id),
@@ -86,14 +98,10 @@ class MetadataHub(Hub):
async def UpdateStatus(self, client: Client, status: int) -> None:
status_ = OnlineStatus(status)
user_id = int(client.connection_id)
store = self.state.get(user_id)
if store:
if store.status is not None and store.status == status_:
return
store.status = OnlineStatus(status_)
else:
store = MetadataClientState(status=OnlineStatus(status_))
self.state[user_id] = store
store = self.get_or_create_state(client)
if store.status is not None and store.status == status_:
return
store.status = OnlineStatus(status_)
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(
@@ -112,14 +120,8 @@ class MetadataHub(Hub):
if activity_dict
else None
)
store = self.state.get(user_id)
if store:
store.user_activity = activity
else:
store = MetadataClientState(
user_activity=activity,
)
self.state[user_id] = store
store = self.get_or_create_state(client)
store.user_activity = activity
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(
@@ -144,9 +146,7 @@ class MetadataHub(Hub):
if store.pushable
]
)
self.groups.setdefault(self.online_presence_watchers_group(), set()).add(client)
self.add_to_group(client, self.online_presence_watchers_group())
async def EndWatchingUserPresence(self, client: Client) -> None:
self.groups.setdefault(self.online_presence_watchers_group(), set()).discard(
client
)
self.remove_from_group(client, self.online_presence_watchers_group())

View File

@@ -5,6 +5,7 @@ import json
import lzma
import struct
import time
from typing import override
from app.database import Beatmap
from app.database.score import Score
@@ -140,15 +141,29 @@ def save_replay(
replay_path.write_bytes(data)
class SpectatorHub(Hub):
def __init__(self) -> None:
super().__init__()
self.state: dict[int, StoreClientState] = {}
class SpectatorHub(Hub[StoreClientState]):
@staticmethod
def group_id(user_id: int) -> str:
return f"watch:{user_id}"
@override
def create_state(self, client: Client) -> StoreClientState:
return StoreClientState(
connection_id=client.connection_id,
connection_token=client.connection_token,
)
@override
async def _clean_state(self, state: StoreClientState) -> None:
if state.state:
await self._end_session(int(state.connection_id), state.state)
for target in self.waited_clients:
target_client = self.get_client_by_id(target)
if target_client:
await self.call_noblock(
target_client, "UserEndedWatching", int(state.connection_id)
)
async def on_client_connect(self, client: Client) -> None:
tasks = [
self.call_noblock(
@@ -163,8 +178,8 @@ class SpectatorHub(Hub):
self, client: Client, score_token: int, state: SpectatorState
) -> None:
user_id = int(client.connection_id)
previous_state = self.state.get(user_id)
if previous_state is not None:
store = self.get_or_create_state(client)
if store.state is not None:
return
if state.beatmap_id is None or state.ruleset_id is None:
return
@@ -183,23 +198,19 @@ class SpectatorHub(Hub):
if not user:
return
name = user.name
store = StoreClientState(
state=state,
beatmap_status=beatmap.beatmap_status,
checksum=beatmap.checksum,
ruleset_id=state.ruleset_id,
score_token=score_token,
watched_user=set(),
score=StoreScore(
score_info=ScoreInfo(
mods=state.mods,
user=APIUser(id=user_id, name=name),
ruleset=state.ruleset_id,
maximum_statistics=state.maximum_statistics,
)
),
store.state = state
store.beatmap_status = beatmap.beatmap_status
store.checksum = beatmap.checksum
store.ruleset_id = state.ruleset_id
store.score_token = score_token
store.score = StoreScore(
score_info=ScoreInfo(
mods=state.mods,
user=APIUser(id=user_id, name=name),
ruleset=state.ruleset_id,
maximum_statistics=state.maximum_statistics,
)
)
self.state[user_id] = store
await self.broadcast_group_call(
self.group_id(user_id),
"UserBeganPlaying",
@@ -209,19 +220,16 @@ class SpectatorHub(Hub):
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
user_id = int(client.connection_id)
state = self.state.get(user_id)
if not state:
state = self.get_or_create_state(client)
if not state.score:
return
score = state.score
if not score:
return
score.score_info.acc = frame_data.header.acc
score.score_info.combo = frame_data.header.combo
score.score_info.max_combo = frame_data.header.max_combo
score.score_info.statistics = frame_data.header.statistics
score.score_info.total_score = frame_data.header.total_score
score.score_info.mods = frame_data.header.mods
score.replay_frames.extend(frame_data.frames)
state.score.score_info.acc = frame_data.header.acc
state.score.score_info.combo = frame_data.header.combo
state.score.score_info.max_combo = frame_data.header.max_combo
state.score.score_info.statistics = frame_data.header.statistics
state.score.score_info.total_score = frame_data.header.total_score
state.score.score_info.mods = frame_data.header.mods
state.score.replay_frames.extend(frame_data.frames)
await self.broadcast_group_call(
self.group_id(user_id),
"UserSentFrames",
@@ -231,9 +239,7 @@ class SpectatorHub(Hub):
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
user_id = int(client.connection_id)
store = self.state.get(user_id)
if not store:
return
store = self.get_or_create_state(client)
score = store.score
if not score or not store.score_token:
return
@@ -294,8 +300,15 @@ class SpectatorHub(Hub):
):
# save replay
await _save_replay()
store.state = None
store.beatmap_status = None
store.checksum = None
store.ruleset_id = None
store.score_token = None
store.score = None
await self._end_session(user_id, state)
del self.state[user_id]
async def _end_session(self, user_id: int, state: SpectatorState) -> None:
if state.state == SpectatedUserState.Playing:
state.state = SpectatedUserState.Quit
await self.broadcast_group_call(
@@ -308,22 +321,18 @@ class SpectatorHub(Hub):
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
print(f"StartWatchingUser -> {client.connection_id} {target_id}")
user_id = int(client.connection_id)
target_store = self.state.get(target_id)
if target_store and target_store.state:
target_store = self.get_or_create_state(client)
if target_store.state:
await self.call_noblock(
client,
"UserBeganPlaying",
target_id,
serialize_to_list(target_store.state),
)
store = self.state.get(user_id)
if store is None:
store = StoreClientState(
watched_user=set(),
)
store = self.get_or_create_state(client)
store.watched_user.add(target_id)
self.state[user_id] = store
self.groups.setdefault(self.group_id(target_id), set()).add(client)
self.add_to_group(client, self.group_id(target_id))
async with AsyncSession(engine) as session:
async with session.begin():
@@ -340,7 +349,7 @@ class SpectatorHub(Hub):
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
print(f"EndWatchingUser -> {client.connection_id} {target_id}")
user_id = int(client.connection_id)
self.groups[self.group_id(target_id)].discard(client)
self.remove_from_group(client, self.group_id(target_id))
store = self.state.get(user_id)
if store:
store.watched_user.discard(target_id)

View File

@@ -51,10 +51,18 @@ class PingPacket(Packet):
type: PacketType = PacketType.PING
@dataclass(kw_only=True)
class ClosePacket(Packet):
type: PacketType = PacketType.CLOSE
error: str | None = None
allow_reconnect: bool = False
PACKETS = {
PacketType.INVOCATION: InvocationPacket,
PacketType.COMPLETION: CompletionPacket,
PacketType.PING: PingPacket,
PacketType.CLOSE: ClosePacket,
}
@@ -127,6 +135,13 @@ class MsgpackProtocol:
]
case PacketType.PING:
return [PingPacket()]
case PacketType.CLOSE:
return [
ClosePacket(
error=unpacked[1],
allow_reconnect=unpacked[2] if len(unpacked) > 2 else False,
)
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod
@@ -156,6 +171,13 @@ class MsgpackProtocol:
packet.error or packet.result or None,
]
)
elif isinstance(packet, ClosePacket):
payload.extend(
[
packet.error or "",
packet.allow_reconnect,
]
)
elif isinstance(packet, PingPacket):
payload.pop(-1)
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
@@ -198,6 +220,13 @@ class JSONProtocol:
]
case PacketType.PING:
return [PingPacket()]
case PacketType.CLOSE:
return [
ClosePacket(
error=data.get("error"),
allow_reconnect=data.get("allowReconnect", False),
)
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod
@@ -231,6 +260,14 @@ class JSONProtocol:
payload["result"] = packet.result
elif isinstance(packet, PingPacket):
pass
elif isinstance(packet, ClosePacket):
payload.update(
{
"allowReconnect": packet.allow_reconnect,
}
)
if packet.error is not None:
payload["error"] = packet.error
return json.dumps(payload).encode("utf-8") + SEP

View File

@@ -74,7 +74,7 @@ async def connect(
client = None
try:
client = hub_.add_client(
client = await hub_.add_client(
connection_id=user_id,
connection_token=id,
connection=websocket,
@@ -87,13 +87,17 @@ async def connect(
except ValueError as e:
error = str(e)
payload = {"error": error} if error else {}
# finish handshake
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
if error or not client:
await websocket.close(code=1008)
return
await hub_.clean_state(client, False)
task = asyncio.create_task(hub_.on_connect(client))
hub_.tasks.add(task)
task.add_done_callback(hub_.tasks.discard)
await hub_._listen_client(client)
try:
await websocket.close()
except Exception:
...

View File

@@ -77,7 +77,7 @@ mark-parentheses = false
keep-runtime-typing = true
[tool.pyright]
pythonVersion = "3.11"
pythonVersion = "3.12"
pythonPlatform = "All"
typeCheckingMode = "standard"