feat(signalr): support json & msgpack protocol for all hubs

This commit is contained in:
MingxuanGame
2025-08-03 09:45:04 +00:00
parent 0f1a57afba
commit 9f7ab81213
13 changed files with 432 additions and 307 deletions

View File

@@ -1,114 +1,85 @@
from __future__ import annotations
from enum import IntEnum
from typing import Any, Literal
from typing import ClassVar, Literal
from app.models.signalr import UserState
from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field
class _UserActivity(BaseModel):
model_config = ConfigDict(serialize_by_alias=True)
type: Literal[
"ChoosingBeatmap",
"InSoloGame",
"WatchingReplay",
"SpectatingUser",
"SearchingForLobby",
"InLobby",
"InMultiplayerGame",
"SpectatingMultiplayerGame",
"InPlaylistGame",
"EditingBeatmap",
"ModdingBeatmap",
"TestingBeatmap",
"InDailyChallengeLobby",
"PlayingDailyChallenge",
] = Field(alias="$dtype")
value: Any | None = Field(alias="$value")
class _UserActivity(SignalRUnionMessage): ...
class ChoosingBeatmap(_UserActivity):
type: Literal["ChoosingBeatmap"] = Field(alias="$dtype")
class InGameValue(BaseModel):
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
ruleset_id: int = Field(alias="RulesetID")
ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb")
union_type: ClassVar[Literal[11]] = 11
class _InGame(_UserActivity):
value: InGameValue = Field(alias="$value")
beatmap_id: int
beatmap_display_title: str
ruleset_id: int
ruleset_playing_verb: str
class InSoloGame(_InGame):
type: Literal["InSoloGame"] = Field(alias="$dtype")
union_type: ClassVar[Literal[12]] = 12
class InMultiplayerGame(_InGame):
type: Literal["InMultiplayerGame"] = Field(alias="$dtype")
union_type: ClassVar[Literal[23]] = 23
class SpectatingMultiplayerGame(_InGame):
type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype")
union_type: ClassVar[Literal[24]] = 24
class InPlaylistGame(_InGame):
type: Literal["InPlaylistGame"] = Field(alias="$dtype")
union_type: ClassVar[Literal[31]] = 31
class EditingBeatmapValue(BaseModel):
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
class PlayingDailyChallenge(_InGame):
union_type: ClassVar[Literal[52]] = 52
class EditingBeatmap(_UserActivity):
type: Literal["EditingBeatmap"] = Field(alias="$dtype")
value: EditingBeatmapValue = Field(alias="$value")
union_type: ClassVar[Literal[41]] = 41
beatmap_id: int
beatmap_display_title: str
class TestingBeatmap(_UserActivity):
type: Literal["TestingBeatmap"] = Field(alias="$dtype")
class TestingBeatmap(EditingBeatmap):
union_type: ClassVar[Literal[43]] = 43
class ModdingBeatmap(_UserActivity):
type: Literal["ModdingBeatmap"] = Field(alias="$dtype")
class WatchingReplayValue(BaseModel):
score_id: int = Field(alias="ScoreID")
player_name: str = Field(alias="PlayerName")
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
class ModdingBeatmap(EditingBeatmap):
union_type: ClassVar[Literal[42]] = 42
class WatchingReplay(_UserActivity):
type: Literal["WatchingReplay"] = Field(alias="$dtype")
value: int | None = Field(alias="$value") # Replay ID
union_type: ClassVar[Literal[13]] = 13
score_id: int
player_name: str
beatmap_id: int
beatmap_display_title: str
class SpectatingUser(WatchingReplay):
type: Literal["SpectatingUser"] = Field(alias="$dtype")
union_type: ClassVar[Literal[14]] = 14
class SearchingForLobby(_UserActivity):
type: Literal["SearchingForLobby"] = Field(alias="$dtype")
class InLobbyValue(BaseModel):
room_id: int = Field(alias="RoomID")
room_name: str = Field(alias="RoomName")
union_type: ClassVar[Literal[21]] = 21
class InLobby(_UserActivity):
type: Literal["InLobby"] = "InLobby"
union_type: ClassVar[Literal[22]] = 22
room_id: int
room_name: str
class InDailyChallengeLobby(_UserActivity):
type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype")
union_type: ClassVar[Literal[51]] = 51
UserActivity = (
@@ -128,23 +99,28 @@ UserActivity = (
)
class MetadataClientState(UserState):
user_activity: UserActivity | None = None
status: OnlineStatus | None = None
def to_dict(self) -> dict[str, Any] | None:
if self.status is None or self.status == OnlineStatus.OFFLINE:
return None
dumped = self.model_dump(by_alias=True, exclude_none=True)
return {
"Activity": dumped.get("user_activity"),
"Status": dumped.get("status"),
}
class UserPresence(BaseModel):
activity: UserActivity | None = Field(
default=None, metadata=SignalRMeta(use_upper_case=True)
)
status: OnlineStatus | None = Field(
default=None, metadata=SignalRMeta(use_upper_case=True)
)
@property
def pushable(self) -> bool:
return self.status is not None and self.status != OnlineStatus.OFFLINE
@property
def for_push(self) -> "UserPresence | None":
return UserPresence(
activity=self.activity,
status=self.status,
)
class MetadataClientState(UserPresence, UserState): ...
class OnlineStatus(IntEnum):
OFFLINE = 0 # 隐身

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from enum import Enum, IntEnum
from enum import Enum
from typing import Literal, TypedDict
from .mods import API_MODS, APIMod, init_mods
@@ -93,43 +93,6 @@ class HitResult(str, Enum):
)
class HitResultInt(IntEnum):
PERFECT = 0
GREAT = 1
GOOD = 2
OK = 3
MEH = 4
MISS = 5
LARGE_TICK_HIT = 6
SMALL_TICK_HIT = 7
SLIDER_TAIL_HIT = 8
LARGE_BONUS = 9
SMALL_BONUS = 10
LARGE_TICK_MISS = 11
SMALL_TICK_MISS = 12
IGNORE_HIT = 13
IGNORE_MISS = 14
NONE = 15
COMBO_BREAK = 16
LEGACY_COMBO_INCREASE = 99
def is_hit(self) -> bool:
return self not in (
HitResultInt.NONE,
HitResultInt.IGNORE_MISS,
HitResultInt.COMBO_BREAK,
HitResultInt.LARGE_TICK_MISS,
HitResultInt.SMALL_TICK_MISS,
HitResultInt.MISS,
)
class LeaderboardType(Enum):
GLOBAL = "global"
FRIENDS = "friend"
@@ -138,7 +101,6 @@ class LeaderboardType(Enum):
ScoreStatistics = dict[HitResult, int]
ScoreStatisticsInt = dict[HitResultInt, int]
class SoloScoreSubmissionInfo(BaseModel):
@@ -176,8 +138,8 @@ class SoloScoreSubmissionInfo(BaseModel):
class LegacyReplaySoloScoreInfo(TypedDict):
online_id: int
mods: list[APIMod]
statistics: ScoreStatisticsInt
maximum_statistics: ScoreStatisticsInt
statistics: ScoreStatistics
maximum_statistics: ScoreStatistics
client_version: str
rank: Rank
user_id: int

View File

@@ -1,41 +1,21 @@
from __future__ import annotations
import datetime
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import Any, ClassVar
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
TypeAdapter,
model_serializer,
model_validator,
)
def serialize_msgpack(v: Any) -> Any:
typ = v.__class__
if issubclass(typ, BaseModel):
return serialize_to_list(v)
elif issubclass(typ, list):
return TypeAdapter(
typ, config=ConfigDict(arbitrary_types_allowed=True)
).dump_python(v)
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif issubclass(typ, Enum):
list_ = list(typ)
return list_.index(v) if v in list_ else v.value
return v
def serialize_to_list(value: BaseModel) -> list[Any]:
data = []
for field, info in value.__class__.model_fields.items():
data.append(serialize_msgpack(v=getattr(value, field)))
return data
@dataclass
class SignalRMeta:
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
json_ignore: bool = False # implement of JsonIgnore (json) attribute
use_upper_case: bool = False # use upper CamelCase for field names
def _by_index(v: Any, class_: type[Enum]):
@@ -54,37 +34,8 @@ def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator:
return BeforeValidator(lambda v: _by_index(v, enum_class))
def msgpack_union(v):
data = v[1]
data.append(v[0])
return data
def msgpack_union_dump(v: BaseModel) -> list[Any]:
_type = getattr(v, "type", None)
if _type is None:
raise ValueError(
f"Model {v.__class__.__name__} does not have a '_type' attribute"
)
return [_type, serialize_to_list(v)]
class MessagePackArrayModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="before")
@classmethod
def unpack(cls, v: Any) -> Any:
if isinstance(v, list):
fields = list(cls.model_fields.keys())
if len(v) != len(fields):
raise ValueError(f"Expected list of length {len(fields)}, got {len(v)}")
return dict(zip(fields, v))
return v
@model_serializer
def serialize(self) -> list[Any]:
return serialize_to_list(self)
class SignalRUnionMessage(BaseModel):
union_type: ClassVar[int]
class Transport(BaseModel):

View File

@@ -5,14 +5,14 @@ from enum import IntEnum
from typing import Any
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from .score import (
ScoreStatisticsInt,
ScoreStatistics,
)
from .signalr import MessagePackArrayModel, UserState
from .signalr import SignalRMeta, UserState
from msgpack_lazer_api import APIMod
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Field, field_validator
class SpectatedUserState(IntEnum):
@@ -24,14 +24,12 @@ class SpectatedUserState(IntEnum):
Quit = 5
class SpectatorState(MessagePackArrayModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class SpectatorState(BaseModel):
beatmap_id: int | None = None
ruleset_id: int | None = None # 0,1,2,3
mods: list[APIMod] = Field(default_factory=list)
state: SpectatedUserState
maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict)
maximum_statistics: ScoreStatistics = Field(default_factory=dict)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SpectatorState):
@@ -44,22 +42,20 @@ class SpectatorState(MessagePackArrayModel):
)
class ScoreProcessorStatistics(MessagePackArrayModel):
base_score: int
maximum_base_score: int
class ScoreProcessorStatistics(BaseModel):
base_score: float
maximum_base_score: float
accuracy_judgement_count: int
combo_portion: float
bouns_portion: float
bonus_portion: float
class FrameHeader(MessagePackArrayModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class FrameHeader(BaseModel):
total_score: int
acc: float
accuracy: float
combo: int
max_combo: int
statistics: ScoreStatisticsInt = Field(default_factory=dict)
statistics: ScoreStatistics = Field(default_factory=dict)
score_processor_statistics: ScoreProcessorStatistics
received_time: datetime.datetime
mods: list[APIMod] = Field(default_factory=list)
@@ -87,14 +83,18 @@ class FrameHeader(MessagePackArrayModel):
# SMOKE = 16
class LegacyReplayFrame(MessagePackArrayModel):
class LegacyReplayFrame(BaseModel):
time: float # from ReplayFrame,the parent of LegacyReplayFrame
x: float | None = None
y: float | None = None
mouse_x: float | None = None
mouse_y: float | None = None
button_state: int
header: FrameHeader | None = Field(
default=None, metadata=[SignalRMeta(member_ignore=True)]
)
class FrameDataBundle(MessagePackArrayModel):
class FrameDataBundle(BaseModel):
header: FrameHeader
frames: list[LegacyReplayFrame]
@@ -106,18 +106,16 @@ class APIUser(BaseModel):
class ScoreInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
mods: list[APIMod]
user: APIUser
ruleset: int
maximum_statistics: ScoreStatisticsInt
maximum_statistics: ScoreStatistics
id: int | None = None
total_score: int | None = None
acc: float | None = None
accuracy: float | None = None
max_combo: int | None = None
combo: int | None = None
statistics: ScoreStatisticsInt = Field(default_factory=dict)
statistics: ScoreStatistics = Field(default_factory=dict)
class StoreScore(BaseModel):

View File

@@ -2,14 +2,12 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from enum import Enum
import inspect
import time
from typing import Any
from app.config import settings
from app.log import logger
from app.models.signalr import UserState, _by_index
from app.models.signalr import UserState
from app.signalr.exception import InvokeException
from app.signalr.packet import (
ClosePacket,
@@ -23,7 +21,6 @@ from app.signalr.store import ResultStore
from app.signalr.utils import get_signature
from fastapi import WebSocket
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
@@ -51,7 +48,7 @@ class Client:
self.connection_id = connection_id
self.connection_token = connection_token
self.connection = connection
self.procotol = protocol
self.protocol = protocol
self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None
self._store = ResultStore()
@@ -64,14 +61,14 @@ class Client:
return int(self.connection_id)
async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.procotol.encode(packet))
await self.connection.send_bytes(self.protocol.encode(packet))
async def receive_packets(self) -> list[Packet]:
message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
return []
return self.procotol.decode(d)
return self.protocol.decode(d)
async def _ping(self):
while True:
@@ -265,14 +262,9 @@ class Hub[TState: UserState]:
for name, param in signature.parameters.items():
if name == "self" or param.annotation is Client:
continue
if issubclass(param.annotation, BaseModel):
call_params.append(param.annotation.model_validate(args.pop(0)))
elif inspect.isclass(param.annotation) and issubclass(
param.annotation, Enum
):
call_params.append(_by_index(args.pop(0), param.annotation))
else:
call_params.append(args.pop(0))
call_params.append(
client.protocol.validate_object(args.pop(0), param.annotation)
)
return await method_(client, *call_params)
async def call(self, client: Client, method: str, *args: Any) -> Any:

View File

@@ -12,7 +12,6 @@ from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActiv
from .hub import Client, Hub
from pydantic import TypeAdapter
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -32,7 +31,7 @@ class MetadataHub(Hub[MetadataClientState]):
) -> set[Coroutine]:
if store is not None and not store.pushable:
return set()
data = store.to_dict() if store else None
data = store.for_push if store else None
return {
self.broadcast_group_call(
self.online_presence_watchers_group(),
@@ -103,7 +102,7 @@ class MetadataHub(Hub[MetadataClientState]):
self.friend_presence_watchers_group(friend_id),
"FriendPresenceUpdated",
friend_id,
friend_state.to_dict(),
friend_state if friend_state.pushable else None,
)
)
await asyncio.gather(*tasks)
@@ -123,27 +122,24 @@ class MetadataHub(Hub[MetadataClientState]):
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
store.for_push,
)
)
await asyncio.gather(*tasks)
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
async def UpdateActivity(
self, client: Client, activity: UserActivity | None
) -> None:
user_id = int(client.connection_id)
activity = (
TypeAdapter(UserActivity).validate_python(activity_dict)
if activity_dict
else None
)
store = self.get_or_create_state(client)
store.user_activity = activity
store.activity = activity
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
store.for_push,
)
)
await asyncio.gather(*tasks)
@@ -155,7 +151,7 @@ class MetadataHub(Hub[MetadataClientState]):
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
store,
)
for user_id, store in self.state.items()
if store.pushable

View File

@@ -13,8 +13,7 @@ from app.database.score_token import ScoreToken
from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
from app.models.signalr import serialize_to_list
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
from app.models.spectator_hub import (
APIUser,
FrameDataBundle,
@@ -69,8 +68,8 @@ def save_replay(
md5: str,
username: str,
score: Score,
statistics: ScoreStatisticsInt,
maximum_statistics: ScoreStatisticsInt,
statistics: ScoreStatistics,
maximum_statistics: ScoreStatistics,
frames: list[LegacyReplayFrame],
) -> None:
data = bytearray()
@@ -107,8 +106,8 @@ def save_replay(
last_time = 0
for frame in frames:
frame_strs.append(
f"{frame.time - last_time}|{frame.x or 0.0}"
f"|{frame.y or 0.0}|{frame.button_state}"
f"{frame.time - last_time}|{frame.mouse_x or 0.0}"
f"|{frame.mouse_y or 0.0}|{frame.button_state}"
)
last_time = frame.time
frame_strs.append("-12345|0|0|0")
@@ -165,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]):
async def on_client_connect(self, client: Client) -> None:
tasks = [
self.call_noblock(
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
)
self.call_noblock(client, "UserBeganPlaying", user_id, store.state)
for user_id, store in self.state.items()
if store.state is not None
]
@@ -214,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserBeganPlaying",
user_id,
serialize_to_list(state),
state,
)
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
@@ -222,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]):
state = self.get_or_create_state(client)
if not state.score:
return
state.score.score_info.acc = frame_data.header.acc
state.score.score_info.accuracy = frame_data.header.accuracy
state.score.score_info.combo = frame_data.header.combo
state.score.score_info.max_combo = frame_data.header.max_combo
state.score.score_info.statistics = frame_data.header.statistics
@@ -233,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserSentFrames",
user_id,
frame_data.model_dump(),
frame_data,
)
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
@@ -316,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id),
"UserFinishedPlaying",
user_id,
serialize_to_list(state) if state else None,
state,
)
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
@@ -327,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]):
client,
"UserBeganPlaying",
target_id,
serialize_to_list(target_store.state),
target_store.state,
)
store = self.get_or_create_state(client)
store.watched_user.add(target_id)

View File

@@ -1,16 +1,24 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
import datetime
from enum import Enum, IntEnum
import inspect
import json
from types import NoneType, UnionType
from typing import (
Any,
Protocol as TypingProtocol,
Union,
get_args,
get_origin,
)
from app.models.signalr import serialize_msgpack
from app.models.signalr import SignalRMeta, SignalRUnionMessage
from app.utils import camel_to_snake, snake_to_camel
import msgpack_lazer_api as m
from pydantic import BaseModel
SEP = b"\x1e"
@@ -75,8 +83,61 @@ class Protocol(TypingProtocol):
@staticmethod
def encode(packet: Packet) -> bytes: ...
@classmethod
def validate_object(cls, v: Any, typ: type) -> Any: ...
class MsgpackProtocol:
@classmethod
def serialize_msgpack(cls, v: Any) -> Any:
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_to_list(v)
elif issubclass(typ, list):
return [cls.serialize_msgpack(item) for item in v]
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif isinstance(v, dict):
return {
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
for k, value in v.items()
}
elif issubclass(typ, Enum):
list_ = list(typ)
return list_.index(v) if v in list_ else v.value
return v
@classmethod
def serialize_to_list(cls, value: BaseModel) -> list[Any]:
values = []
for field, info in value.__class__.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.member_ignore:
continue
values.append(cls.serialize_msgpack(v=getattr(value, field)))
if issubclass(value.__class__, SignalRUnionMessage):
return [value.__class__.union_type, values]
else:
return values
@staticmethod
def process_object(v: Any, typ: type[BaseModel]) -> Any:
if isinstance(v, list):
d = {}
for i, f in enumerate(typ.model_fields.items()):
field, info = f
if info.exclude:
continue
anno = info.annotation
if anno is None:
d[camel_to_snake(field)] = v[i]
continue
d[field] = MsgpackProtocol.validate_object(v[i], anno)
return d
return v
@staticmethod
def _encode_varint(value: int) -> bytes:
result = []
@@ -142,6 +203,49 @@ class MsgpackProtocol:
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@classmethod
def validate_object(cls, v: Any, typ: type) -> Any:
if issubclass(typ, BaseModel):
return typ.model_validate(obj=cls.process_object(v, typ))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return v[0]
elif isinstance(v, list):
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
elif get_origin(typ) is dict:
return {
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
v, get_args(typ)[1]
)
for k, v in v.items()
}
elif (origin := get_origin(typ)) is Union or origin is UnionType:
args = get_args(typ)
if len(args) == 2 and NoneType in args:
non_none_args = [arg for arg in args if arg is not NoneType]
if len(non_none_args) == 1:
if v is None:
return None
return cls.validate_object(v, non_none_args[0])
# suppose use `MessagePack-CSharp Union | None`
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(issubclass(arg, SignalRUnionMessage) for arg in args):
raise ValueError(
f"Cannot validate {v} to {typ}, "
"only SignalRUnionMessage subclasses are supported"
)
union_type = v[0]
for arg in args:
assert issubclass(arg, SignalRUnionMessage)
if arg.union_type == union_type:
return cls.validate_object(v[1], arg)
return v
@staticmethod
def encode(packet: Packet) -> bytes:
payload = [packet.type.value, packet.header or {}]
@@ -153,7 +257,9 @@ class MsgpackProtocol:
]
)
if packet.arguments is not None:
payload.append([serialize_msgpack(arg) for arg in packet.arguments])
payload.append(
[MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments]
)
if packet.stream_ids is not None:
payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket):
@@ -166,7 +272,9 @@ class MsgpackProtocol:
[
packet.invocation_id,
result_kind,
packet.error or packet.result or None,
packet.error
or MsgpackProtocol.serialize_msgpack(packet.result)
or None,
]
)
elif isinstance(packet, ClosePacket):
@@ -183,6 +291,62 @@ class MsgpackProtocol:
class JSONProtocol:
@classmethod
def serialize_to_json(cls, v: Any):
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_model(v)
elif isinstance(v, dict):
return {
cls.serialize_to_json(k): cls.serialize_to_json(value)
for k, value in v.items()
}
elif isinstance(v, list):
return [cls.serialize_to_json(item) for item in v]
elif isinstance(v, datetime.datetime):
return v.isoformat()
elif isinstance(v, Enum):
return v.value
return v
@classmethod
def serialize_model(cls, v: BaseModel) -> dict[str, Any]:
d = {}
for field, info in v.__class__.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.json_ignore:
continue
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = (
cls.serialize_to_json(getattr(v, field))
)
if issubclass(v.__class__, SignalRUnionMessage):
return {
"$dtype": v.__class__.__name__,
"$value": d,
}
return d
@staticmethod
def process_object(
v: Any, typ: type[BaseModel], from_union: bool = False
) -> dict[str, Any]:
d = {}
for field, info in typ.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.json_ignore:
continue
value = v.get(snake_to_camel(field, not from_union))
anno = typ.model_fields[field].annotation
if anno is None:
d[field] = value
continue
d[field] = JSONProtocol.validate_object(value, anno)
return d
@staticmethod
def decode(input: bytes) -> list[Packet]:
packets_raw = input.removesuffix(SEP).split(SEP)
@@ -227,6 +391,52 @@ class JSONProtocol:
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@classmethod
def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
if issubclass(typ, BaseModel):
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return datetime.datetime.fromisoformat(v)
elif isinstance(v, list):
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
elif get_origin(typ) is dict:
return {
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
v, get_args(typ)[1]
)
for k, v in v.items()
}
elif (origin := get_origin(typ)) is Union or origin is UnionType:
args = get_args(typ)
if len(args) == 2 and NoneType in args:
non_none_args = [arg for arg in args if arg is not NoneType]
if len(non_none_args) == 1:
if v is None:
return None
return cls.validate_object(v, non_none_args[0])
# suppose use `MessagePack-CSharp Union | None`
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
):
raise ValueError(
f"Cannot validate {v} to {typ}, "
"only SignalRUnionMessage subclasses are supported"
)
# https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
union_type = v["$dtype"]
for arg in args:
assert issubclass(arg, SignalRUnionMessage)
if arg.__name__ == union_type:
return cls.validate_object(v["$value"], arg, True)
return v
@staticmethod
def encode(packet: Packet) -> bytes:
payload: dict[str, Any] = {
@@ -243,7 +453,9 @@ class JSONProtocol:
if packet.invocation_id is not None:
payload["invocationId"] = packet.invocation_id
if packet.arguments is not None:
payload["arguments"] = packet.arguments
payload["arguments"] = [
JSONProtocol.serialize_to_json(arg) for arg in packet.arguments
]
if packet.stream_ids is not None:
payload["streamIds"] = packet.stream_ids
elif isinstance(packet, CompletionPacket):
@@ -255,7 +467,7 @@ class JSONProtocol:
if packet.error is not None:
payload["error"] = packet.error
if packet.result is not None:
payload["result"] = packet.result
payload["result"] = JSONProtocol.serialize_to_json(packet.result)
elif isinstance(packet, PingPacket):
pass
elif isinstance(packet, ClosePacket):

View File

@@ -4,3 +4,55 @@ from __future__ import annotations
def unix_timestamp_to_windows(timestamp: int) -> int:
"""Convert a Unix timestamp to a Windows timestamp."""
return (timestamp + 62135596800) * 10_000_000
def camel_to_snake(name: str) -> str:
"""Convert a camelCase string to snake_case."""
result = []
last_chr = ""
for char in name:
if char.isupper():
if not last_chr.isupper() and result:
result.append("_")
result.append(char.lower())
else:
result.append(char)
last_chr = char
return "".join(result)
def snake_to_camel(name: str, lower_case: bool = True) -> str:
"""Convert a snake_case string to camelCase."""
if not name:
return name
parts = name.split("_")
if not parts:
return name
# 常见缩写词列表
abbreviations = {
"id",
"url",
"api",
"http",
"https",
"xml",
"json",
"css",
"html",
"sql",
"db",
}
result = []
for part in parts:
if part.lower() in abbreviations:
result.append(part.upper())
else:
if result or not lower_case:
result.append(part.capitalize())
else:
result.append(part.lower())
return "".join(result)

View File

@@ -1,11 +1,4 @@
from typing import Any
class APIMod:
def __init__(self, acronym: str, settings: dict[str, Any]) -> None: ...
@property
def acronym(self) -> str: ...
@property
def settings(self) -> dict[str, Any]: ...
def encode(obj: Any) -> bytes: ...
def decode(data: bytes) -> Any: ...

View File

@@ -1,8 +1,6 @@
use crate::APIMod;
use chrono::{TimeZone, Utc};
use pyo3::types::PyDict;
use pyo3::{prelude::*, IntoPyObjectExt};
use std::collections::HashMap;
use std::io::Read;
pub fn read_object(
@@ -206,13 +204,12 @@ fn read_array(
let obj1 = read_object(py, cursor, false)?;
if obj1.extract::<String>(py).map_or(false, |k| k.len() == 2) {
let obj2 = read_object(py, cursor, true)?;
return Ok(APIMod {
acronym: obj1.extract::<String>(py)?,
settings: obj2.extract::<HashMap<String, PyObject>>(py)?,
}
.into_pyobject(py)?
.into_any()
.unbind());
let api_mod_dict = PyDict::new(py);
api_mod_dict.set_item("acronym", obj1)?;
api_mod_dict.set_item("settings", obj2)?;
return Ok(api_mod_dict.into_pyobject(py)?.into_any().unbind());
} else {
items.push(obj1);
i += 1;

View File

@@ -1,8 +1,7 @@
use crate::APIMod;
use chrono::{DateTime, Utc};
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyStringMethods};
use chrono::{DateTime, Utc};
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyResult, PyStringMethods};
use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString};
use pyo3::{Bound, PyAny, PyRef, Python};
use pyo3::{Bound, PyAny};
use std::io::Write;
fn write_list(buf: &mut Vec<u8>, obj: &Bound<'_, PyList>) {
@@ -61,19 +60,42 @@ fn write_hashmap(buf: &mut Vec<u8>, obj: &Bound<'_, PyDict>) {
}
}
fn write_nil(buf: &mut Vec<u8>){
fn write_nil(buf: &mut Vec<u8>) {
rmp::encode::write_nil(buf).unwrap();
}
// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
fn write_api_mod(buf: &mut Vec<u8>, api_mod: PyRef<APIMod>) {
rmp::encode::write_array_len(buf, 2).unwrap();
rmp::encode::write_str(buf, &api_mod.acronym).unwrap();
rmp::encode::write_array_len(buf, api_mod.settings.len() as u32).unwrap();
for (k, v) in api_mod.settings.iter() {
rmp::encode::write_str(buf, k).unwrap();
Python::with_gil(|py| write_object(buf, &v.bind(py)));
fn is_api_mod(dict: &Bound<'_, PyDict>) -> bool {
if let Ok(Some(acronym)) = dict.get_item("acronym") {
if let Ok(acronym_str) = acronym.extract::<String>() {
return acronym_str.len() == 2;
}
}
false
}
// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
fn write_api_mod(buf: &mut Vec<u8>, api_mod: &Bound<'_, PyDict>) -> PyResult<()> {
let acronym = api_mod
.get_item("acronym")?
.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("APIMod missing 'acronym' field"))?;
let acronym_str = acronym.extract::<String>()?;
let settings = api_mod
.get_item("settings")?
.unwrap_or_else(|| PyDict::new(acronym.py()).into_any());
let settings_dict = settings.downcast::<PyDict>()?;
rmp::encode::write_array_len(buf, 2).unwrap();
rmp::encode::write_str(buf, &acronym_str).unwrap();
rmp::encode::write_array_len(buf, settings_dict.len() as u32).unwrap();
for (k, v) in settings_dict.iter() {
let key_str = k.extract::<String>()?;
rmp::encode::write_str(buf, &key_str).unwrap();
write_object(buf, &v);
}
Ok(())
}
fn write_datetime(buf: &mut Vec<u8>, obj: &Bound<'_, PyDateTime>) {
@@ -111,21 +133,23 @@ pub fn write_object(buf: &mut Vec<u8>, obj: &Bound<'_, PyAny>) {
} else if let Ok(string) = obj.downcast::<PyString>() {
write_string(buf, string);
} else if let Ok(boolean) = obj.downcast::<PyBool>() {
write_bool(buf, boolean);
write_bool(buf, boolean);
} else if let Ok(float) = obj.downcast::<PyFloat>() {
write_float(buf, float);
write_float(buf, float);
} else if let Ok(integer) = obj.downcast::<PyInt>() {
write_integer(buf, integer);
write_integer(buf, integer);
} else if let Ok(bytes) = obj.downcast::<PyBytes>() {
write_bin(buf, bytes);
} else if let Ok(dict) = obj.downcast::<PyDict>() {
write_hashmap(buf, dict);
if is_api_mod(dict) {
write_api_mod(buf, dict).unwrap_or_else(|_| write_hashmap(buf, dict));
} else {
write_hashmap(buf, dict);
}
} else if let Ok(_none) = obj.downcast::<PyNone>() {
write_nil(buf);
} else if let Ok(datetime) = obj.downcast::<PyDateTime>() {
write_datetime(buf, datetime);
} else if let Ok(api_mod) = obj.extract::<PyRef<APIMod>>() {
write_api_mod(buf, api_mod);
} else {
panic!("Unsupported type");
}

View File

@@ -2,30 +2,6 @@ mod decode;
mod encode;
use pyo3::prelude::*;
use std::collections::HashMap;
#[pyclass]
struct APIMod {
#[pyo3(get, set)]
acronym: String,
#[pyo3(get, set)]
settings: HashMap<String, PyObject>,
}
#[pymethods]
impl APIMod {
#[new]
fn new(acronym: String, settings: HashMap<String, PyObject>) -> Self {
APIMod { acronym, settings }
}
fn __repr__(&self) -> String {
format!(
"APIMod(acronym='{}', settings={:?})",
self.acronym, self.settings
)
}
}
#[pyfunction]
#[pyo3(name = "encode")]
@@ -46,6 +22,5 @@ fn decode_py(py: Python, data: &[u8]) -> PyResult<PyObject> {
fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(encode_py, m)?)?;
m.add_function(wrap_pyfunction!(decode_py, m)?)?;
m.add_class::<APIMod>()?;
Ok(())
}