From a11ea743a71ffe14388bf4c5cda7132fa5fc02c9 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 11:00:57 +0000 Subject: [PATCH 1/4] fix(signarl): wrong msgpack encode --- app/models/signalr.py | 46 +++++++++++++++++-- app/signalr/packet.py | 2 +- .../msgpack_lazer_api/msgpack_lazer_api.pyi | 2 +- packages/msgpack_lazer_api/src/decode.rs | 4 +- packages/msgpack_lazer_api/src/encode.rs | 10 ++-- 5 files changed, 50 insertions(+), 14 deletions(-) diff --git a/app/models/signalr.py b/app/models/signalr.py index 37b2741..202da4f 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -1,10 +1,12 @@ from __future__ import annotations import datetime -from typing import Any, get_origin +from enum import Enum +from typing import Any from pydantic import ( BaseModel, + BeforeValidator, ConfigDict, Field, TypeAdapter, @@ -17,22 +19,56 @@ def serialize_to_list(value: BaseModel) -> list[Any]: data = [] for field, info in value.__class__.model_fields.items(): v = getattr(value, field) - anno = get_origin(info.annotation) - if anno and issubclass(anno, BaseModel): + typ = v.__class__ + if issubclass(typ, BaseModel): data.append(serialize_to_list(v)) - elif anno and issubclass(anno, list): + elif issubclass(typ, list): data.append( TypeAdapter( info.annotation, config=ConfigDict(arbitrary_types_allowed=True) ).dump_python(v) ) - elif isinstance(v, datetime.datetime): + elif issubclass(typ, datetime.datetime): data.append([v, 0]) + elif issubclass(typ, Enum): + list_ = list(typ) + data.append(list_.index(v) if v in list_ else v.value) else: data.append(v) return data +def _by_index(v: Any, class_: type[Enum]): + enum_list = list(class_) + if not isinstance(v, int): + return v + if 0 <= v < len(enum_list): + return enum_list[v] + raise ValueError( + f"Value {v} is out of range for enum " + f"{class_.__name__} with {len(enum_list)} items" + ) + + +def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator: + return BeforeValidator(lambda v: _by_index(v, enum_class)) + + +def msgpack_union(v): + data = v[1] + data.append(v[0]) + return data + + +def msgpack_union_dump(v: BaseModel) -> list[Any]: + _type = getattr(v, "type", None) + if _type is None: + raise ValueError( + f"Model {v.__class__.__name__} does not have a '_type' attribute" + ) + return [_type, serialize_to_list(v)] + + class MessagePackArrayModel(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index e361ef8..387231c 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -158,7 +158,7 @@ class MsgpackProtocol: result_kind = 2 if packet.error: result_kind = 1 - elif packet.result is None: + elif packet.result is not None: result_kind = 3 payload.extend( [ diff --git a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi index 88b79c5..b8653f0 100644 --- a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi +++ b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi @@ -5,7 +5,7 @@ class APIMod: @property def acronym(self) -> str: ... @property - def settings(self) -> str: ... + def settings(self) -> dict[str, Any]: ... def encode(obj: Any) -> bytes: ... def decode(data: bytes) -> Any: ... diff --git a/packages/msgpack_lazer_api/src/decode.rs b/packages/msgpack_lazer_api/src/decode.rs index 15156ca..b8e239b 100644 --- a/packages/msgpack_lazer_api/src/decode.rs +++ b/packages/msgpack_lazer_api/src/decode.rs @@ -13,6 +13,8 @@ pub fn read_object( match rmp::decode::read_marker(cursor) { Ok(marker) => match marker { rmp::Marker::Null => Ok(py.None()), + rmp::Marker::True => Ok(true.into_py_any(py)?), + rmp::Marker::False => Ok(false.into_py_any(py)?), rmp::Marker::FixPos(val) => Ok(val.into_pyobject(py)?.into_any().unbind()), rmp::Marker::FixNeg(val) => Ok(val.into_pyobject(py)?.into_any().unbind()), rmp::Marker::U8 => { @@ -86,8 +88,6 @@ pub fn read_object( cursor.read_exact(&mut data).map_err(to_py_err)?; Ok(data.into_pyobject(py)?.into_any().unbind()) } - rmp::Marker::True => Ok(true.into_py_any(py)?), - rmp::Marker::False => Ok(false.into_py_any(py)?), rmp::Marker::FixStr(len) => read_string(py, cursor, len as u32), rmp::Marker::Str8 => { let mut buf = [0u8; 1]; diff --git a/packages/msgpack_lazer_api/src/encode.rs b/packages/msgpack_lazer_api/src/encode.rs index 88a732b..0e0907c 100644 --- a/packages/msgpack_lazer_api/src/encode.rs +++ b/packages/msgpack_lazer_api/src/encode.rs @@ -110,12 +110,12 @@ pub fn write_object(buf: &mut Vec, obj: &Bound<'_, PyAny>) { write_list(buf, list); } else if let Ok(string) = obj.downcast::() { write_string(buf, string); - } else if let Ok(integer) = obj.downcast::() { - write_integer(buf, integer); - } else if let Ok(float) = obj.downcast::() { - write_float(buf, float); } else if let Ok(boolean) = obj.downcast::() { - write_bool(buf, boolean); + write_bool(buf, boolean); + } else if let Ok(float) = obj.downcast::() { + write_float(buf, float); + } else if let Ok(integer) = obj.downcast::() { + write_integer(buf, integer); } else if let Ok(bytes) = obj.downcast::() { write_bin(buf, bytes); } else if let Ok(dict) = obj.downcast::() { From 5ccb35dc8be2ca234d5c04b1f3a11f8e95fea094 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 2 Aug 2025 14:59:12 +0000 Subject: [PATCH 2/4] fix(signalr): encode enum by index --- app/models/signalr.py | 34 +++++++++++++++++----------------- app/signalr/hub/hub.py | 8 +++++++- app/signalr/packet.py | 4 +++- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/app/models/signalr.py b/app/models/signalr.py index 202da4f..9e189e9 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -15,26 +15,26 @@ from pydantic import ( ) +def serialize_msgpack(v: Any) -> Any: + typ = v.__class__ + if issubclass(typ, BaseModel): + return serialize_to_list(v) + elif issubclass(typ, list): + return TypeAdapter( + typ, config=ConfigDict(arbitrary_types_allowed=True) + ).dump_python(v) + elif issubclass(typ, datetime.datetime): + return [v, 0] + elif issubclass(typ, Enum): + list_ = list(typ) + return list_.index(v) if v in list_ else v.value + return v + + def serialize_to_list(value: BaseModel) -> list[Any]: data = [] for field, info in value.__class__.model_fields.items(): - v = getattr(value, field) - typ = v.__class__ - if issubclass(typ, BaseModel): - data.append(serialize_to_list(v)) - elif issubclass(typ, list): - data.append( - TypeAdapter( - info.annotation, config=ConfigDict(arbitrary_types_allowed=True) - ).dump_python(v) - ) - elif issubclass(typ, datetime.datetime): - data.append([v, 0]) - elif issubclass(typ, Enum): - list_ = list(typ) - data.append(list_.index(v) if v in list_ else v.value) - else: - data.append(v) + data.append(serialize_msgpack(v=getattr(value, field))) return data diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index 276140f..a11fbe7 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -2,12 +2,14 @@ from __future__ import annotations from abc import abstractmethod import asyncio +from enum import Enum +import inspect import time from typing import Any from app.config import settings from app.log import logger -from app.models.signalr import UserState +from app.models.signalr import UserState, _by_index from app.signalr.exception import InvokeException from app.signalr.packet import ( ClosePacket, @@ -265,6 +267,10 @@ class Hub[TState: UserState]: continue if issubclass(param.annotation, BaseModel): call_params.append(param.annotation.model_validate(args.pop(0))) + elif inspect.isclass(param.annotation) and issubclass( + param.annotation, Enum + ): + call_params.append(_by_index(args.pop(0), param.annotation)) else: call_params.append(args.pop(0)) return await method_(client, *call_params) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 387231c..de5ce8a 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -8,6 +8,8 @@ from typing import ( Protocol as TypingProtocol, ) +from app.models.signalr import serialize_msgpack + import msgpack_lazer_api as m SEP = b"\x1e" @@ -151,7 +153,7 @@ class MsgpackProtocol: ] ) if packet.arguments is not None: - payload.append(packet.arguments) + payload.append([serialize_msgpack(arg) for arg in packet.arguments]) if packet.stream_ids is not None: payload.append(packet.stream_ids) elif isinstance(packet, CompletionPacket): From 0f1a57afba5b73339a817f075ae9e3141bc0d48b Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 2 Aug 2025 15:02:12 +0000 Subject: [PATCH 3/4] fix(user): last_visit is nullable --- app/database/lazer_user.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 3bd751b..2717c3a 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -66,7 +66,7 @@ class UserBase(UTCBaseModel, SQLModel): is_active: bool = True is_bot: bool = False is_supporter: bool = False - last_visit: datetime = Field( + last_visit: datetime | None = Field( default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)) ) pm_friends_only: bool = False From 9f7ab812134910abd5a905633547d4792833fb41 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 3 Aug 2025 09:45:04 +0000 Subject: [PATCH 4/4] feat(signalr): support json & msgpack protocol for all hubs --- app/models/metadata_hub.py | 124 ++++------ app/models/score.py | 44 +--- app/models/signalr.py | 67 +----- app/models/spectator_hub.py | 50 ++-- app/signalr/hub/hub.py | 22 +- app/signalr/hub/metadata.py | 22 +- app/signalr/hub/spectator.py | 25 +- app/signalr/packet.py | 224 +++++++++++++++++- app/utils.py | 52 ++++ .../msgpack_lazer_api/msgpack_lazer_api.pyi | 7 - packages/msgpack_lazer_api/src/decode.rs | 15 +- packages/msgpack_lazer_api/src/encode.rs | 62 +++-- packages/msgpack_lazer_api/src/lib.rs | 25 -- 13 files changed, 432 insertions(+), 307 deletions(-) diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py index 8ae3e65..3206d03 100644 --- a/app/models/metadata_hub.py +++ b/app/models/metadata_hub.py @@ -1,114 +1,85 @@ from __future__ import annotations from enum import IntEnum -from typing import Any, Literal +from typing import ClassVar, Literal -from app.models.signalr import UserState +from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field -class _UserActivity(BaseModel): - model_config = ConfigDict(serialize_by_alias=True) - type: Literal[ - "ChoosingBeatmap", - "InSoloGame", - "WatchingReplay", - "SpectatingUser", - "SearchingForLobby", - "InLobby", - "InMultiplayerGame", - "SpectatingMultiplayerGame", - "InPlaylistGame", - "EditingBeatmap", - "ModdingBeatmap", - "TestingBeatmap", - "InDailyChallengeLobby", - "PlayingDailyChallenge", - ] = Field(alias="$dtype") - value: Any | None = Field(alias="$value") +class _UserActivity(SignalRUnionMessage): ... class ChoosingBeatmap(_UserActivity): - type: Literal["ChoosingBeatmap"] = Field(alias="$dtype") - - -class InGameValue(BaseModel): - beatmap_id: int = Field(alias="BeatmapID") - beatmap_display_title: str = Field(alias="BeatmapDisplayTitle") - ruleset_id: int = Field(alias="RulesetID") - ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb") + union_type: ClassVar[Literal[11]] = 11 class _InGame(_UserActivity): - value: InGameValue = Field(alias="$value") + beatmap_id: int + beatmap_display_title: str + ruleset_id: int + ruleset_playing_verb: str class InSoloGame(_InGame): - type: Literal["InSoloGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[12]] = 12 class InMultiplayerGame(_InGame): - type: Literal["InMultiplayerGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[23]] = 23 class SpectatingMultiplayerGame(_InGame): - type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[24]] = 24 class InPlaylistGame(_InGame): - type: Literal["InPlaylistGame"] = Field(alias="$dtype") + union_type: ClassVar[Literal[31]] = 31 -class EditingBeatmapValue(BaseModel): - beatmap_id: int = Field(alias="BeatmapID") - beatmap_display_title: str = Field(alias="BeatmapDisplayTitle") +class PlayingDailyChallenge(_InGame): + union_type: ClassVar[Literal[52]] = 52 class EditingBeatmap(_UserActivity): - type: Literal["EditingBeatmap"] = Field(alias="$dtype") - value: EditingBeatmapValue = Field(alias="$value") + union_type: ClassVar[Literal[41]] = 41 + beatmap_id: int + beatmap_display_title: str -class TestingBeatmap(_UserActivity): - type: Literal["TestingBeatmap"] = Field(alias="$dtype") +class TestingBeatmap(EditingBeatmap): + union_type: ClassVar[Literal[43]] = 43 -class ModdingBeatmap(_UserActivity): - type: Literal["ModdingBeatmap"] = Field(alias="$dtype") - - -class WatchingReplayValue(BaseModel): - score_id: int = Field(alias="ScoreID") - player_name: str = Field(alias="PlayerName") - beatmap_id: int = Field(alias="BeatmapID") - beatmap_display_title: str = Field(alias="BeatmapDisplayTitle") +class ModdingBeatmap(EditingBeatmap): + union_type: ClassVar[Literal[42]] = 42 class WatchingReplay(_UserActivity): - type: Literal["WatchingReplay"] = Field(alias="$dtype") - value: int | None = Field(alias="$value") # Replay ID + union_type: ClassVar[Literal[13]] = 13 + score_id: int + player_name: str + beatmap_id: int + beatmap_display_title: str class SpectatingUser(WatchingReplay): - type: Literal["SpectatingUser"] = Field(alias="$dtype") + union_type: ClassVar[Literal[14]] = 14 class SearchingForLobby(_UserActivity): - type: Literal["SearchingForLobby"] = Field(alias="$dtype") - - -class InLobbyValue(BaseModel): - room_id: int = Field(alias="RoomID") - room_name: str = Field(alias="RoomName") + union_type: ClassVar[Literal[21]] = 21 class InLobby(_UserActivity): - type: Literal["InLobby"] = "InLobby" + union_type: ClassVar[Literal[22]] = 22 + room_id: int + room_name: str class InDailyChallengeLobby(_UserActivity): - type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype") + union_type: ClassVar[Literal[51]] = 51 UserActivity = ( @@ -128,23 +99,28 @@ UserActivity = ( ) -class MetadataClientState(UserState): - user_activity: UserActivity | None = None - status: OnlineStatus | None = None - - def to_dict(self) -> dict[str, Any] | None: - if self.status is None or self.status == OnlineStatus.OFFLINE: - return None - dumped = self.model_dump(by_alias=True, exclude_none=True) - return { - "Activity": dumped.get("user_activity"), - "Status": dumped.get("status"), - } +class UserPresence(BaseModel): + activity: UserActivity | None = Field( + default=None, metadata=SignalRMeta(use_upper_case=True) + ) + status: OnlineStatus | None = Field( + default=None, metadata=SignalRMeta(use_upper_case=True) + ) @property def pushable(self) -> bool: return self.status is not None and self.status != OnlineStatus.OFFLINE + @property + def for_push(self) -> "UserPresence | None": + return UserPresence( + activity=self.activity, + status=self.status, + ) + + +class MetadataClientState(UserPresence, UserState): ... + class OnlineStatus(IntEnum): OFFLINE = 0 # 隐身 diff --git a/app/models/score.py b/app/models/score.py index bfc9f53..cef6b28 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -1,6 +1,6 @@ from __future__ import annotations -from enum import Enum, IntEnum +from enum import Enum from typing import Literal, TypedDict from .mods import API_MODS, APIMod, init_mods @@ -93,43 +93,6 @@ class HitResult(str, Enum): ) -class HitResultInt(IntEnum): - PERFECT = 0 - GREAT = 1 - GOOD = 2 - OK = 3 - MEH = 4 - MISS = 5 - - LARGE_TICK_HIT = 6 - SMALL_TICK_HIT = 7 - SLIDER_TAIL_HIT = 8 - - LARGE_BONUS = 9 - SMALL_BONUS = 10 - - LARGE_TICK_MISS = 11 - SMALL_TICK_MISS = 12 - - IGNORE_HIT = 13 - IGNORE_MISS = 14 - - NONE = 15 - COMBO_BREAK = 16 - - LEGACY_COMBO_INCREASE = 99 - - def is_hit(self) -> bool: - return self not in ( - HitResultInt.NONE, - HitResultInt.IGNORE_MISS, - HitResultInt.COMBO_BREAK, - HitResultInt.LARGE_TICK_MISS, - HitResultInt.SMALL_TICK_MISS, - HitResultInt.MISS, - ) - - class LeaderboardType(Enum): GLOBAL = "global" FRIENDS = "friend" @@ -138,7 +101,6 @@ class LeaderboardType(Enum): ScoreStatistics = dict[HitResult, int] -ScoreStatisticsInt = dict[HitResultInt, int] class SoloScoreSubmissionInfo(BaseModel): @@ -176,8 +138,8 @@ class SoloScoreSubmissionInfo(BaseModel): class LegacyReplaySoloScoreInfo(TypedDict): online_id: int mods: list[APIMod] - statistics: ScoreStatisticsInt - maximum_statistics: ScoreStatisticsInt + statistics: ScoreStatistics + maximum_statistics: ScoreStatistics client_version: str rank: Rank user_id: int diff --git a/app/models/signalr.py b/app/models/signalr.py index 9e189e9..90ef95f 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -1,41 +1,21 @@ from __future__ import annotations -import datetime +from dataclasses import dataclass from enum import Enum -from typing import Any +from typing import Any, ClassVar from pydantic import ( BaseModel, BeforeValidator, - ConfigDict, Field, - TypeAdapter, - model_serializer, - model_validator, ) -def serialize_msgpack(v: Any) -> Any: - typ = v.__class__ - if issubclass(typ, BaseModel): - return serialize_to_list(v) - elif issubclass(typ, list): - return TypeAdapter( - typ, config=ConfigDict(arbitrary_types_allowed=True) - ).dump_python(v) - elif issubclass(typ, datetime.datetime): - return [v, 0] - elif issubclass(typ, Enum): - list_ = list(typ) - return list_.index(v) if v in list_ else v.value - return v - - -def serialize_to_list(value: BaseModel) -> list[Any]: - data = [] - for field, info in value.__class__.model_fields.items(): - data.append(serialize_msgpack(v=getattr(value, field))) - return data +@dataclass +class SignalRMeta: + member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute + json_ignore: bool = False # implement of JsonIgnore (json) attribute + use_upper_case: bool = False # use upper CamelCase for field names def _by_index(v: Any, class_: type[Enum]): @@ -54,37 +34,8 @@ def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator: return BeforeValidator(lambda v: _by_index(v, enum_class)) -def msgpack_union(v): - data = v[1] - data.append(v[0]) - return data - - -def msgpack_union_dump(v: BaseModel) -> list[Any]: - _type = getattr(v, "type", None) - if _type is None: - raise ValueError( - f"Model {v.__class__.__name__} does not have a '_type' attribute" - ) - return [_type, serialize_to_list(v)] - - -class MessagePackArrayModel(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - @model_validator(mode="before") - @classmethod - def unpack(cls, v: Any) -> Any: - if isinstance(v, list): - fields = list(cls.model_fields.keys()) - if len(v) != len(fields): - raise ValueError(f"Expected list of length {len(fields)}, got {len(v)}") - return dict(zip(fields, v)) - return v - - @model_serializer - def serialize(self) -> list[Any]: - return serialize_to_list(self) +class SignalRUnionMessage(BaseModel): + union_type: ClassVar[int] class Transport(BaseModel): diff --git a/app/models/spectator_hub.py b/app/models/spectator_hub.py index 994e083..a9e9042 100644 --- a/app/models/spectator_hub.py +++ b/app/models/spectator_hub.py @@ -5,14 +5,14 @@ from enum import IntEnum from typing import Any from app.models.beatmap import BeatmapRankStatus +from app.models.mods import APIMod from .score import ( - ScoreStatisticsInt, + ScoreStatistics, ) -from .signalr import MessagePackArrayModel, UserState +from .signalr import SignalRMeta, UserState -from msgpack_lazer_api import APIMod -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, Field, field_validator class SpectatedUserState(IntEnum): @@ -24,14 +24,12 @@ class SpectatedUserState(IntEnum): Quit = 5 -class SpectatorState(MessagePackArrayModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - +class SpectatorState(BaseModel): beatmap_id: int | None = None ruleset_id: int | None = None # 0,1,2,3 mods: list[APIMod] = Field(default_factory=list) state: SpectatedUserState - maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict) + maximum_statistics: ScoreStatistics = Field(default_factory=dict) def __eq__(self, other: object) -> bool: if not isinstance(other, SpectatorState): @@ -44,22 +42,20 @@ class SpectatorState(MessagePackArrayModel): ) -class ScoreProcessorStatistics(MessagePackArrayModel): - base_score: int - maximum_base_score: int +class ScoreProcessorStatistics(BaseModel): + base_score: float + maximum_base_score: float accuracy_judgement_count: int combo_portion: float - bouns_portion: float + bonus_portion: float -class FrameHeader(MessagePackArrayModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - +class FrameHeader(BaseModel): total_score: int - acc: float + accuracy: float combo: int max_combo: int - statistics: ScoreStatisticsInt = Field(default_factory=dict) + statistics: ScoreStatistics = Field(default_factory=dict) score_processor_statistics: ScoreProcessorStatistics received_time: datetime.datetime mods: list[APIMod] = Field(default_factory=list) @@ -87,14 +83,18 @@ class FrameHeader(MessagePackArrayModel): # SMOKE = 16 -class LegacyReplayFrame(MessagePackArrayModel): +class LegacyReplayFrame(BaseModel): time: float # from ReplayFrame,the parent of LegacyReplayFrame - x: float | None = None - y: float | None = None + mouse_x: float | None = None + mouse_y: float | None = None button_state: int + header: FrameHeader | None = Field( + default=None, metadata=[SignalRMeta(member_ignore=True)] + ) -class FrameDataBundle(MessagePackArrayModel): + +class FrameDataBundle(BaseModel): header: FrameHeader frames: list[LegacyReplayFrame] @@ -106,18 +106,16 @@ class APIUser(BaseModel): class ScoreInfo(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - mods: list[APIMod] user: APIUser ruleset: int - maximum_statistics: ScoreStatisticsInt + maximum_statistics: ScoreStatistics id: int | None = None total_score: int | None = None - acc: float | None = None + accuracy: float | None = None max_combo: int | None = None combo: int | None = None - statistics: ScoreStatisticsInt = Field(default_factory=dict) + statistics: ScoreStatistics = Field(default_factory=dict) class StoreScore(BaseModel): diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index a11fbe7..f3c5b29 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -2,14 +2,12 @@ from __future__ import annotations from abc import abstractmethod import asyncio -from enum import Enum -import inspect import time from typing import Any from app.config import settings from app.log import logger -from app.models.signalr import UserState, _by_index +from app.models.signalr import UserState from app.signalr.exception import InvokeException from app.signalr.packet import ( ClosePacket, @@ -23,7 +21,6 @@ from app.signalr.store import ResultStore from app.signalr.utils import get_signature from fastapi import WebSocket -from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect @@ -51,7 +48,7 @@ class Client: self.connection_id = connection_id self.connection_token = connection_token self.connection = connection - self.procotol = protocol + self.protocol = protocol self._listen_task: asyncio.Task | None = None self._ping_task: asyncio.Task | None = None self._store = ResultStore() @@ -64,14 +61,14 @@ class Client: return int(self.connection_id) async def send_packet(self, packet: Packet): - await self.connection.send_bytes(self.procotol.encode(packet)) + await self.connection.send_bytes(self.protocol.encode(packet)) async def receive_packets(self) -> list[Packet]: message = await self.connection.receive() d = message.get("bytes") or message.get("text", "").encode() if not d: return [] - return self.procotol.decode(d) + return self.protocol.decode(d) async def _ping(self): while True: @@ -265,14 +262,9 @@ class Hub[TState: UserState]: for name, param in signature.parameters.items(): if name == "self" or param.annotation is Client: continue - if issubclass(param.annotation, BaseModel): - call_params.append(param.annotation.model_validate(args.pop(0))) - elif inspect.isclass(param.annotation) and issubclass( - param.annotation, Enum - ): - call_params.append(_by_index(args.pop(0), param.annotation)) - else: - call_params.append(args.pop(0)) + call_params.append( + client.protocol.validate_object(args.pop(0), param.annotation) + ) return await method_(client, *call_params) async def call(self, client: Client, method: str, *args: Any) -> Any: diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 227cf7b..64232c0 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -12,7 +12,6 @@ from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActiv from .hub import Client, Hub -from pydantic import TypeAdapter from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -32,7 +31,7 @@ class MetadataHub(Hub[MetadataClientState]): ) -> set[Coroutine]: if store is not None and not store.pushable: return set() - data = store.to_dict() if store else None + data = store.for_push if store else None return { self.broadcast_group_call( self.online_presence_watchers_group(), @@ -103,7 +102,7 @@ class MetadataHub(Hub[MetadataClientState]): self.friend_presence_watchers_group(friend_id), "FriendPresenceUpdated", friend_id, - friend_state.to_dict(), + friend_state if friend_state.pushable else None, ) ) await asyncio.gather(*tasks) @@ -123,27 +122,24 @@ class MetadataHub(Hub[MetadataClientState]): client, "UserPresenceUpdated", user_id, - store.to_dict(), + store.for_push, ) ) await asyncio.gather(*tasks) - async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None: + async def UpdateActivity( + self, client: Client, activity: UserActivity | None + ) -> None: user_id = int(client.connection_id) - activity = ( - TypeAdapter(UserActivity).validate_python(activity_dict) - if activity_dict - else None - ) store = self.get_or_create_state(client) - store.user_activity = activity + store.activity = activity tasks = self.broadcast_tasks(user_id, store) tasks.add( self.call_noblock( client, "UserPresenceUpdated", user_id, - store.to_dict(), + store.for_push, ) ) await asyncio.gather(*tasks) @@ -155,7 +151,7 @@ class MetadataHub(Hub[MetadataClientState]): client, "UserPresenceUpdated", user_id, - store.to_dict(), + store, ) for user_id, store in self.state.items() if store.pushable diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index bd311ec..b9a3c99 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -13,8 +13,7 @@ from app.database.score_token import ScoreToken from app.dependencies.database import engine from app.models.beatmap import BeatmapRankStatus from app.models.mods import mods_to_int -from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt -from app.models.signalr import serialize_to_list +from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics from app.models.spectator_hub import ( APIUser, FrameDataBundle, @@ -69,8 +68,8 @@ def save_replay( md5: str, username: str, score: Score, - statistics: ScoreStatisticsInt, - maximum_statistics: ScoreStatisticsInt, + statistics: ScoreStatistics, + maximum_statistics: ScoreStatistics, frames: list[LegacyReplayFrame], ) -> None: data = bytearray() @@ -107,8 +106,8 @@ def save_replay( last_time = 0 for frame in frames: frame_strs.append( - f"{frame.time - last_time}|{frame.x or 0.0}" - f"|{frame.y or 0.0}|{frame.button_state}" + f"{frame.time - last_time}|{frame.mouse_x or 0.0}" + f"|{frame.mouse_y or 0.0}|{frame.button_state}" ) last_time = frame.time frame_strs.append("-12345|0|0|0") @@ -165,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]): async def on_client_connect(self, client: Client) -> None: tasks = [ - self.call_noblock( - client, "UserBeganPlaying", user_id, serialize_to_list(store.state) - ) + self.call_noblock(client, "UserBeganPlaying", user_id, store.state) for user_id, store in self.state.items() if store.state is not None ] @@ -214,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]): self.group_id(user_id), "UserBeganPlaying", user_id, - serialize_to_list(state), + state, ) async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None: @@ -222,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]): state = self.get_or_create_state(client) if not state.score: return - state.score.score_info.acc = frame_data.header.acc + state.score.score_info.accuracy = frame_data.header.accuracy state.score.score_info.combo = frame_data.header.combo state.score.score_info.max_combo = frame_data.header.max_combo state.score.score_info.statistics = frame_data.header.statistics @@ -233,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]): self.group_id(user_id), "UserSentFrames", user_id, - frame_data.model_dump(), + frame_data, ) async def EndPlaySession(self, client: Client, state: SpectatorState) -> None: @@ -316,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]): self.group_id(user_id), "UserFinishedPlaying", user_id, - serialize_to_list(state) if state else None, + state, ) async def StartWatchingUser(self, client: Client, target_id: int) -> None: @@ -327,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]): client, "UserBeganPlaying", target_id, - serialize_to_list(target_store.state), + target_store.state, ) store = self.get_or_create_state(client) store.watched_user.add(target_id) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index de5ce8a..be98c39 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -1,16 +1,24 @@ from __future__ import annotations from dataclasses import dataclass -from enum import IntEnum +import datetime +from enum import Enum, IntEnum +import inspect import json +from types import NoneType, UnionType from typing import ( Any, Protocol as TypingProtocol, + Union, + get_args, + get_origin, ) -from app.models.signalr import serialize_msgpack +from app.models.signalr import SignalRMeta, SignalRUnionMessage +from app.utils import camel_to_snake, snake_to_camel import msgpack_lazer_api as m +from pydantic import BaseModel SEP = b"\x1e" @@ -75,8 +83,61 @@ class Protocol(TypingProtocol): @staticmethod def encode(packet: Packet) -> bytes: ... + @classmethod + def validate_object(cls, v: Any, typ: type) -> Any: ... + class MsgpackProtocol: + @classmethod + def serialize_msgpack(cls, v: Any) -> Any: + typ = v.__class__ + if issubclass(typ, BaseModel): + return cls.serialize_to_list(v) + elif issubclass(typ, list): + return [cls.serialize_msgpack(item) for item in v] + elif issubclass(typ, datetime.datetime): + return [v, 0] + elif isinstance(v, dict): + return { + cls.serialize_msgpack(k): cls.serialize_msgpack(value) + for k, value in v.items() + } + elif issubclass(typ, Enum): + list_ = list(typ) + return list_.index(v) if v in list_ else v.value + return v + + @classmethod + def serialize_to_list(cls, value: BaseModel) -> list[Any]: + values = [] + for field, info in value.__class__.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.member_ignore: + continue + values.append(cls.serialize_msgpack(v=getattr(value, field))) + if issubclass(value.__class__, SignalRUnionMessage): + return [value.__class__.union_type, values] + else: + return values + + @staticmethod + def process_object(v: Any, typ: type[BaseModel]) -> Any: + if isinstance(v, list): + d = {} + for i, f in enumerate(typ.model_fields.items()): + field, info = f + if info.exclude: + continue + anno = info.annotation + if anno is None: + d[camel_to_snake(field)] = v[i] + continue + d[field] = MsgpackProtocol.validate_object(v[i], anno) + return d + return v + @staticmethod def _encode_varint(value: int) -> bytes: result = [] @@ -142,6 +203,49 @@ class MsgpackProtocol: ] raise ValueError(f"Unsupported packet type: {packet_type}") + @classmethod + def validate_object(cls, v: Any, typ: type) -> Any: + if issubclass(typ, BaseModel): + return typ.model_validate(obj=cls.process_object(v, typ)) + elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): + return v[0] + elif isinstance(v, list): + return [cls.validate_object(item, get_args(typ)[0]) for item in v] + elif inspect.isclass(typ) and issubclass(typ, Enum): + list_ = list(typ) + return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v) + elif get_origin(typ) is dict: + return { + cls.validate_object(k, get_args(typ)[0]): cls.validate_object( + v, get_args(typ)[1] + ) + for k, v in v.items() + } + elif (origin := get_origin(typ)) is Union or origin is UnionType: + args = get_args(typ) + if len(args) == 2 and NoneType in args: + non_none_args = [arg for arg in args if arg is not NoneType] + if len(non_none_args) == 1: + if v is None: + return None + return cls.validate_object(v, non_none_args[0]) + + # suppose use `MessagePack-CSharp Union | None` + # except `X (Other Type) | None` + if NoneType in args and v is None: + return None + if not all(issubclass(arg, SignalRUnionMessage) for arg in args): + raise ValueError( + f"Cannot validate {v} to {typ}, " + "only SignalRUnionMessage subclasses are supported" + ) + union_type = v[0] + for arg in args: + assert issubclass(arg, SignalRUnionMessage) + if arg.union_type == union_type: + return cls.validate_object(v[1], arg) + return v + @staticmethod def encode(packet: Packet) -> bytes: payload = [packet.type.value, packet.header or {}] @@ -153,7 +257,9 @@ class MsgpackProtocol: ] ) if packet.arguments is not None: - payload.append([serialize_msgpack(arg) for arg in packet.arguments]) + payload.append( + [MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments] + ) if packet.stream_ids is not None: payload.append(packet.stream_ids) elif isinstance(packet, CompletionPacket): @@ -166,7 +272,9 @@ class MsgpackProtocol: [ packet.invocation_id, result_kind, - packet.error or packet.result or None, + packet.error + or MsgpackProtocol.serialize_msgpack(packet.result) + or None, ] ) elif isinstance(packet, ClosePacket): @@ -183,6 +291,62 @@ class MsgpackProtocol: class JSONProtocol: + @classmethod + def serialize_to_json(cls, v: Any): + typ = v.__class__ + if issubclass(typ, BaseModel): + return cls.serialize_model(v) + elif isinstance(v, dict): + return { + cls.serialize_to_json(k): cls.serialize_to_json(value) + for k, value in v.items() + } + elif isinstance(v, list): + return [cls.serialize_to_json(item) for item in v] + elif isinstance(v, datetime.datetime): + return v.isoformat() + elif isinstance(v, Enum): + return v.value + return v + + @classmethod + def serialize_model(cls, v: BaseModel) -> dict[str, Any]: + d = {} + for field, info in v.__class__.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.json_ignore: + continue + d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = ( + cls.serialize_to_json(getattr(v, field)) + ) + if issubclass(v.__class__, SignalRUnionMessage): + return { + "$dtype": v.__class__.__name__, + "$value": d, + } + return d + + @staticmethod + def process_object( + v: Any, typ: type[BaseModel], from_union: bool = False + ) -> dict[str, Any]: + d = {} + for field, info in typ.model_fields.items(): + metadata = next( + (m for m in info.metadata if isinstance(m, SignalRMeta)), None + ) + if metadata and metadata.json_ignore: + continue + value = v.get(snake_to_camel(field, not from_union)) + anno = typ.model_fields[field].annotation + if anno is None: + d[field] = value + continue + d[field] = JSONProtocol.validate_object(value, anno) + return d + @staticmethod def decode(input: bytes) -> list[Packet]: packets_raw = input.removesuffix(SEP).split(SEP) @@ -227,6 +391,52 @@ class JSONProtocol: ] raise ValueError(f"Unsupported packet type: {packet_type}") + @classmethod + def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any: + if issubclass(typ, BaseModel): + return typ.model_validate(JSONProtocol.process_object(v, typ, from_union)) + elif inspect.isclass(typ) and issubclass(typ, datetime.datetime): + return datetime.datetime.fromisoformat(v) + elif isinstance(v, list): + return [cls.validate_object(item, get_args(typ)[0]) for item in v] + elif inspect.isclass(typ) and issubclass(typ, Enum): + list_ = list(typ) + return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v) + elif get_origin(typ) is dict: + return { + cls.validate_object(k, get_args(typ)[0]): cls.validate_object( + v, get_args(typ)[1] + ) + for k, v in v.items() + } + elif (origin := get_origin(typ)) is Union or origin is UnionType: + args = get_args(typ) + if len(args) == 2 and NoneType in args: + non_none_args = [arg for arg in args if arg is not NoneType] + if len(non_none_args) == 1: + if v is None: + return None + return cls.validate_object(v, non_none_args[0]) + + # suppose use `MessagePack-CSharp Union | None` + # except `X (Other Type) | None` + if NoneType in args and v is None: + return None + if not all( + issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args + ): + raise ValueError( + f"Cannot validate {v} to {typ}, " + "only SignalRUnionMessage subclasses are supported" + ) + # https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs + union_type = v["$dtype"] + for arg in args: + assert issubclass(arg, SignalRUnionMessage) + if arg.__name__ == union_type: + return cls.validate_object(v["$value"], arg, True) + return v + @staticmethod def encode(packet: Packet) -> bytes: payload: dict[str, Any] = { @@ -243,7 +453,9 @@ class JSONProtocol: if packet.invocation_id is not None: payload["invocationId"] = packet.invocation_id if packet.arguments is not None: - payload["arguments"] = packet.arguments + payload["arguments"] = [ + JSONProtocol.serialize_to_json(arg) for arg in packet.arguments + ] if packet.stream_ids is not None: payload["streamIds"] = packet.stream_ids elif isinstance(packet, CompletionPacket): @@ -255,7 +467,7 @@ class JSONProtocol: if packet.error is not None: payload["error"] = packet.error if packet.result is not None: - payload["result"] = packet.result + payload["result"] = JSONProtocol.serialize_to_json(packet.result) elif isinstance(packet, PingPacket): pass elif isinstance(packet, ClosePacket): diff --git a/app/utils.py b/app/utils.py index 09e8fdc..0d759a1 100644 --- a/app/utils.py +++ b/app/utils.py @@ -4,3 +4,55 @@ from __future__ import annotations def unix_timestamp_to_windows(timestamp: int) -> int: """Convert a Unix timestamp to a Windows timestamp.""" return (timestamp + 62135596800) * 10_000_000 + + +def camel_to_snake(name: str) -> str: + """Convert a camelCase string to snake_case.""" + result = [] + last_chr = "" + for char in name: + if char.isupper(): + if not last_chr.isupper() and result: + result.append("_") + result.append(char.lower()) + else: + result.append(char) + last_chr = char + return "".join(result) + + +def snake_to_camel(name: str, lower_case: bool = True) -> str: + """Convert a snake_case string to camelCase.""" + if not name: + return name + + parts = name.split("_") + if not parts: + return name + + # 常见缩写词列表 + abbreviations = { + "id", + "url", + "api", + "http", + "https", + "xml", + "json", + "css", + "html", + "sql", + "db", + } + + result = [] + for part in parts: + if part.lower() in abbreviations: + result.append(part.upper()) + else: + if result or not lower_case: + result.append(part.capitalize()) + else: + result.append(part.lower()) + + return "".join(result) diff --git a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi index b8653f0..433c53b 100644 --- a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi +++ b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi @@ -1,11 +1,4 @@ from typing import Any -class APIMod: - def __init__(self, acronym: str, settings: dict[str, Any]) -> None: ... - @property - def acronym(self) -> str: ... - @property - def settings(self) -> dict[str, Any]: ... - def encode(obj: Any) -> bytes: ... def decode(data: bytes) -> Any: ... diff --git a/packages/msgpack_lazer_api/src/decode.rs b/packages/msgpack_lazer_api/src/decode.rs index b8e239b..1e36c42 100644 --- a/packages/msgpack_lazer_api/src/decode.rs +++ b/packages/msgpack_lazer_api/src/decode.rs @@ -1,8 +1,6 @@ -use crate::APIMod; use chrono::{TimeZone, Utc}; use pyo3::types::PyDict; use pyo3::{prelude::*, IntoPyObjectExt}; -use std::collections::HashMap; use std::io::Read; pub fn read_object( @@ -206,13 +204,12 @@ fn read_array( let obj1 = read_object(py, cursor, false)?; if obj1.extract::(py).map_or(false, |k| k.len() == 2) { let obj2 = read_object(py, cursor, true)?; - return Ok(APIMod { - acronym: obj1.extract::(py)?, - settings: obj2.extract::>(py)?, - } - .into_pyobject(py)? - .into_any() - .unbind()); + + let api_mod_dict = PyDict::new(py); + api_mod_dict.set_item("acronym", obj1)?; + api_mod_dict.set_item("settings", obj2)?; + + return Ok(api_mod_dict.into_pyobject(py)?.into_any().unbind()); } else { items.push(obj1); i += 1; diff --git a/packages/msgpack_lazer_api/src/encode.rs b/packages/msgpack_lazer_api/src/encode.rs index 0e0907c..3ff4864 100644 --- a/packages/msgpack_lazer_api/src/encode.rs +++ b/packages/msgpack_lazer_api/src/encode.rs @@ -1,8 +1,7 @@ -use crate::APIMod; -use chrono::{DateTime, Utc}; -use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyStringMethods}; +use chrono::{DateTime, Utc}; +use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyResult, PyStringMethods}; use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString}; -use pyo3::{Bound, PyAny, PyRef, Python}; +use pyo3::{Bound, PyAny}; use std::io::Write; fn write_list(buf: &mut Vec, obj: &Bound<'_, PyList>) { @@ -61,19 +60,42 @@ fn write_hashmap(buf: &mut Vec, obj: &Bound<'_, PyDict>) { } } -fn write_nil(buf: &mut Vec){ +fn write_nil(buf: &mut Vec) { rmp::encode::write_nil(buf).unwrap(); } -// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs -fn write_api_mod(buf: &mut Vec, api_mod: PyRef) { - rmp::encode::write_array_len(buf, 2).unwrap(); - rmp::encode::write_str(buf, &api_mod.acronym).unwrap(); - rmp::encode::write_array_len(buf, api_mod.settings.len() as u32).unwrap(); - for (k, v) in api_mod.settings.iter() { - rmp::encode::write_str(buf, k).unwrap(); - Python::with_gil(|py| write_object(buf, &v.bind(py))); +fn is_api_mod(dict: &Bound<'_, PyDict>) -> bool { + if let Ok(Some(acronym)) = dict.get_item("acronym") { + if let Ok(acronym_str) = acronym.extract::() { + return acronym_str.len() == 2; + } } + false +} + +// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs +fn write_api_mod(buf: &mut Vec, api_mod: &Bound<'_, PyDict>) -> PyResult<()> { + let acronym = api_mod + .get_item("acronym")? + .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("APIMod missing 'acronym' field"))?; + let acronym_str = acronym.extract::()?; + + let settings = api_mod + .get_item("settings")? + .unwrap_or_else(|| PyDict::new(acronym.py()).into_any()); + let settings_dict = settings.downcast::()?; + + rmp::encode::write_array_len(buf, 2).unwrap(); + rmp::encode::write_str(buf, &acronym_str).unwrap(); + rmp::encode::write_array_len(buf, settings_dict.len() as u32).unwrap(); + + for (k, v) in settings_dict.iter() { + let key_str = k.extract::()?; + rmp::encode::write_str(buf, &key_str).unwrap(); + write_object(buf, &v); + } + + Ok(()) } fn write_datetime(buf: &mut Vec, obj: &Bound<'_, PyDateTime>) { @@ -111,21 +133,23 @@ pub fn write_object(buf: &mut Vec, obj: &Bound<'_, PyAny>) { } else if let Ok(string) = obj.downcast::() { write_string(buf, string); } else if let Ok(boolean) = obj.downcast::() { - write_bool(buf, boolean); + write_bool(buf, boolean); } else if let Ok(float) = obj.downcast::() { - write_float(buf, float); + write_float(buf, float); } else if let Ok(integer) = obj.downcast::() { - write_integer(buf, integer); + write_integer(buf, integer); } else if let Ok(bytes) = obj.downcast::() { write_bin(buf, bytes); } else if let Ok(dict) = obj.downcast::() { - write_hashmap(buf, dict); + if is_api_mod(dict) { + write_api_mod(buf, dict).unwrap_or_else(|_| write_hashmap(buf, dict)); + } else { + write_hashmap(buf, dict); + } } else if let Ok(_none) = obj.downcast::() { write_nil(buf); } else if let Ok(datetime) = obj.downcast::() { write_datetime(buf, datetime); - } else if let Ok(api_mod) = obj.extract::>() { - write_api_mod(buf, api_mod); } else { panic!("Unsupported type"); } diff --git a/packages/msgpack_lazer_api/src/lib.rs b/packages/msgpack_lazer_api/src/lib.rs index fda540c..220e645 100644 --- a/packages/msgpack_lazer_api/src/lib.rs +++ b/packages/msgpack_lazer_api/src/lib.rs @@ -2,30 +2,6 @@ mod decode; mod encode; use pyo3::prelude::*; -use std::collections::HashMap; - -#[pyclass] -struct APIMod { - #[pyo3(get, set)] - acronym: String, - #[pyo3(get, set)] - settings: HashMap, -} - -#[pymethods] -impl APIMod { - #[new] - fn new(acronym: String, settings: HashMap) -> Self { - APIMod { acronym, settings } - } - - fn __repr__(&self) -> String { - format!( - "APIMod(acronym='{}', settings={:?})", - self.acronym, self.settings - ) - } -} #[pyfunction] #[pyo3(name = "encode")] @@ -46,6 +22,5 @@ fn decode_py(py: Python, data: &[u8]) -> PyResult { fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(encode_py, m)?)?; m.add_function(wrap_pyfunction!(decode_py, m)?)?; - m.add_class::()?; Ok(()) }