Reapply "Merge branch 'main' of https://github.com/GooGuTeam/osu_lazer_api"
This reverts commit 68701dbb1d.
This commit is contained in:
@@ -3,3 +3,4 @@ from __future__ import annotations
|
||||
from . import me # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
from .api_router import router as api_router
|
||||
from .auth import router as auth_router
|
||||
from .signalr import signalr_router as signalr_router
|
||||
|
||||
@@ -11,7 +11,7 @@ from app.auth import (
|
||||
)
|
||||
from app.config import settings
|
||||
from app.dependencies import get_db
|
||||
from app.models import TokenResponse
|
||||
from app.models.oauth import TokenResponse
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -6,7 +6,7 @@ from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models import (
|
||||
from app.models.user import (
|
||||
User as ApiUser,
|
||||
)
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
@@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@router.get("/me/{ruleset}", response_model=ApiUser)
|
||||
@router.get("/me/", response_model=ApiUser)
|
||||
async def get_user_info_default(
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
|
||||
1
app/router/signalr/__init__.py
Normal file
1
app/router/signalr/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .router import router as signalr_router
|
||||
6
app/router/signalr/exception.py
Normal file
6
app/router/signalr/exception.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class SignalRException(Exception):
|
||||
pass
|
||||
|
||||
class InvokeException(SignalRException):
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
15
app/router/signalr/hub/__init__.py
Normal file
15
app/router/signalr/hub/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .hub import Hub
|
||||
from .metadata import MetadataHub
|
||||
from .multiplayer import MultiplayerHub
|
||||
from .spectator import SpectatorHub
|
||||
|
||||
SpectatorHubs = SpectatorHub()
|
||||
MultiplayerHubs = MultiplayerHub()
|
||||
MetadataHubs = MetadataHub()
|
||||
Hubs: dict[str, Hub] = {
|
||||
"spectator": SpectatorHubs,
|
||||
"multiplayer": MultiplayerHubs,
|
||||
"metadata": MetadataHubs,
|
||||
}
|
||||
211
app/router/signalr/hub/hub.py
Normal file
211
app/router/signalr/hub/hub.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.router.signalr.exception import InvokeException
|
||||
from app.router.signalr.packet import (
|
||||
PacketType,
|
||||
ResultKind,
|
||||
encode_varint,
|
||||
parse_packet,
|
||||
)
|
||||
from app.router.signalr.store import ResultStore
|
||||
from app.router.signalr.utils import get_signature
|
||||
|
||||
from fastapi import WebSocket
|
||||
import msgpack
|
||||
from pydantic import BaseModel
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self, connection_id: str, connection_token: str, connection: WebSocket
|
||||
) -> None:
|
||||
self.connection_id = connection_id
|
||||
self.connection_token = connection_token
|
||||
self.connection = connection
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._ping_task: asyncio.Task | None = None
|
||||
self._store = ResultStore()
|
||||
|
||||
async def send_packet(self, type: PacketType, packet: list[Any]):
|
||||
packet.insert(0, type.value)
|
||||
payload = msgpack.packb(packet)
|
||||
length = encode_varint(len(payload))
|
||||
await self.connection.send_bytes(length + payload)
|
||||
|
||||
async def _ping(self):
|
||||
while True:
|
||||
try:
|
||||
await self.send_packet(PacketType.PING, [])
|
||||
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error in ping task for {self.connection_id}: {e}")
|
||||
break
|
||||
|
||||
|
||||
class Hub:
|
||||
def __init__(self) -> None:
|
||||
self.clients: dict[str, Client] = {}
|
||||
self.waited_clients: dict[str, int] = {}
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
|
||||
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
|
||||
self.waited_clients[connection_token] = timestamp
|
||||
|
||||
def add_client(
|
||||
self, connection_id: str, connection_token: str, connection: WebSocket
|
||||
) -> Client:
|
||||
if connection_token in self.clients:
|
||||
raise ValueError(
|
||||
f"Client with connection token {connection_token} already exists."
|
||||
)
|
||||
if connection_token in self.waited_clients:
|
||||
if (
|
||||
self.waited_clients[connection_token]
|
||||
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
|
||||
):
|
||||
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
||||
del self.waited_clients[connection_token]
|
||||
client = Client(connection_id, connection_token, connection)
|
||||
self.clients[connection_token] = client
|
||||
task = asyncio.create_task(client._ping())
|
||||
self.tasks.add(task)
|
||||
client._ping_task = task
|
||||
return client
|
||||
|
||||
async def remove_client(self, connection_id: str) -> None:
|
||||
if client := self.clients.get(connection_id):
|
||||
del self.clients[connection_id]
|
||||
if client._listen_task:
|
||||
client._listen_task.cancel()
|
||||
if client._ping_task:
|
||||
client._ping_task.cancel()
|
||||
await client.connection.close()
|
||||
|
||||
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]):
|
||||
await client.send_packet(type, packet)
|
||||
|
||||
async def _listen_client(self, client: Client) -> None:
|
||||
jump = False
|
||||
while not jump:
|
||||
try:
|
||||
message = await client.connection.receive_bytes()
|
||||
packet_type, packet_data = parse_packet(message)
|
||||
task = asyncio.create_task(
|
||||
self._handle_packet(client, packet_type, packet_data)
|
||||
)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except WebSocketDisconnect as e:
|
||||
if e.code == 1005:
|
||||
continue
|
||||
print(
|
||||
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
jump = True
|
||||
except Exception as e:
|
||||
print(f"Error in client {client.connection_id}: {e}")
|
||||
jump = True
|
||||
await self.remove_client(client.connection_id)
|
||||
|
||||
async def _handle_packet(
|
||||
self, client: Client, type: PacketType, packet: list[Any]
|
||||
) -> None:
|
||||
match type:
|
||||
case PacketType.PING:
|
||||
...
|
||||
case PacketType.INVOCATION:
|
||||
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration]
|
||||
target: str = packet[2]
|
||||
args: list[Any] | None = packet[3]
|
||||
if args is None:
|
||||
args = []
|
||||
streams: list[str] | None = packet[4] # TODO: stream support
|
||||
code = ResultKind.VOID
|
||||
result = None
|
||||
try:
|
||||
result = await self.invoke_method(client, target, args)
|
||||
if result is not None:
|
||||
code = ResultKind.HAS_VALUE
|
||||
except InvokeException as e:
|
||||
code = ResultKind.ERROR
|
||||
result = e.message
|
||||
|
||||
except Exception as e:
|
||||
code = ResultKind.ERROR
|
||||
result = str(e)
|
||||
|
||||
packet = [
|
||||
{}, # header
|
||||
invocation_id,
|
||||
code.value,
|
||||
]
|
||||
if result is not None:
|
||||
packet.append(result)
|
||||
if invocation_id is not None:
|
||||
await client.send_packet(
|
||||
PacketType.COMPLETION,
|
||||
packet,
|
||||
)
|
||||
case PacketType.COMPLETION:
|
||||
invocation_id: str = packet[1]
|
||||
code: ResultKind = ResultKind(packet[2])
|
||||
result: Any = packet[3] if len(packet) > 3 else None
|
||||
client._store.add_result(invocation_id, code, result)
|
||||
|
||||
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
|
||||
method_ = getattr(self, method, None)
|
||||
call_params = []
|
||||
if not method_:
|
||||
raise InvokeException(f"Method '{method}' not found in hub.")
|
||||
signature = get_signature(method_)
|
||||
for name, param in signature.parameters.items():
|
||||
if name == "self" or param.annotation is Client:
|
||||
continue
|
||||
if issubclass(param.annotation, BaseModel):
|
||||
call_params.append(param.annotation.model_validate(args.pop(0)))
|
||||
else:
|
||||
call_params.append(args.pop(0))
|
||||
return await method_(client, *call_params)
|
||||
|
||||
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
||||
invocation_id = client._store.get_invocation_id()
|
||||
await client.send_packet(
|
||||
PacketType.INVOCATION,
|
||||
[
|
||||
{}, # header
|
||||
invocation_id,
|
||||
method,
|
||||
list(args),
|
||||
None, # streams
|
||||
],
|
||||
)
|
||||
r = await client._store.fetch(invocation_id, None)
|
||||
if r[0] == ResultKind.HAS_VALUE:
|
||||
return r[1]
|
||||
if r[0] == ResultKind.ERROR:
|
||||
raise InvokeException(r[1])
|
||||
return None
|
||||
|
||||
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
|
||||
await client.send_packet(
|
||||
PacketType.INVOCATION,
|
||||
[
|
||||
{}, # header
|
||||
None, # invocation_id
|
||||
method,
|
||||
list(args),
|
||||
None, # streams
|
||||
],
|
||||
)
|
||||
return None
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return item in self.clients or item in self.waited_clients
|
||||
4
app/router/signalr/hub/metadata.py
Normal file
4
app/router/signalr/hub/metadata.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .hub import Hub
|
||||
|
||||
|
||||
class MetadataHub(Hub): ...
|
||||
4
app/router/signalr/hub/multiplayer.py
Normal file
4
app/router/signalr/hub/multiplayer.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .hub import Hub
|
||||
|
||||
|
||||
class MultiplayerHub(Hub): ...
|
||||
17
app/router/signalr/hub/spectator.py
Normal file
17
app/router/signalr/hub/spectator.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.spectator_hub import FrameDataBundle, SpectatorState
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
|
||||
class SpectatorHub(Hub):
|
||||
async def BeginPlaySession(
|
||||
self, client: Client, score_token: int, state: SpectatorState
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def SendFrameData(
|
||||
self, client: Client, frame_data: FrameDataBundle
|
||||
) -> None:
|
||||
...
|
||||
56
app/router/signalr/packet.py
Normal file
56
app/router/signalr/packet.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
SEP = b"\x1e"
|
||||
|
||||
|
||||
class PacketType(IntEnum):
|
||||
INVOCATION = 1
|
||||
STREAM_ITEM = 2
|
||||
COMPLETION = 3
|
||||
STREAM_INVOCATION = 4
|
||||
CANCEL_INVOCATION = 5
|
||||
PING = 6
|
||||
CLOSE = 7
|
||||
|
||||
class ResultKind(IntEnum):
|
||||
ERROR = 1
|
||||
VOID = 2
|
||||
HAS_VALUE = 3
|
||||
|
||||
|
||||
def parse_packet(data: bytes) -> tuple[PacketType, list[Any]]:
|
||||
length, offset = decode_varint(data)
|
||||
message_data = data[offset : offset + length]
|
||||
unpacked = msgpack.unpackb(message_data, raw=False)
|
||||
return PacketType(unpacked[0]), unpacked[1:]
|
||||
|
||||
|
||||
def encode_varint(value: int) -> bytes:
|
||||
result = []
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
|
||||
result = 0
|
||||
shift = 0
|
||||
pos = offset
|
||||
|
||||
while pos < len(data):
|
||||
byte = data[pos]
|
||||
result |= (byte & 0x7F) << shift
|
||||
pos += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
|
||||
return result, pos
|
||||
93
app/router/signalr/router.py
Normal file
93
app/router/signalr/router.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import info
|
||||
import time
|
||||
from typing import Literal
|
||||
import uuid
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user_by_token, security
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
from app.router.signalr.packet import SEP
|
||||
|
||||
from .hub import Hubs
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/{hub}/negotiate", response_model=NegotiateResponse)
|
||||
async def negotiate(
|
||||
hub: Literal["spectator", "multiplayer", "metadata"],
|
||||
negotiate_version: int = Query(1, alias="negotiateVersion"),
|
||||
user: DBUser = Depends(get_current_user),
|
||||
):
|
||||
connectionId = str(user.id)
|
||||
connectionToken = f"{connectionId}:{uuid.uuid4()}"
|
||||
Hubs[hub].add_waited_client(
|
||||
connection_token=connectionToken,
|
||||
timestamp=int(time.time()),
|
||||
)
|
||||
return NegotiateResponse(
|
||||
connectionId=connectionId,
|
||||
connectionToken=connectionToken,
|
||||
negotiateVersion=negotiate_version,
|
||||
availableTransports=[Transport(transport="WebSockets")],
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/{hub}")
|
||||
async def connect(
|
||||
hub: Literal["spectator", "multiplayer", "metadata"],
|
||||
websocket: WebSocket,
|
||||
id: str,
|
||||
authorization: str = Header(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
token = authorization[7:]
|
||||
user_id = id.split(":")[0]
|
||||
hub_ = Hubs[hub]
|
||||
if id not in hub_:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
if (user := await get_current_user_by_token(token, db)) is None or str(
|
||||
user.id
|
||||
) != user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
await websocket.accept()
|
||||
|
||||
# handshake
|
||||
handshake = await websocket.receive_bytes()
|
||||
handshake_payload = json.loads(handshake[:-1])
|
||||
error = ""
|
||||
if (protocol := handshake_payload.get("protocol")) != "messagepack" or (
|
||||
handshake_payload.get("version")
|
||||
) != 1:
|
||||
error = f"Requested protocol '{protocol}' is not available."
|
||||
|
||||
client = None
|
||||
try:
|
||||
client = hub_.add_client(
|
||||
connection_id=user_id,
|
||||
connection_token=id,
|
||||
connection=websocket,
|
||||
)
|
||||
except TimeoutError:
|
||||
error = f"Connection {id} has waited too long."
|
||||
except ValueError as e:
|
||||
error = str(e)
|
||||
payload = {"error": error} if error else {}
|
||||
|
||||
# finish handshake
|
||||
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
|
||||
if error or not client:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
await hub_._listen_client(client)
|
||||
45
app/router/signalr/store.py
Normal file
45
app/router/signalr/store.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.router.signalr.packet import ResultKind
|
||||
|
||||
|
||||
class ResultStore:
|
||||
def __init__(self) -> None:
|
||||
self._seq: int = 1
|
||||
self._futures: dict[str, asyncio.Future] = {}
|
||||
|
||||
@property
|
||||
def current_invocation_id(self) -> int:
|
||||
return self._seq
|
||||
|
||||
def get_invocation_id(self) -> str:
|
||||
s = self._seq
|
||||
self._seq = (self._seq + 1) % sys.maxsize
|
||||
return str(s)
|
||||
|
||||
def add_result(
|
||||
self, invocation_id: str, type: ResultKind, result: dict[str, Any] | None
|
||||
) -> None:
|
||||
if isinstance(invocation_id, str) and invocation_id.isdecimal():
|
||||
if future := self._futures.get(invocation_id):
|
||||
future.set_result((type, result))
|
||||
|
||||
async def fetch(
|
||||
self,
|
||||
invocation_id: str,
|
||||
timeout: float | None, # noqa: ASYNC109
|
||||
) -> (
|
||||
tuple[Literal[ResultKind.ERROR], str]
|
||||
| tuple[Literal[ResultKind.VOID], None]
|
||||
| tuple[Literal[ResultKind.HAS_VALUE], Any]
|
||||
):
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._futures[invocation_id] = future
|
||||
try:
|
||||
return await asyncio.wait_for(future, timeout)
|
||||
finally:
|
||||
del self._futures[invocation_id]
|
||||
45
app/router/signalr/utils.py
Normal file
45
app/router/signalr/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, ForwardRef, cast
|
||||
|
||||
|
||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
|
||||
def evaluate_forwardref(
|
||||
type_: ForwardRef,
|
||||
globalns: Any,
|
||||
localns: Any,
|
||||
) -> Any:
|
||||
# Even though it is the right signature for python 3.9,
|
||||
# mypy complains with
|
||||
# `error: Too many arguments for "_evaluate" of
|
||||
# "ForwardRef"` hence the cast...
|
||||
return cast(Any, type_)._evaluate(
|
||||
globalns,
|
||||
localns,
|
||||
set(),
|
||||
)
|
||||
|
||||
|
||||
def get_annotation(param: inspect.Parameter, globalns: dict[str, Any]) -> Any:
|
||||
annotation = param.annotation
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
try:
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
except Exception:
|
||||
return inspect.Parameter.empty
|
||||
return annotation
|
||||
|
||||
|
||||
def get_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_annotation(param, globalns),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
return inspect.Signature(typed_params)
|
||||
Reference in New Issue
Block a user