from __future__ import annotations import asyncio import json import time from typing import Literal import uuid from app.database import User as DBUser from app.dependencies import get_current_user from app.dependencies.database import get_db from app.dependencies.user import get_current_user_by_token from app.models.signalr import NegotiateResponse, Transport from .hub import Hubs from .packet import PROTOCOLS, SEP from fastapi import APIRouter, Depends, Header, Query, WebSocket from sqlmodel.ext.asyncio.session import AsyncSession router = APIRouter() @router.post("/{hub}/negotiate", response_model=NegotiateResponse) async def negotiate( hub: Literal["spectator", "multiplayer", "metadata"], negotiate_version: int = Query(1, alias="negotiateVersion"), user: DBUser = Depends(get_current_user), ): connectionId = str(user.id) connectionToken = f"{connectionId}:{uuid.uuid4()}" Hubs[hub].add_waited_client( connection_token=connectionToken, timestamp=int(time.time()), ) return NegotiateResponse( connectionId=connectionId, connectionToken=connectionToken, negotiateVersion=negotiate_version, availableTransports=[Transport(transport="WebSockets")], ) @router.websocket("/{hub}") async def connect( hub: Literal["spectator", "multiplayer", "metadata"], websocket: WebSocket, id: str, authorization: str = Header(...), db: AsyncSession = Depends(get_db), ): token = authorization[7:] user_id = id.split(":")[0] hub_ = Hubs[hub] if id not in hub_: await websocket.close(code=1008) return if (user := await get_current_user_by_token(token, db)) is None or str( user.id ) != user_id: await websocket.close(code=1008) return await websocket.accept() # handshake handshake = await websocket.receive() message = handshake.get("bytes") or handshake.get("text") if not message: await websocket.close(code=1008) return handshake_payload = json.loads(message[:-1]) error = "" protocol = handshake_payload.get("protocol", "json") client = None try: client = await hub_.add_client( connection_id=user_id, connection_token=id, connection=websocket, protocol=PROTOCOLS[protocol], ) except KeyError: error = f"Protocol '{protocol}' is not supported." except TimeoutError: error = f"Connection {id} has waited too long." except ValueError as e: error = str(e) payload = {"error": error} if error else {} # finish handshake await websocket.send_bytes(json.dumps(payload).encode() + SEP) if error or not client: await websocket.close(code=1008) return await hub_.clean_state(client, False) task = asyncio.create_task(hub_.on_connect(client)) hub_.tasks.add(task) task.add_done_callback(hub_.tasks.discard) await hub_._listen_client(client) try: await websocket.close() except Exception: ...