chore(signalr): move to app/

This commit is contained in:
MingxuanGame
2025-07-27 02:42:14 +00:00
parent b359be3637
commit 0d684a1288
12 changed files with 8 additions and 7 deletions

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from app.signalr import signalr_router as signalr_router
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
beatmap,
beatmapset,
@@ -10,6 +12,5 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
from .api_router import router as api_router
from .auth import router as auth_router
from .fetcher import fetcher_router as fetcher_router
from .signalr import signalr_router as signalr_router
__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"]

View File

@@ -1,5 +0,0 @@
from __future__ import annotations
from .router import router as signalr_router
__all__ = ["signalr_router"]

View File

@@ -1,10 +0,0 @@
from __future__ import annotations
class SignalRException(Exception):
pass
class InvokeException(SignalRException):
def __init__(self, message: str) -> None:
self.message = message

View File

@@ -1,15 +0,0 @@
from __future__ import annotations
from .hub import Hub
from .metadata import MetadataHub
from .multiplayer import MultiplayerHub
from .spectator import SpectatorHub
SpectatorHubs = SpectatorHub()
MultiplayerHubs = MultiplayerHub()
MetadataHubs = MetadataHub()
Hubs: dict[str, Hub] = {
"spectator": SpectatorHubs,
"multiplayer": MultiplayerHubs,
"metadata": MetadataHubs,
}

View File

@@ -1,211 +0,0 @@
from __future__ import annotations
import asyncio
import time
from typing import Any
from app.config import settings
from app.router.signalr.exception import InvokeException
from app.router.signalr.packet import (
PacketType,
ResultKind,
encode_varint,
parse_packet,
)
from app.router.signalr.store import ResultStore
from app.router.signalr.utils import get_signature
from fastapi import WebSocket
import msgpack
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
class Client:
def __init__(
self, connection_id: str, connection_token: str, connection: WebSocket
) -> None:
self.connection_id = connection_id
self.connection_token = connection_token
self.connection = connection
self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None
self._store = ResultStore()
async def send_packet(self, type: PacketType, packet: list[Any]):
packet.insert(0, type.value)
payload = msgpack.packb(packet)
length = encode_varint(len(payload))
await self.connection.send_bytes(length + payload)
async def _ping(self):
while True:
try:
await self.send_packet(PacketType.PING, [])
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
except WebSocketDisconnect:
break
except Exception as e:
print(f"Error in ping task for {self.connection_id}: {e}")
break
class Hub:
def __init__(self) -> None:
self.clients: dict[str, Client] = {}
self.waited_clients: dict[str, int] = {}
self.tasks: set[asyncio.Task] = set()
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
self.waited_clients[connection_token] = timestamp
def add_client(
self, connection_id: str, connection_token: str, connection: WebSocket
) -> Client:
if connection_token in self.clients:
raise ValueError(
f"Client with connection token {connection_token} already exists."
)
if connection_token in self.waited_clients:
if (
self.waited_clients[connection_token]
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
):
raise TimeoutError(f"Connection {connection_id} has waited too long.")
del self.waited_clients[connection_token]
client = Client(connection_id, connection_token, connection)
self.clients[connection_token] = client
task = asyncio.create_task(client._ping())
self.tasks.add(task)
client._ping_task = task
return client
async def remove_client(self, connection_id: str) -> None:
if client := self.clients.get(connection_id):
del self.clients[connection_id]
if client._listen_task:
client._listen_task.cancel()
if client._ping_task:
client._ping_task.cancel()
await client.connection.close()
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]):
await client.send_packet(type, packet)
async def _listen_client(self, client: Client) -> None:
jump = False
while not jump:
try:
message = await client.connection.receive_bytes()
packet_type, packet_data = parse_packet(message)
task = asyncio.create_task(
self._handle_packet(client, packet_type, packet_data)
)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e:
if e.code == 1005:
continue
print(
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
)
jump = True
except Exception as e:
print(f"Error in client {client.connection_id}: {e}")
jump = True
await self.remove_client(client.connection_id)
async def _handle_packet(
self, client: Client, type: PacketType, packet: list[Any]
) -> None:
match type:
case PacketType.PING:
...
case PacketType.INVOCATION:
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration]
target: str = packet[2]
args: list[Any] | None = packet[3]
if args is None:
args = []
# streams: list[str] | None = packet[4] # TODO: stream support
code = ResultKind.VOID
result = None
try:
result = await self.invoke_method(client, target, args)
if result is not None:
code = ResultKind.HAS_VALUE
except InvokeException as e:
code = ResultKind.ERROR
result = e.message
except Exception as e:
code = ResultKind.ERROR
result = str(e)
packet = [
{}, # header
invocation_id,
code.value,
]
if result is not None:
packet.append(result)
if invocation_id is not None:
await client.send_packet(
PacketType.COMPLETION,
packet,
)
case PacketType.COMPLETION:
invocation_id: str = packet[1]
code: ResultKind = ResultKind(packet[2])
result: Any = packet[3] if len(packet) > 3 else None
client._store.add_result(invocation_id, code, result)
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
method_ = getattr(self, method, None)
call_params = []
if not method_:
raise InvokeException(f"Method '{method}' not found in hub.")
signature = get_signature(method_)
for name, param in signature.parameters.items():
if name == "self" or param.annotation is Client:
continue
if issubclass(param.annotation, BaseModel):
call_params.append(param.annotation.model_validate(args.pop(0)))
else:
call_params.append(args.pop(0))
return await method_(client, *call_params)
async def call(self, client: Client, method: str, *args: Any) -> Any:
invocation_id = client._store.get_invocation_id()
await client.send_packet(
PacketType.INVOCATION,
[
{}, # header
invocation_id,
method,
list(args),
None, # streams
],
)
r = await client._store.fetch(invocation_id, None)
if r[0] == ResultKind.HAS_VALUE:
return r[1]
if r[0] == ResultKind.ERROR:
raise InvokeException(r[1])
return None
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
await client.send_packet(
PacketType.INVOCATION,
[
{}, # header
None, # invocation_id
method,
list(args),
None, # streams
],
)
return None
def __contains__(self, item: str) -> bool:
return item in self.clients or item in self.waited_clients

View File

@@ -1,6 +0,0 @@
from __future__ import annotations
from .hub import Hub
class MetadataHub(Hub): ...

View File

@@ -1,6 +0,0 @@
from __future__ import annotations
from .hub import Hub
class MultiplayerHub(Hub): ...

View File

@@ -1,15 +0,0 @@
from __future__ import annotations
from app.models.spectator_hub import FrameDataBundle, SpectatorState
from .hub import Client, Hub
class SpectatorHub(Hub):
async def BeginPlaySession(
self, client: Client, score_token: int, state: SpectatorState
) -> None: ...
async def SendFrameData(
self, client: Client, frame_data: FrameDataBundle
) -> None: ...

View File

@@ -1,56 +0,0 @@
from __future__ import annotations
from enum import IntEnum
from typing import Any
import msgpack
SEP = b"\x1e"
class PacketType(IntEnum):
INVOCATION = 1
STREAM_ITEM = 2
COMPLETION = 3
STREAM_INVOCATION = 4
CANCEL_INVOCATION = 5
PING = 6
CLOSE = 7
class ResultKind(IntEnum):
ERROR = 1
VOID = 2
HAS_VALUE = 3
def parse_packet(data: bytes) -> tuple[PacketType, list[Any]]:
length, offset = decode_varint(data)
message_data = data[offset : offset + length]
unpacked = msgpack.unpackb(message_data, raw=False)
return PacketType(unpacked[0]), unpacked[1:]
def encode_varint(value: int) -> bytes:
result = []
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
result = 0
shift = 0
pos = offset
while pos < len(data):
byte = data[pos]
result |= (byte & 0x7F) << shift
pos += 1
if (byte & 0x80) == 0:
break
shift += 7
return result, pos

View File

@@ -1,91 +0,0 @@
from __future__ import annotations
import json
import time
from typing import Literal
import uuid
from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user_by_token
from app.models.signalr import NegotiateResponse, Transport
from app.router.signalr.packet import SEP
from .hub import Hubs
from fastapi import APIRouter, Depends, Header, Query, WebSocket
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter()
@router.post("/{hub}/negotiate", response_model=NegotiateResponse)
async def negotiate(
hub: Literal["spectator", "multiplayer", "metadata"],
negotiate_version: int = Query(1, alias="negotiateVersion"),
user: DBUser = Depends(get_current_user),
):
connectionId = str(user.id)
connectionToken = f"{connectionId}:{uuid.uuid4()}"
Hubs[hub].add_waited_client(
connection_token=connectionToken,
timestamp=int(time.time()),
)
return NegotiateResponse(
connectionId=connectionId,
connectionToken=connectionToken,
negotiateVersion=negotiate_version,
availableTransports=[Transport(transport="WebSockets")],
)
@router.websocket("/{hub}")
async def connect(
hub: Literal["spectator", "multiplayer", "metadata"],
websocket: WebSocket,
id: str,
authorization: str = Header(...),
db: AsyncSession = Depends(get_db),
):
token = authorization[7:]
user_id = id.split(":")[0]
hub_ = Hubs[hub]
if id not in hub_:
await websocket.close(code=1008)
return
if (user := await get_current_user_by_token(token, db)) is None or str(
user.id
) != user_id:
await websocket.close(code=1008)
return
await websocket.accept()
# handshake
handshake = await websocket.receive_bytes()
handshake_payload = json.loads(handshake[:-1])
error = ""
if (protocol := handshake_payload.get("protocol")) != "messagepack" or (
handshake_payload.get("version")
) != 1:
error = f"Requested protocol '{protocol}' is not available."
client = None
try:
client = hub_.add_client(
connection_id=user_id,
connection_token=id,
connection=websocket,
)
except TimeoutError:
error = f"Connection {id} has waited too long."
except ValueError as e:
error = str(e)
payload = {"error": error} if error else {}
# finish handshake
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
if error or not client:
await websocket.close(code=1008)
return
await hub_._listen_client(client)

View File

@@ -1,45 +0,0 @@
from __future__ import annotations
import asyncio
import sys
from typing import Any, Literal
from app.router.signalr.packet import ResultKind
class ResultStore:
def __init__(self) -> None:
self._seq: int = 1
self._futures: dict[str, asyncio.Future] = {}
@property
def current_invocation_id(self) -> int:
return self._seq
def get_invocation_id(self) -> str:
s = self._seq
self._seq = (self._seq + 1) % sys.maxsize
return str(s)
def add_result(
self, invocation_id: str, type: ResultKind, result: dict[str, Any] | None
) -> None:
if isinstance(invocation_id, str) and invocation_id.isdecimal():
if future := self._futures.get(invocation_id):
future.set_result((type, result))
async def fetch(
self,
invocation_id: str,
timeout: float | None, # noqa: ASYNC109
) -> (
tuple[Literal[ResultKind.ERROR], str]
| tuple[Literal[ResultKind.VOID], None]
| tuple[Literal[ResultKind.HAS_VALUE], Any]
):
future = asyncio.get_event_loop().create_future()
self._futures[invocation_id] = future
try:
return await asyncio.wait_for(future, timeout)
finally:
del self._futures[invocation_id]

View File

@@ -1,48 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
import inspect
from typing import Any, ForwardRef, cast
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
def evaluate_forwardref(
type_: ForwardRef,
globalns: Any,
localns: Any,
) -> Any:
# Even though it is the right signature for python 3.9,
# mypy complains with
# `error: Too many arguments for "_evaluate" of
# "ForwardRef"` hence the cast...
return cast(Any, type_)._evaluate(
globalns,
localns,
set(),
)
def get_annotation(param: inspect.Parameter, globalns: dict[str, Any]) -> Any:
annotation = param.annotation
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
try:
annotation = evaluate_forwardref(annotation, globalns, globalns)
except Exception:
return inspect.Parameter.empty
return annotation
def get_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_annotation(param, globalns),
)
for param in signature.parameters.values()
]
return inspect.Signature(typed_params)