chore(merge): merge branch 'main' into feat/multiplayer-api

This commit is contained in:
MingxuanGame
2025-08-03 09:50:53 +00:00
13 changed files with 434 additions and 325 deletions

View File

@@ -2,15 +2,13 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from enum import Enum
import inspect
import time
from typing import Any
from app.config import settings
from app.exception import InvokeException
from app.log import logger
from app.models.signalr import UserState, _by_index
from app.models.signalr import UserState
from app.signalr.packet import (
ClosePacket,
CompletionPacket,
@@ -23,7 +21,6 @@ from app.signalr.store import ResultStore
from app.signalr.utils import get_signature
from fastapi import WebSocket
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
@@ -51,7 +48,7 @@ class Client:
self.connection_id = connection_id
self.connection_token = connection_token
self.connection = connection
self.procotol = protocol
self.protocol = protocol
self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None
self._store = ResultStore()
@@ -64,14 +61,14 @@ class Client:
return int(self.connection_id)
async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.procotol.encode(packet))
await self.connection.send_bytes(self.protocol.encode(packet))
async def receive_packets(self) -> list[Packet]:
message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
return []
return self.procotol.decode(d)
return self.protocol.decode(d)
async def _ping(self):
while True:
@@ -265,14 +262,9 @@ class Hub[TState: UserState]:
for name, param in signature.parameters.items():
if name == "self" or param.annotation is Client:
continue
if issubclass(param.annotation, BaseModel):
call_params.append(param.annotation.model_validate(args.pop(0)))
elif inspect.isclass(param.annotation) and issubclass(
param.annotation, Enum
):
call_params.append(_by_index(args.pop(0), param.annotation))
else:
call_params.append(args.pop(0))
call_params.append(
client.protocol.validate_object(args.pop(0), param.annotation)
)
return await method_(client, *call_params)
async def call(self, client: Client, method: str, *args: Any) -> Any:

View File

@@ -11,7 +11,6 @@ from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActiv
from .hub import Client, Hub
from pydantic import TypeAdapter
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -31,7 +30,7 @@ class MetadataHub(Hub[MetadataClientState]):
) -> set[Coroutine]:
if store is not None and not store.pushable:
return set()
data = store.to_dict() if store else None
data = store.for_push if store else None
return {
self.broadcast_group_call(
self.online_presence_watchers_group(),
@@ -102,7 +101,9 @@ class MetadataHub(Hub[MetadataClientState]):
self.friend_presence_watchers_group(friend_id),
"FriendPresenceUpdated",
friend_id,
friend_state.to_dict(),
friend_state.for_push
if friend_state.pushable
else None,
)
)
await asyncio.gather(*tasks)
@@ -122,27 +123,24 @@ class MetadataHub(Hub[MetadataClientState]):
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
store.for_push,
)
)
await asyncio.gather(*tasks)
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
async def UpdateActivity(
self, client: Client, activity: UserActivity | None
) -> None:
user_id = int(client.connection_id)
activity = (
TypeAdapter(UserActivity).validate_python(activity_dict)
if activity_dict
else None
)
store = self.get_or_create_state(client)
store.user_activity = activity
store.activity = activity
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
store.for_push,
)
)
await asyncio.gather(*tasks)
@@ -154,7 +152,7 @@ class MetadataHub(Hub[MetadataClientState]):
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
store,
)
for user_id, store in self.state.items()
if store.pushable

View File

@@ -13,8 +13,7 @@ from app.database.score_token import ScoreToken
from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
from app.models.signalr import serialize_to_list
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
from app.models.spectator_hub import (
APIUser,
FrameDataBundle,
@@ -69,8 +68,8 @@ def save_replay(
md5: str,
username: str,
score: Score,
statistics: ScoreStatisticsInt,
maximum_statistics: ScoreStatisticsInt,
statistics: ScoreStatistics,
maximum_statistics: ScoreStatistics,
frames: list[LegacyReplayFrame],
) -> None:
data = bytearray()
@@ -107,8 +106,8 @@ def save_replay(
last_time = 0
for frame in frames:
frame_strs.append(
f"{frame.time - last_time}|{frame.x or 0.0}"
f"|{frame.y or 0.0}|{frame.button_state}"
f"{frame.time - last_time}|{frame.mouse_x or 0.0}"
f"|{frame.mouse_y or 0.0}|{frame.button_state}"
)
last_time = frame.time
frame_strs.append("-12345|0|0|0")
@@ -165,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]):
async def on_client_connect(self, client: Client) -> None:
tasks = [
self.call_noblock(
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
)
self.call_noblock(client, "UserBeganPlaying", user_id, store.state)
for user_id, store in self.state.items()
if store.state is not None
]
@@ -214,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserBeganPlaying",
user_id,
serialize_to_list(state),
state,
)
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
@@ -222,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]):
state = self.get_or_create_state(client)
if not state.score:
return
state.score.score_info.acc = frame_data.header.acc
state.score.score_info.accuracy = frame_data.header.accuracy
state.score.score_info.combo = frame_data.header.combo
state.score.score_info.max_combo = frame_data.header.max_combo
state.score.score_info.statistics = frame_data.header.statistics
@@ -233,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserSentFrames",
user_id,
frame_data.model_dump(),
frame_data,
)
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
@@ -316,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserFinishedPlaying",
user_id,
serialize_to_list(state) if state else None,
state,
)
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
@@ -327,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]):
client,
"UserBeganPlaying",
target_id,
serialize_to_list(target_store.state),
target_store.state,
)
store = self.get_or_create_state(client)
store.watched_user.add(target_id)

View File

@@ -1,16 +1,24 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
import datetime
from enum import Enum, IntEnum
import inspect
import json
from types import NoneType, UnionType
from typing import (
Any,
Protocol as TypingProtocol,
Union,
get_args,
get_origin,
)
from app.models.signalr import serialize_msgpack
from app.models.signalr import SignalRMeta, SignalRUnionMessage
from app.utils import camel_to_snake, snake_to_camel
import msgpack_lazer_api as m
from pydantic import BaseModel
SEP = b"\x1e"
@@ -75,8 +83,61 @@ class Protocol(TypingProtocol):
@staticmethod
def encode(packet: Packet) -> bytes: ...
@classmethod
def validate_object(cls, v: Any, typ: type) -> Any: ...
class MsgpackProtocol:
@classmethod
def serialize_msgpack(cls, v: Any) -> Any:
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_to_list(v)
elif issubclass(typ, list):
return [cls.serialize_msgpack(item) for item in v]
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif isinstance(v, dict):
return {
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
for k, value in v.items()
}
elif issubclass(typ, Enum):
list_ = list(typ)
return list_.index(v) if v in list_ else v.value
return v
@classmethod
def serialize_to_list(cls, value: BaseModel) -> list[Any]:
values = []
for field, info in value.__class__.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.member_ignore:
continue
values.append(cls.serialize_msgpack(v=getattr(value, field)))
if issubclass(value.__class__, SignalRUnionMessage):
return [value.__class__.union_type, values]
else:
return values
@staticmethod
def process_object(v: Any, typ: type[BaseModel]) -> Any:
if isinstance(v, list):
d = {}
for i, f in enumerate(typ.model_fields.items()):
field, info = f
if info.exclude:
continue
anno = info.annotation
if anno is None:
d[camel_to_snake(field)] = v[i]
continue
d[field] = MsgpackProtocol.validate_object(v[i], anno)
return d
return v
@staticmethod
def _encode_varint(value: int) -> bytes:
result = []
@@ -142,6 +203,49 @@ class MsgpackProtocol:
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@classmethod
def validate_object(cls, v: Any, typ: type) -> Any:
if issubclass(typ, BaseModel):
return typ.model_validate(obj=cls.process_object(v, typ))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return v[0]
elif isinstance(v, list):
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
elif get_origin(typ) is dict:
return {
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
v, get_args(typ)[1]
)
for k, v in v.items()
}
elif (origin := get_origin(typ)) is Union or origin is UnionType:
args = get_args(typ)
if len(args) == 2 and NoneType in args:
non_none_args = [arg for arg in args if arg is not NoneType]
if len(non_none_args) == 1:
if v is None:
return None
return cls.validate_object(v, non_none_args[0])
# suppose use `MessagePack-CSharp Union | None`
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(issubclass(arg, SignalRUnionMessage) for arg in args):
raise ValueError(
f"Cannot validate {v} to {typ}, "
"only SignalRUnionMessage subclasses are supported"
)
union_type = v[0]
for arg in args:
assert issubclass(arg, SignalRUnionMessage)
if arg.union_type == union_type:
return cls.validate_object(v[1], arg)
return v
@staticmethod
def encode(packet: Packet) -> bytes:
payload = [packet.type.value, packet.header or {}]
@@ -153,7 +257,9 @@ class MsgpackProtocol:
]
)
if packet.arguments is not None:
payload.append([serialize_msgpack(arg) for arg in packet.arguments])
payload.append(
[MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments]
)
if packet.stream_ids is not None:
payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket):
@@ -166,7 +272,9 @@ class MsgpackProtocol:
[
packet.invocation_id,
result_kind,
packet.error or packet.result or None,
packet.error
or MsgpackProtocol.serialize_msgpack(packet.result)
or None,
]
)
elif isinstance(packet, ClosePacket):
@@ -183,6 +291,62 @@ class MsgpackProtocol:
class JSONProtocol:
@classmethod
def serialize_to_json(cls, v: Any):
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_model(v)
elif isinstance(v, dict):
return {
cls.serialize_to_json(k): cls.serialize_to_json(value)
for k, value in v.items()
}
elif isinstance(v, list):
return [cls.serialize_to_json(item) for item in v]
elif isinstance(v, datetime.datetime):
return v.isoformat()
elif isinstance(v, Enum):
return v.value
return v
@classmethod
def serialize_model(cls, v: BaseModel) -> dict[str, Any]:
d = {}
for field, info in v.__class__.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.json_ignore:
continue
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = (
cls.serialize_to_json(getattr(v, field))
)
if issubclass(v.__class__, SignalRUnionMessage):
return {
"$dtype": v.__class__.__name__,
"$value": d,
}
return d
@staticmethod
def process_object(
v: Any, typ: type[BaseModel], from_union: bool = False
) -> dict[str, Any]:
d = {}
for field, info in typ.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.json_ignore:
continue
value = v.get(snake_to_camel(field, not from_union))
anno = typ.model_fields[field].annotation
if anno is None:
d[field] = value
continue
d[field] = JSONProtocol.validate_object(value, anno)
return d
@staticmethod
def decode(input: bytes) -> list[Packet]:
packets_raw = input.removesuffix(SEP).split(SEP)
@@ -227,6 +391,52 @@ class JSONProtocol:
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@classmethod
def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
if issubclass(typ, BaseModel):
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return datetime.datetime.fromisoformat(v)
elif isinstance(v, list):
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
elif get_origin(typ) is dict:
return {
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
v, get_args(typ)[1]
)
for k, v in v.items()
}
elif (origin := get_origin(typ)) is Union or origin is UnionType:
args = get_args(typ)
if len(args) == 2 and NoneType in args:
non_none_args = [arg for arg in args if arg is not NoneType]
if len(non_none_args) == 1:
if v is None:
return None
return cls.validate_object(v, non_none_args[0])
# suppose use `MessagePack-CSharp Union | None`
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
):
raise ValueError(
f"Cannot validate {v} to {typ}, "
"only SignalRUnionMessage subclasses are supported"
)
# https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
union_type = v["$dtype"]
for arg in args:
assert issubclass(arg, SignalRUnionMessage)
if arg.__name__ == union_type:
return cls.validate_object(v["$value"], arg, True)
return v
@staticmethod
def encode(packet: Packet) -> bytes:
payload: dict[str, Any] = {
@@ -243,7 +453,9 @@ class JSONProtocol:
if packet.invocation_id is not None:
payload["invocationId"] = packet.invocation_id
if packet.arguments is not None:
payload["arguments"] = packet.arguments
payload["arguments"] = [
JSONProtocol.serialize_to_json(arg) for arg in packet.arguments
]
if packet.stream_ids is not None:
payload["streamIds"] = packet.stream_ids
elif isinstance(packet, CompletionPacket):
@@ -255,7 +467,7 @@ class JSONProtocol:
if packet.error is not None:
payload["error"] = packet.error
if packet.result is not None:
payload["result"] = packet.result
payload["result"] = JSONProtocol.serialize_to_json(packet.result)
elif isinstance(packet, PingPacket):
pass
elif isinstance(packet, ClosePacket):