115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from typing import Literal
|
|
import uuid
|
|
|
|
from app.database import User as DBUser
|
|
from app.dependencies import get_current_user
|
|
from app.dependencies.database import get_db
|
|
from app.models.signalr import NegotiateResponse, Transport
|
|
|
|
from .hub import Hubs
|
|
from .packet import PROTOCOLS, SEP
|
|
|
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
|
|
from fastapi.security import SecurityScopes
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
router = APIRouter(prefix="/signalr", tags=["SignalR"])
|
|
|
|
|
|
@router.post("/{hub}/negotiate", response_model=NegotiateResponse)
|
|
async def negotiate(
|
|
hub: Literal["spectator", "multiplayer", "metadata"],
|
|
negotiate_version: int = Query(1, alias="negotiateVersion"),
|
|
user: DBUser = Depends(get_current_user),
|
|
):
|
|
connectionId = str(user.id)
|
|
connectionToken = f"{connectionId}:{uuid.uuid4()}"
|
|
Hubs[hub].add_waited_client(
|
|
connection_token=connectionToken,
|
|
timestamp=int(time.time()),
|
|
)
|
|
return NegotiateResponse(
|
|
connectionId=connectionId,
|
|
connectionToken=connectionToken,
|
|
negotiateVersion=negotiate_version,
|
|
availableTransports=[Transport(transport="WebSockets")],
|
|
)
|
|
|
|
|
|
@router.websocket("/{hub}")
|
|
async def connect(
|
|
hub: Literal["spectator", "multiplayer", "metadata"],
|
|
websocket: WebSocket,
|
|
id: str,
|
|
authorization: str = Header(...),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
token = authorization[7:]
|
|
user_id = id.split(":")[0]
|
|
hub_ = Hubs[hub]
|
|
if id not in hub_:
|
|
await websocket.close(code=1008)
|
|
return
|
|
try:
|
|
if (
|
|
user := await get_current_user(
|
|
SecurityScopes(scopes=["*"]), db, token_pw=token
|
|
)
|
|
) is None or str(user.id) != user_id:
|
|
await websocket.close(code=1008)
|
|
return
|
|
except HTTPException:
|
|
await websocket.close(code=1008)
|
|
return
|
|
await websocket.accept()
|
|
|
|
# handshake
|
|
handshake = await websocket.receive()
|
|
message = handshake.get("bytes") or handshake.get("text")
|
|
if not message:
|
|
await websocket.close(code=1008)
|
|
return
|
|
handshake_payload = json.loads(message[:-1])
|
|
error = ""
|
|
protocol = handshake_payload.get("protocol", "json")
|
|
|
|
client = None
|
|
try:
|
|
client = await hub_.add_client(
|
|
connection_id=user_id,
|
|
connection_token=id,
|
|
connection=websocket,
|
|
protocol=PROTOCOLS[protocol],
|
|
)
|
|
except KeyError:
|
|
error = f"Protocol '{protocol}' is not supported."
|
|
except TimeoutError:
|
|
error = f"Connection {id} has waited too long."
|
|
except ValueError as e:
|
|
error = str(e)
|
|
payload = {"error": error} if error else {}
|
|
# finish handshake
|
|
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
|
|
if error or not client:
|
|
await websocket.close(code=1008)
|
|
return
|
|
|
|
connected_clients = hub_.get_before_clients(user_id, id)
|
|
for connected_client in connected_clients:
|
|
await hub_.kick_client(connected_client)
|
|
|
|
await hub_.clean_state(client, False)
|
|
task = asyncio.create_task(hub_.on_connect(client))
|
|
hub_.tasks.add(task)
|
|
task.add_done_callback(hub_.tasks.discard)
|
|
await hub_._listen_client(client)
|
|
try:
|
|
await websocket.close()
|
|
except Exception:
|
|
...
|