chore(merge): merge branch 'main' into feat/multiplayer-api
This commit is contained in:
@@ -2,15 +2,13 @@ from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.exception import InvokeException
|
||||
from app.log import logger
|
||||
from app.models.signalr import UserState, _by_index
|
||||
from app.models.signalr import UserState
|
||||
from app.signalr.packet import (
|
||||
ClosePacket,
|
||||
CompletionPacket,
|
||||
@@ -23,7 +21,6 @@ from app.signalr.store import ResultStore
|
||||
from app.signalr.utils import get_signature
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
|
||||
@@ -51,7 +48,7 @@ class Client:
|
||||
self.connection_id = connection_id
|
||||
self.connection_token = connection_token
|
||||
self.connection = connection
|
||||
self.procotol = protocol
|
||||
self.protocol = protocol
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._ping_task: asyncio.Task | None = None
|
||||
self._store = ResultStore()
|
||||
@@ -64,14 +61,14 @@ class Client:
|
||||
return int(self.connection_id)
|
||||
|
||||
async def send_packet(self, packet: Packet):
|
||||
await self.connection.send_bytes(self.procotol.encode(packet))
|
||||
await self.connection.send_bytes(self.protocol.encode(packet))
|
||||
|
||||
async def receive_packets(self) -> list[Packet]:
|
||||
message = await self.connection.receive()
|
||||
d = message.get("bytes") or message.get("text", "").encode()
|
||||
if not d:
|
||||
return []
|
||||
return self.procotol.decode(d)
|
||||
return self.protocol.decode(d)
|
||||
|
||||
async def _ping(self):
|
||||
while True:
|
||||
@@ -265,14 +262,9 @@ class Hub[TState: UserState]:
|
||||
for name, param in signature.parameters.items():
|
||||
if name == "self" or param.annotation is Client:
|
||||
continue
|
||||
if issubclass(param.annotation, BaseModel):
|
||||
call_params.append(param.annotation.model_validate(args.pop(0)))
|
||||
elif inspect.isclass(param.annotation) and issubclass(
|
||||
param.annotation, Enum
|
||||
):
|
||||
call_params.append(_by_index(args.pop(0), param.annotation))
|
||||
else:
|
||||
call_params.append(args.pop(0))
|
||||
call_params.append(
|
||||
client.protocol.validate_object(args.pop(0), param.annotation)
|
||||
)
|
||||
return await method_(client, *call_params)
|
||||
|
||||
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
||||
|
||||
@@ -11,7 +11,6 @@ from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActiv
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -31,7 +30,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
) -> set[Coroutine]:
|
||||
if store is not None and not store.pushable:
|
||||
return set()
|
||||
data = store.to_dict() if store else None
|
||||
data = store.for_push if store else None
|
||||
return {
|
||||
self.broadcast_group_call(
|
||||
self.online_presence_watchers_group(),
|
||||
@@ -102,7 +101,9 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
self.friend_presence_watchers_group(friend_id),
|
||||
"FriendPresenceUpdated",
|
||||
friend_id,
|
||||
friend_state.to_dict(),
|
||||
friend_state.for_push
|
||||
if friend_state.pushable
|
||||
else None,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -122,27 +123,24 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
store.for_push,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
|
||||
async def UpdateActivity(
|
||||
self, client: Client, activity: UserActivity | None
|
||||
) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
activity = (
|
||||
TypeAdapter(UserActivity).validate_python(activity_dict)
|
||||
if activity_dict
|
||||
else None
|
||||
)
|
||||
store = self.get_or_create_state(client)
|
||||
store.user_activity = activity
|
||||
store.activity = activity
|
||||
tasks = self.broadcast_tasks(user_id, store)
|
||||
tasks.add(
|
||||
self.call_noblock(
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
store.for_push,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -154,7 +152,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
store,
|
||||
)
|
||||
for user_id, store in self.state.items()
|
||||
if store.pushable
|
||||
|
||||
@@ -13,8 +13,7 @@ from app.database.score_token import ScoreToken
|
||||
from app.dependencies.database import engine
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import mods_to_int
|
||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
|
||||
from app.models.signalr import serialize_to_list
|
||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
|
||||
from app.models.spectator_hub import (
|
||||
APIUser,
|
||||
FrameDataBundle,
|
||||
@@ -69,8 +68,8 @@ def save_replay(
|
||||
md5: str,
|
||||
username: str,
|
||||
score: Score,
|
||||
statistics: ScoreStatisticsInt,
|
||||
maximum_statistics: ScoreStatisticsInt,
|
||||
statistics: ScoreStatistics,
|
||||
maximum_statistics: ScoreStatistics,
|
||||
frames: list[LegacyReplayFrame],
|
||||
) -> None:
|
||||
data = bytearray()
|
||||
@@ -107,8 +106,8 @@ def save_replay(
|
||||
last_time = 0
|
||||
for frame in frames:
|
||||
frame_strs.append(
|
||||
f"{frame.time - last_time}|{frame.x or 0.0}"
|
||||
f"|{frame.y or 0.0}|{frame.button_state}"
|
||||
f"{frame.time - last_time}|{frame.mouse_x or 0.0}"
|
||||
f"|{frame.mouse_y or 0.0}|{frame.button_state}"
|
||||
)
|
||||
last_time = frame.time
|
||||
frame_strs.append("-12345|0|0|0")
|
||||
@@ -165,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
tasks = [
|
||||
self.call_noblock(
|
||||
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
|
||||
)
|
||||
self.call_noblock(client, "UserBeganPlaying", user_id, store.state)
|
||||
for user_id, store in self.state.items()
|
||||
if store.state is not None
|
||||
]
|
||||
@@ -214,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
self.group_id(user_id),
|
||||
"UserBeganPlaying",
|
||||
user_id,
|
||||
serialize_to_list(state),
|
||||
state,
|
||||
)
|
||||
|
||||
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
||||
@@ -222,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
state = self.get_or_create_state(client)
|
||||
if not state.score:
|
||||
return
|
||||
state.score.score_info.acc = frame_data.header.acc
|
||||
state.score.score_info.accuracy = frame_data.header.accuracy
|
||||
state.score.score_info.combo = frame_data.header.combo
|
||||
state.score.score_info.max_combo = frame_data.header.max_combo
|
||||
state.score.score_info.statistics = frame_data.header.statistics
|
||||
@@ -233,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
self.group_id(user_id),
|
||||
"UserSentFrames",
|
||||
user_id,
|
||||
frame_data.model_dump(),
|
||||
frame_data,
|
||||
)
|
||||
|
||||
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
||||
@@ -316,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
self.group_id(user_id),
|
||||
"UserFinishedPlaying",
|
||||
user_id,
|
||||
serialize_to_list(state) if state else None,
|
||||
state,
|
||||
)
|
||||
|
||||
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
@@ -327,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
client,
|
||||
"UserBeganPlaying",
|
||||
target_id,
|
||||
serialize_to_list(target_store.state),
|
||||
target_store.state,
|
||||
)
|
||||
store = self.get_or_create_state(client)
|
||||
store.watched_user.add(target_id)
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
import datetime
|
||||
from enum import Enum, IntEnum
|
||||
import inspect
|
||||
import json
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol as TypingProtocol,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from app.models.signalr import serialize_msgpack
|
||||
from app.models.signalr import SignalRMeta, SignalRUnionMessage
|
||||
from app.utils import camel_to_snake, snake_to_camel
|
||||
|
||||
import msgpack_lazer_api as m
|
||||
from pydantic import BaseModel
|
||||
|
||||
SEP = b"\x1e"
|
||||
|
||||
@@ -75,8 +83,61 @@ class Protocol(TypingProtocol):
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes: ...
|
||||
|
||||
@classmethod
|
||||
def validate_object(cls, v: Any, typ: type) -> Any: ...
|
||||
|
||||
|
||||
class MsgpackProtocol:
|
||||
@classmethod
|
||||
def serialize_msgpack(cls, v: Any) -> Any:
|
||||
typ = v.__class__
|
||||
if issubclass(typ, BaseModel):
|
||||
return cls.serialize_to_list(v)
|
||||
elif issubclass(typ, list):
|
||||
return [cls.serialize_msgpack(item) for item in v]
|
||||
elif issubclass(typ, datetime.datetime):
|
||||
return [v, 0]
|
||||
elif isinstance(v, dict):
|
||||
return {
|
||||
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
|
||||
for k, value in v.items()
|
||||
}
|
||||
elif issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
return list_.index(v) if v in list_ else v.value
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def serialize_to_list(cls, value: BaseModel) -> list[Any]:
|
||||
values = []
|
||||
for field, info in value.__class__.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.member_ignore:
|
||||
continue
|
||||
values.append(cls.serialize_msgpack(v=getattr(value, field)))
|
||||
if issubclass(value.__class__, SignalRUnionMessage):
|
||||
return [value.__class__.union_type, values]
|
||||
else:
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def process_object(v: Any, typ: type[BaseModel]) -> Any:
|
||||
if isinstance(v, list):
|
||||
d = {}
|
||||
for i, f in enumerate(typ.model_fields.items()):
|
||||
field, info = f
|
||||
if info.exclude:
|
||||
continue
|
||||
anno = info.annotation
|
||||
if anno is None:
|
||||
d[camel_to_snake(field)] = v[i]
|
||||
continue
|
||||
d[field] = MsgpackProtocol.validate_object(v[i], anno)
|
||||
return d
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def _encode_varint(value: int) -> bytes:
|
||||
result = []
|
||||
@@ -142,6 +203,49 @@ class MsgpackProtocol:
|
||||
]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@classmethod
|
||||
def validate_object(cls, v: Any, typ: type) -> Any:
|
||||
if issubclass(typ, BaseModel):
|
||||
return typ.model_validate(obj=cls.process_object(v, typ))
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||
return v[0]
|
||||
elif isinstance(v, list):
|
||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
|
||||
elif get_origin(typ) is dict:
|
||||
return {
|
||||
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
|
||||
v, get_args(typ)[1]
|
||||
)
|
||||
for k, v in v.items()
|
||||
}
|
||||
elif (origin := get_origin(typ)) is Union or origin is UnionType:
|
||||
args = get_args(typ)
|
||||
if len(args) == 2 and NoneType in args:
|
||||
non_none_args = [arg for arg in args if arg is not NoneType]
|
||||
if len(non_none_args) == 1:
|
||||
if v is None:
|
||||
return None
|
||||
return cls.validate_object(v, non_none_args[0])
|
||||
|
||||
# suppose use `MessagePack-CSharp Union | None`
|
||||
# except `X (Other Type) | None`
|
||||
if NoneType in args and v is None:
|
||||
return None
|
||||
if not all(issubclass(arg, SignalRUnionMessage) for arg in args):
|
||||
raise ValueError(
|
||||
f"Cannot validate {v} to {typ}, "
|
||||
"only SignalRUnionMessage subclasses are supported"
|
||||
)
|
||||
union_type = v[0]
|
||||
for arg in args:
|
||||
assert issubclass(arg, SignalRUnionMessage)
|
||||
if arg.union_type == union_type:
|
||||
return cls.validate_object(v[1], arg)
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes:
|
||||
payload = [packet.type.value, packet.header or {}]
|
||||
@@ -153,7 +257,9 @@ class MsgpackProtocol:
|
||||
]
|
||||
)
|
||||
if packet.arguments is not None:
|
||||
payload.append([serialize_msgpack(arg) for arg in packet.arguments])
|
||||
payload.append(
|
||||
[MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments]
|
||||
)
|
||||
if packet.stream_ids is not None:
|
||||
payload.append(packet.stream_ids)
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
@@ -166,7 +272,9 @@ class MsgpackProtocol:
|
||||
[
|
||||
packet.invocation_id,
|
||||
result_kind,
|
||||
packet.error or packet.result or None,
|
||||
packet.error
|
||||
or MsgpackProtocol.serialize_msgpack(packet.result)
|
||||
or None,
|
||||
]
|
||||
)
|
||||
elif isinstance(packet, ClosePacket):
|
||||
@@ -183,6 +291,62 @@ class MsgpackProtocol:
|
||||
|
||||
|
||||
class JSONProtocol:
|
||||
@classmethod
|
||||
def serialize_to_json(cls, v: Any):
|
||||
typ = v.__class__
|
||||
if issubclass(typ, BaseModel):
|
||||
return cls.serialize_model(v)
|
||||
elif isinstance(v, dict):
|
||||
return {
|
||||
cls.serialize_to_json(k): cls.serialize_to_json(value)
|
||||
for k, value in v.items()
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
return [cls.serialize_to_json(item) for item in v]
|
||||
elif isinstance(v, datetime.datetime):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, Enum):
|
||||
return v.value
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def serialize_model(cls, v: BaseModel) -> dict[str, Any]:
|
||||
d = {}
|
||||
for field, info in v.__class__.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.json_ignore:
|
||||
continue
|
||||
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = (
|
||||
cls.serialize_to_json(getattr(v, field))
|
||||
)
|
||||
if issubclass(v.__class__, SignalRUnionMessage):
|
||||
return {
|
||||
"$dtype": v.__class__.__name__,
|
||||
"$value": d,
|
||||
}
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def process_object(
|
||||
v: Any, typ: type[BaseModel], from_union: bool = False
|
||||
) -> dict[str, Any]:
|
||||
d = {}
|
||||
for field, info in typ.model_fields.items():
|
||||
metadata = next(
|
||||
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||
)
|
||||
if metadata and metadata.json_ignore:
|
||||
continue
|
||||
value = v.get(snake_to_camel(field, not from_union))
|
||||
anno = typ.model_fields[field].annotation
|
||||
if anno is None:
|
||||
d[field] = value
|
||||
continue
|
||||
d[field] = JSONProtocol.validate_object(value, anno)
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
packets_raw = input.removesuffix(SEP).split(SEP)
|
||||
@@ -227,6 +391,52 @@ class JSONProtocol:
|
||||
]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@classmethod
|
||||
def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
|
||||
if issubclass(typ, BaseModel):
|
||||
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||
return datetime.datetime.fromisoformat(v)
|
||||
elif isinstance(v, list):
|
||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
|
||||
elif get_origin(typ) is dict:
|
||||
return {
|
||||
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
|
||||
v, get_args(typ)[1]
|
||||
)
|
||||
for k, v in v.items()
|
||||
}
|
||||
elif (origin := get_origin(typ)) is Union or origin is UnionType:
|
||||
args = get_args(typ)
|
||||
if len(args) == 2 and NoneType in args:
|
||||
non_none_args = [arg for arg in args if arg is not NoneType]
|
||||
if len(non_none_args) == 1:
|
||||
if v is None:
|
||||
return None
|
||||
return cls.validate_object(v, non_none_args[0])
|
||||
|
||||
# suppose use `MessagePack-CSharp Union | None`
|
||||
# except `X (Other Type) | None`
|
||||
if NoneType in args and v is None:
|
||||
return None
|
||||
if not all(
|
||||
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot validate {v} to {typ}, "
|
||||
"only SignalRUnionMessage subclasses are supported"
|
||||
)
|
||||
# https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
|
||||
union_type = v["$dtype"]
|
||||
for arg in args:
|
||||
assert issubclass(arg, SignalRUnionMessage)
|
||||
if arg.__name__ == union_type:
|
||||
return cls.validate_object(v["$value"], arg, True)
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes:
|
||||
payload: dict[str, Any] = {
|
||||
@@ -243,7 +453,9 @@ class JSONProtocol:
|
||||
if packet.invocation_id is not None:
|
||||
payload["invocationId"] = packet.invocation_id
|
||||
if packet.arguments is not None:
|
||||
payload["arguments"] = packet.arguments
|
||||
payload["arguments"] = [
|
||||
JSONProtocol.serialize_to_json(arg) for arg in packet.arguments
|
||||
]
|
||||
if packet.stream_ids is not None:
|
||||
payload["streamIds"] = packet.stream_ids
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
@@ -255,7 +467,7 @@ class JSONProtocol:
|
||||
if packet.error is not None:
|
||||
payload["error"] = packet.error
|
||||
if packet.result is not None:
|
||||
payload["result"] = packet.result
|
||||
payload["result"] = JSONProtocol.serialize_to_json(packet.result)
|
||||
elif isinstance(packet, PingPacket):
|
||||
pass
|
||||
elif isinstance(packet, ClosePacket):
|
||||
|
||||
Reference in New Issue
Block a user