feat(signalr): graceful state manager
This commit is contained in:
@@ -3,6 +3,8 @@ from __future__ import annotations
|
|||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from app.models.signalr import UserState
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
@@ -126,7 +128,7 @@ UserActivity = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetadataClientState(BaseModel):
|
class MetadataClientState(UserState):
|
||||||
user_activity: UserActivity | None = None
|
user_activity: UserActivity | None = None
|
||||||
status: OnlineStatus | None = None
|
status: OnlineStatus | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -64,3 +64,8 @@ class NegotiateResponse(BaseModel):
|
|||||||
connectionToken: str
|
connectionToken: str
|
||||||
negotiateVersion: int = 1
|
negotiateVersion: int = 1
|
||||||
availableTransports: list[Transport]
|
availableTransports: list[Transport]
|
||||||
|
|
||||||
|
|
||||||
|
class UserState(BaseModel):
|
||||||
|
connection_id: str
|
||||||
|
connection_token: str
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from app.models.beatmap import BeatmapRankStatus
|
|||||||
from .score import (
|
from .score import (
|
||||||
ScoreStatisticsInt,
|
ScoreStatisticsInt,
|
||||||
)
|
)
|
||||||
from .signalr import MessagePackArrayModel
|
from .signalr import MessagePackArrayModel, UserState
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
@@ -128,7 +128,7 @@ class StoreScore(BaseModel):
|
|||||||
replay_frames: list[LegacyReplayFrame] = Field(default_factory=list)
|
replay_frames: list[LegacyReplayFrame] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class StoreClientState(BaseModel):
|
class StoreClientState(UserState):
|
||||||
state: SpectatorState | None = None
|
state: SpectatorState | None = None
|
||||||
beatmap_status: BeatmapRankStatus | None = None
|
beatmap_status: BeatmapRankStatus | None = None
|
||||||
checksum: str | None = None
|
checksum: str | None = None
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.models.signalr import UserState
|
||||||
from app.signalr.exception import InvokeException
|
from app.signalr.exception import InvokeException
|
||||||
from app.signalr.packet import (
|
from app.signalr.packet import (
|
||||||
|
ClosePacket,
|
||||||
CompletionPacket,
|
CompletionPacket,
|
||||||
InvocationPacket,
|
InvocationPacket,
|
||||||
Packet,
|
Packet,
|
||||||
@@ -22,6 +25,19 @@ from pydantic import BaseModel
|
|||||||
from starlette.websockets import WebSocketDisconnect
|
from starlette.websockets import WebSocketDisconnect
|
||||||
|
|
||||||
|
|
||||||
|
class CloseConnection(Exception):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Connection closed",
|
||||||
|
allow_reconnect: bool = False,
|
||||||
|
from_client: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.allow_reconnect = allow_reconnect
|
||||||
|
self.from_client = from_client
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -39,7 +55,11 @@ class Client:
|
|||||||
self._store = ResultStore()
|
self._store = ResultStore()
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash(self.connection_id + self.connection_token)
|
return hash(self.connection_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_id(self) -> int:
|
||||||
|
return int(self.connection_id)
|
||||||
|
|
||||||
async def send_packet(self, packet: Packet):
|
async def send_packet(self, packet: Packet):
|
||||||
await self.connection.send_bytes(self.procotol.encode(packet))
|
await self.connection.send_bytes(self.procotol.encode(packet))
|
||||||
@@ -48,7 +68,7 @@ class Client:
|
|||||||
message = await self.connection.receive()
|
message = await self.connection.receive()
|
||||||
d = message.get("bytes") or message.get("text", "").encode()
|
d = message.get("bytes") or message.get("text", "").encode()
|
||||||
if not d:
|
if not d:
|
||||||
return [PingPacket()] # FIXME: Graceful empty message handling
|
return []
|
||||||
return self.procotol.decode(d)
|
return self.procotol.decode(d)
|
||||||
|
|
||||||
async def _ping(self):
|
async def _ping(self):
|
||||||
@@ -63,12 +83,13 @@ class Client:
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
class Hub:
|
class Hub[TState: UserState]:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.clients: dict[str, Client] = {}
|
self.clients: dict[str, Client] = {}
|
||||||
self.waited_clients: dict[str, int] = {}
|
self.waited_clients: dict[str, int] = {}
|
||||||
self.tasks: set[asyncio.Task] = set()
|
self.tasks: set[asyncio.Task] = set()
|
||||||
self.groups: dict[str, set[Client]] = {}
|
self.groups: dict[str, set[Client]] = {}
|
||||||
|
self.state: dict[int, TState] = {}
|
||||||
|
|
||||||
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
|
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
|
||||||
self.waited_clients[connection_token] = timestamp
|
self.waited_clients[connection_token] = timestamp
|
||||||
@@ -79,7 +100,25 @@ class Hub:
|
|||||||
return client
|
return client
|
||||||
return default
|
return default
|
||||||
|
|
||||||
def add_client(
|
@abstractmethod
|
||||||
|
def create_state(self, client: Client) -> TState:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_or_create_state(self, client: Client) -> TState:
|
||||||
|
if (state := self.state.get(client.user_id)) is not None:
|
||||||
|
return state
|
||||||
|
state = self.create_state(client)
|
||||||
|
self.state[client.user_id] = state
|
||||||
|
return state
|
||||||
|
|
||||||
|
def add_to_group(self, client: Client, group_id: str) -> None:
|
||||||
|
self.groups.setdefault(group_id, set()).add(client)
|
||||||
|
|
||||||
|
def remove_from_group(self, client: Client, group_id: str) -> None:
|
||||||
|
if group_id in self.groups:
|
||||||
|
self.groups[group_id].discard(client)
|
||||||
|
|
||||||
|
async def add_client(
|
||||||
self,
|
self,
|
||||||
connection_id: str,
|
connection_id: str,
|
||||||
connection_token: str,
|
connection_token: str,
|
||||||
@@ -104,19 +143,34 @@ class Hub:
|
|||||||
client._ping_task = task
|
client._ping_task = task
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
async def remove_client(self, client: Client) -> None:
|
||||||
|
del self.clients[client.connection_token]
|
||||||
|
if client._listen_task:
|
||||||
|
client._listen_task.cancel()
|
||||||
|
if client._ping_task:
|
||||||
|
client._ping_task.cancel()
|
||||||
|
for group in self.groups.values():
|
||||||
|
group.discard(client)
|
||||||
|
await self.clean_state(client, False)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _clean_state(self, state: TState) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def clean_state(self, client: Client, disconnected: bool) -> None:
|
||||||
|
if (state := self.state.get(client.user_id)) is None:
|
||||||
|
return
|
||||||
|
if disconnected and client.connection_token != state.connection_token:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await self._clean_state(state)
|
||||||
|
except Exception:
|
||||||
|
...
|
||||||
|
|
||||||
async def on_connect(self, client: Client) -> None:
|
async def on_connect(self, client: Client) -> None:
|
||||||
if method := getattr(self, "on_client_connect", None):
|
if method := getattr(self, "on_client_connect", None):
|
||||||
await method(client)
|
await method(client)
|
||||||
|
|
||||||
async def remove_client(self, connection_id: str) -> None:
|
|
||||||
if client := self.clients.get(connection_id):
|
|
||||||
del self.clients[connection_id]
|
|
||||||
if client._listen_task:
|
|
||||||
client._listen_task.cancel()
|
|
||||||
if client._ping_task:
|
|
||||||
client._ping_task.cancel()
|
|
||||||
await client.connection.close()
|
|
||||||
|
|
||||||
async def send_packet(self, client: Client, packet: Packet) -> None:
|
async def send_packet(self, client: Client, packet: Packet) -> None:
|
||||||
await client.send_packet(packet)
|
await client.send_packet(packet)
|
||||||
|
|
||||||
@@ -135,26 +189,40 @@ class Hub:
|
|||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def _listen_client(self, client: Client) -> None:
|
async def _listen_client(self, client: Client) -> None:
|
||||||
jump = False
|
try:
|
||||||
while not jump:
|
while True:
|
||||||
try:
|
|
||||||
packets = await client.receive_packets()
|
packets = await client.receive_packets()
|
||||||
for packet in packets:
|
for packet in packets:
|
||||||
if isinstance(packet, PingPacket):
|
if isinstance(packet, PingPacket):
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(packet, ClosePacket):
|
||||||
|
raise CloseConnection(
|
||||||
|
packet.error or "Connection closed by client",
|
||||||
|
packet.allow_reconnect,
|
||||||
|
True,
|
||||||
|
)
|
||||||
task = asyncio.create_task(self._handle_packet(client, packet))
|
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||||
self.tasks.add(task)
|
self.tasks.add(task)
|
||||||
task.add_done_callback(self.tasks.discard)
|
task.add_done_callback(self.tasks.discard)
|
||||||
except WebSocketDisconnect as e:
|
except WebSocketDisconnect as e:
|
||||||
print(
|
print(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}")
|
||||||
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
except RuntimeError as e:
|
||||||
)
|
if "disconnect message" in str(e):
|
||||||
jump = True
|
print(f"Client {client.connection_id} closed the connection.")
|
||||||
except Exception as e:
|
else:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print(f"Error in client {client.connection_id}: {e}")
|
print(f"RuntimeError in client {client.connection_id}: {e}")
|
||||||
jump = True
|
except CloseConnection as e:
|
||||||
await self.remove_client(client.connection_id)
|
if not e.from_client:
|
||||||
|
await client.send_packet(
|
||||||
|
ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect)
|
||||||
|
)
|
||||||
|
print(f"Client {client.connection_id} closed the connection: {e.message}")
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Error in client {client.connection_id}: {e}")
|
||||||
|
|
||||||
|
await self.remove_client(client)
|
||||||
|
|
||||||
async def _handle_packet(self, client: Client, packet: Packet) -> None:
|
async def _handle_packet(self, client: Client, packet: Packet) -> None:
|
||||||
if isinstance(packet, PingPacket):
|
if isinstance(packet, PingPacket):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
|
from typing import override
|
||||||
|
|
||||||
from app.database.relationship import Relationship, RelationshipType
|
from app.database.relationship import Relationship, RelationshipType
|
||||||
from app.dependencies.database import engine
|
from app.dependencies.database import engine
|
||||||
@@ -16,32 +17,32 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
||||||
|
|
||||||
|
|
||||||
class MetadataHub(Hub):
|
class MetadataHub(Hub[MetadataClientState]):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.state: dict[int, MetadataClientState] = {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def online_presence_watchers_group() -> str:
|
def online_presence_watchers_group() -> str:
|
||||||
return ONLINE_PRESENCE_WATCHERS_GROUP
|
return ONLINE_PRESENCE_WATCHERS_GROUP
|
||||||
|
|
||||||
def broadcast_tasks(
|
def broadcast_tasks(
|
||||||
self, user_id: int, store: MetadataClientState
|
self, user_id: int, store: MetadataClientState | None
|
||||||
) -> set[Coroutine]:
|
) -> set[Coroutine]:
|
||||||
if not store.pushable:
|
if store is not None and not store.pushable:
|
||||||
return set()
|
return set()
|
||||||
|
data = store.to_dict() if store else None
|
||||||
return {
|
return {
|
||||||
self.broadcast_group_call(
|
self.broadcast_group_call(
|
||||||
self.online_presence_watchers_group(),
|
self.online_presence_watchers_group(),
|
||||||
"UserPresenceUpdated",
|
"UserPresenceUpdated",
|
||||||
user_id,
|
user_id,
|
||||||
store.to_dict(),
|
data,
|
||||||
),
|
),
|
||||||
self.broadcast_group_call(
|
self.broadcast_group_call(
|
||||||
self.friend_presence_watchers_group(user_id),
|
self.friend_presence_watchers_group(user_id),
|
||||||
"FriendPresenceUpdated",
|
"FriendPresenceUpdated",
|
||||||
user_id,
|
user_id,
|
||||||
store.to_dict(),
|
data,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,11 +50,21 @@ class MetadataHub(Hub):
|
|||||||
def friend_presence_watchers_group(user_id: int):
|
def friend_presence_watchers_group(user_id: int):
|
||||||
return f"metadata:friend-presence-watchers:{user_id}"
|
return f"metadata:friend-presence-watchers:{user_id}"
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def _clean_state(self, state: MetadataClientState) -> None:
|
||||||
|
if state.pushable:
|
||||||
|
await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None))
|
||||||
|
|
||||||
|
@override
|
||||||
|
def create_state(self, client: Client) -> MetadataClientState:
|
||||||
|
return MetadataClientState(
|
||||||
|
connection_id=client.connection_id,
|
||||||
|
connection_token=client.connection_token,
|
||||||
|
)
|
||||||
|
|
||||||
async def on_client_connect(self, client: Client) -> None:
|
async def on_client_connect(self, client: Client) -> None:
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
if store := self.state.get(user_id):
|
self.get_or_create_state(client)
|
||||||
store = MetadataClientState()
|
|
||||||
self.state[user_id] = store
|
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with AsyncSession(engine) as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
@@ -73,6 +84,7 @@ class MetadataHub(Hub):
|
|||||||
if (
|
if (
|
||||||
friend_state := self.state.get(friend_id)
|
friend_state := self.state.get(friend_id)
|
||||||
) and friend_state.pushable:
|
) and friend_state.pushable:
|
||||||
|
print("Pushed")
|
||||||
tasks.append(
|
tasks.append(
|
||||||
self.broadcast_group_call(
|
self.broadcast_group_call(
|
||||||
self.friend_presence_watchers_group(friend_id),
|
self.friend_presence_watchers_group(friend_id),
|
||||||
@@ -86,14 +98,10 @@ class MetadataHub(Hub):
|
|||||||
async def UpdateStatus(self, client: Client, status: int) -> None:
|
async def UpdateStatus(self, client: Client, status: int) -> None:
|
||||||
status_ = OnlineStatus(status)
|
status_ = OnlineStatus(status)
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
store = self.state.get(user_id)
|
store = self.get_or_create_state(client)
|
||||||
if store:
|
if store.status is not None and store.status == status_:
|
||||||
if store.status is not None and store.status == status_:
|
return
|
||||||
return
|
store.status = OnlineStatus(status_)
|
||||||
store.status = OnlineStatus(status_)
|
|
||||||
else:
|
|
||||||
store = MetadataClientState(status=OnlineStatus(status_))
|
|
||||||
self.state[user_id] = store
|
|
||||||
tasks = self.broadcast_tasks(user_id, store)
|
tasks = self.broadcast_tasks(user_id, store)
|
||||||
tasks.add(
|
tasks.add(
|
||||||
self.call_noblock(
|
self.call_noblock(
|
||||||
@@ -112,14 +120,8 @@ class MetadataHub(Hub):
|
|||||||
if activity_dict
|
if activity_dict
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
store = self.state.get(user_id)
|
store = self.get_or_create_state(client)
|
||||||
if store:
|
store.user_activity = activity
|
||||||
store.user_activity = activity
|
|
||||||
else:
|
|
||||||
store = MetadataClientState(
|
|
||||||
user_activity=activity,
|
|
||||||
)
|
|
||||||
self.state[user_id] = store
|
|
||||||
tasks = self.broadcast_tasks(user_id, store)
|
tasks = self.broadcast_tasks(user_id, store)
|
||||||
tasks.add(
|
tasks.add(
|
||||||
self.call_noblock(
|
self.call_noblock(
|
||||||
@@ -144,9 +146,7 @@ class MetadataHub(Hub):
|
|||||||
if store.pushable
|
if store.pushable
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.groups.setdefault(self.online_presence_watchers_group(), set()).add(client)
|
self.add_to_group(client, self.online_presence_watchers_group())
|
||||||
|
|
||||||
async def EndWatchingUserPresence(self, client: Client) -> None:
|
async def EndWatchingUserPresence(self, client: Client) -> None:
|
||||||
self.groups.setdefault(self.online_presence_watchers_group(), set()).discard(
|
self.remove_from_group(client, self.online_presence_watchers_group())
|
||||||
client
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import json
|
|||||||
import lzma
|
import lzma
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
|
from typing import override
|
||||||
|
|
||||||
from app.database import Beatmap
|
from app.database import Beatmap
|
||||||
from app.database.score import Score
|
from app.database.score import Score
|
||||||
@@ -140,15 +141,29 @@ def save_replay(
|
|||||||
replay_path.write_bytes(data)
|
replay_path.write_bytes(data)
|
||||||
|
|
||||||
|
|
||||||
class SpectatorHub(Hub):
|
class SpectatorHub(Hub[StoreClientState]):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.state: dict[int, StoreClientState] = {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def group_id(user_id: int) -> str:
|
def group_id(user_id: int) -> str:
|
||||||
return f"watch:{user_id}"
|
return f"watch:{user_id}"
|
||||||
|
|
||||||
|
@override
|
||||||
|
def create_state(self, client: Client) -> StoreClientState:
|
||||||
|
return StoreClientState(
|
||||||
|
connection_id=client.connection_id,
|
||||||
|
connection_token=client.connection_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def _clean_state(self, state: StoreClientState) -> None:
|
||||||
|
if state.state:
|
||||||
|
await self._end_session(int(state.connection_id), state.state)
|
||||||
|
for target in self.waited_clients:
|
||||||
|
target_client = self.get_client_by_id(target)
|
||||||
|
if target_client:
|
||||||
|
await self.call_noblock(
|
||||||
|
target_client, "UserEndedWatching", int(state.connection_id)
|
||||||
|
)
|
||||||
|
|
||||||
async def on_client_connect(self, client: Client) -> None:
|
async def on_client_connect(self, client: Client) -> None:
|
||||||
tasks = [
|
tasks = [
|
||||||
self.call_noblock(
|
self.call_noblock(
|
||||||
@@ -163,8 +178,8 @@ class SpectatorHub(Hub):
|
|||||||
self, client: Client, score_token: int, state: SpectatorState
|
self, client: Client, score_token: int, state: SpectatorState
|
||||||
) -> None:
|
) -> None:
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
previous_state = self.state.get(user_id)
|
store = self.get_or_create_state(client)
|
||||||
if previous_state is not None:
|
if store.state is not None:
|
||||||
return
|
return
|
||||||
if state.beatmap_id is None or state.ruleset_id is None:
|
if state.beatmap_id is None or state.ruleset_id is None:
|
||||||
return
|
return
|
||||||
@@ -183,23 +198,19 @@ class SpectatorHub(Hub):
|
|||||||
if not user:
|
if not user:
|
||||||
return
|
return
|
||||||
name = user.name
|
name = user.name
|
||||||
store = StoreClientState(
|
store.state = state
|
||||||
state=state,
|
store.beatmap_status = beatmap.beatmap_status
|
||||||
beatmap_status=beatmap.beatmap_status,
|
store.checksum = beatmap.checksum
|
||||||
checksum=beatmap.checksum,
|
store.ruleset_id = state.ruleset_id
|
||||||
ruleset_id=state.ruleset_id,
|
store.score_token = score_token
|
||||||
score_token=score_token,
|
store.score = StoreScore(
|
||||||
watched_user=set(),
|
score_info=ScoreInfo(
|
||||||
score=StoreScore(
|
mods=state.mods,
|
||||||
score_info=ScoreInfo(
|
user=APIUser(id=user_id, name=name),
|
||||||
mods=state.mods,
|
ruleset=state.ruleset_id,
|
||||||
user=APIUser(id=user_id, name=name),
|
maximum_statistics=state.maximum_statistics,
|
||||||
ruleset=state.ruleset_id,
|
)
|
||||||
maximum_statistics=state.maximum_statistics,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.state[user_id] = store
|
|
||||||
await self.broadcast_group_call(
|
await self.broadcast_group_call(
|
||||||
self.group_id(user_id),
|
self.group_id(user_id),
|
||||||
"UserBeganPlaying",
|
"UserBeganPlaying",
|
||||||
@@ -209,19 +220,16 @@ class SpectatorHub(Hub):
|
|||||||
|
|
||||||
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
state = self.state.get(user_id)
|
state = self.get_or_create_state(client)
|
||||||
if not state:
|
if not state.score:
|
||||||
return
|
return
|
||||||
score = state.score
|
state.score.score_info.acc = frame_data.header.acc
|
||||||
if not score:
|
state.score.score_info.combo = frame_data.header.combo
|
||||||
return
|
state.score.score_info.max_combo = frame_data.header.max_combo
|
||||||
score.score_info.acc = frame_data.header.acc
|
state.score.score_info.statistics = frame_data.header.statistics
|
||||||
score.score_info.combo = frame_data.header.combo
|
state.score.score_info.total_score = frame_data.header.total_score
|
||||||
score.score_info.max_combo = frame_data.header.max_combo
|
state.score.score_info.mods = frame_data.header.mods
|
||||||
score.score_info.statistics = frame_data.header.statistics
|
state.score.replay_frames.extend(frame_data.frames)
|
||||||
score.score_info.total_score = frame_data.header.total_score
|
|
||||||
score.score_info.mods = frame_data.header.mods
|
|
||||||
score.replay_frames.extend(frame_data.frames)
|
|
||||||
await self.broadcast_group_call(
|
await self.broadcast_group_call(
|
||||||
self.group_id(user_id),
|
self.group_id(user_id),
|
||||||
"UserSentFrames",
|
"UserSentFrames",
|
||||||
@@ -231,9 +239,7 @@ class SpectatorHub(Hub):
|
|||||||
|
|
||||||
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
store = self.state.get(user_id)
|
store = self.get_or_create_state(client)
|
||||||
if not store:
|
|
||||||
return
|
|
||||||
score = store.score
|
score = store.score
|
||||||
if not score or not store.score_token:
|
if not score or not store.score_token:
|
||||||
return
|
return
|
||||||
@@ -294,8 +300,15 @@ class SpectatorHub(Hub):
|
|||||||
):
|
):
|
||||||
# save replay
|
# save replay
|
||||||
await _save_replay()
|
await _save_replay()
|
||||||
|
store.state = None
|
||||||
|
store.beatmap_status = None
|
||||||
|
store.checksum = None
|
||||||
|
store.ruleset_id = None
|
||||||
|
store.score_token = None
|
||||||
|
store.score = None
|
||||||
|
await self._end_session(user_id, state)
|
||||||
|
|
||||||
del self.state[user_id]
|
async def _end_session(self, user_id: int, state: SpectatorState) -> None:
|
||||||
if state.state == SpectatedUserState.Playing:
|
if state.state == SpectatedUserState.Playing:
|
||||||
state.state = SpectatedUserState.Quit
|
state.state = SpectatedUserState.Quit
|
||||||
await self.broadcast_group_call(
|
await self.broadcast_group_call(
|
||||||
@@ -308,22 +321,18 @@ class SpectatorHub(Hub):
|
|||||||
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
||||||
print(f"StartWatchingUser -> {client.connection_id} {target_id}")
|
print(f"StartWatchingUser -> {client.connection_id} {target_id}")
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
target_store = self.state.get(target_id)
|
target_store = self.get_or_create_state(client)
|
||||||
if target_store and target_store.state:
|
if target_store.state:
|
||||||
await self.call_noblock(
|
await self.call_noblock(
|
||||||
client,
|
client,
|
||||||
"UserBeganPlaying",
|
"UserBeganPlaying",
|
||||||
target_id,
|
target_id,
|
||||||
serialize_to_list(target_store.state),
|
serialize_to_list(target_store.state),
|
||||||
)
|
)
|
||||||
store = self.state.get(user_id)
|
store = self.get_or_create_state(client)
|
||||||
if store is None:
|
|
||||||
store = StoreClientState(
|
|
||||||
watched_user=set(),
|
|
||||||
)
|
|
||||||
store.watched_user.add(target_id)
|
store.watched_user.add(target_id)
|
||||||
self.state[user_id] = store
|
|
||||||
self.groups.setdefault(self.group_id(target_id), set()).add(client)
|
self.add_to_group(client, self.group_id(target_id))
|
||||||
|
|
||||||
async with AsyncSession(engine) as session:
|
async with AsyncSession(engine) as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
@@ -340,7 +349,7 @@ class SpectatorHub(Hub):
|
|||||||
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
|
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
|
||||||
print(f"EndWatchingUser -> {client.connection_id} {target_id}")
|
print(f"EndWatchingUser -> {client.connection_id} {target_id}")
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
self.groups[self.group_id(target_id)].discard(client)
|
self.remove_from_group(client, self.group_id(target_id))
|
||||||
store = self.state.get(user_id)
|
store = self.state.get(user_id)
|
||||||
if store:
|
if store:
|
||||||
store.watched_user.discard(target_id)
|
store.watched_user.discard(target_id)
|
||||||
|
|||||||
@@ -51,10 +51,18 @@ class PingPacket(Packet):
|
|||||||
type: PacketType = PacketType.PING
|
type: PacketType = PacketType.PING
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class ClosePacket(Packet):
|
||||||
|
type: PacketType = PacketType.CLOSE
|
||||||
|
error: str | None = None
|
||||||
|
allow_reconnect: bool = False
|
||||||
|
|
||||||
|
|
||||||
PACKETS = {
|
PACKETS = {
|
||||||
PacketType.INVOCATION: InvocationPacket,
|
PacketType.INVOCATION: InvocationPacket,
|
||||||
PacketType.COMPLETION: CompletionPacket,
|
PacketType.COMPLETION: CompletionPacket,
|
||||||
PacketType.PING: PingPacket,
|
PacketType.PING: PingPacket,
|
||||||
|
PacketType.CLOSE: ClosePacket,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -127,6 +135,13 @@ class MsgpackProtocol:
|
|||||||
]
|
]
|
||||||
case PacketType.PING:
|
case PacketType.PING:
|
||||||
return [PingPacket()]
|
return [PingPacket()]
|
||||||
|
case PacketType.CLOSE:
|
||||||
|
return [
|
||||||
|
ClosePacket(
|
||||||
|
error=unpacked[1],
|
||||||
|
allow_reconnect=unpacked[2] if len(unpacked) > 2 else False,
|
||||||
|
)
|
||||||
|
]
|
||||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -156,6 +171,13 @@ class MsgpackProtocol:
|
|||||||
packet.error or packet.result or None,
|
packet.error or packet.result or None,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
elif isinstance(packet, ClosePacket):
|
||||||
|
payload.extend(
|
||||||
|
[
|
||||||
|
packet.error or "",
|
||||||
|
packet.allow_reconnect,
|
||||||
|
]
|
||||||
|
)
|
||||||
elif isinstance(packet, PingPacket):
|
elif isinstance(packet, PingPacket):
|
||||||
payload.pop(-1)
|
payload.pop(-1)
|
||||||
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
|
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
|
||||||
@@ -198,6 +220,13 @@ class JSONProtocol:
|
|||||||
]
|
]
|
||||||
case PacketType.PING:
|
case PacketType.PING:
|
||||||
return [PingPacket()]
|
return [PingPacket()]
|
||||||
|
case PacketType.CLOSE:
|
||||||
|
return [
|
||||||
|
ClosePacket(
|
||||||
|
error=data.get("error"),
|
||||||
|
allow_reconnect=data.get("allowReconnect", False),
|
||||||
|
)
|
||||||
|
]
|
||||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -231,6 +260,14 @@ class JSONProtocol:
|
|||||||
payload["result"] = packet.result
|
payload["result"] = packet.result
|
||||||
elif isinstance(packet, PingPacket):
|
elif isinstance(packet, PingPacket):
|
||||||
pass
|
pass
|
||||||
|
elif isinstance(packet, ClosePacket):
|
||||||
|
payload.update(
|
||||||
|
{
|
||||||
|
"allowReconnect": packet.allow_reconnect,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if packet.error is not None:
|
||||||
|
payload["error"] = packet.error
|
||||||
return json.dumps(payload).encode("utf-8") + SEP
|
return json.dumps(payload).encode("utf-8") + SEP
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ async def connect(
|
|||||||
|
|
||||||
client = None
|
client = None
|
||||||
try:
|
try:
|
||||||
client = hub_.add_client(
|
client = await hub_.add_client(
|
||||||
connection_id=user_id,
|
connection_id=user_id,
|
||||||
connection_token=id,
|
connection_token=id,
|
||||||
connection=websocket,
|
connection=websocket,
|
||||||
@@ -87,13 +87,17 @@ async def connect(
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
error = str(e)
|
error = str(e)
|
||||||
payload = {"error": error} if error else {}
|
payload = {"error": error} if error else {}
|
||||||
|
|
||||||
# finish handshake
|
# finish handshake
|
||||||
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
|
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
|
||||||
if error or not client:
|
if error or not client:
|
||||||
await websocket.close(code=1008)
|
await websocket.close(code=1008)
|
||||||
return
|
return
|
||||||
|
await hub_.clean_state(client, False)
|
||||||
task = asyncio.create_task(hub_.on_connect(client))
|
task = asyncio.create_task(hub_.on_connect(client))
|
||||||
hub_.tasks.add(task)
|
hub_.tasks.add(task)
|
||||||
task.add_done_callback(hub_.tasks.discard)
|
task.add_done_callback(hub_.tasks.discard)
|
||||||
await hub_._listen_client(client)
|
await hub_._listen_client(client)
|
||||||
|
try:
|
||||||
|
await websocket.close()
|
||||||
|
except Exception:
|
||||||
|
...
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ mark-parentheses = false
|
|||||||
keep-runtime-typing = true
|
keep-runtime-typing = true
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
pythonVersion = "3.11"
|
pythonVersion = "3.12"
|
||||||
pythonPlatform = "All"
|
pythonPlatform = "All"
|
||||||
|
|
||||||
typeCheckingMode = "standard"
|
typeCheckingMode = "standard"
|
||||||
|
|||||||
Reference in New Issue
Block a user