From d399cb52e261571fb946bb0b6e809023f932bd6d Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 11:00:57 +0000 Subject: [PATCH] fix(signarl): wrong msgpack encode --- app/models/signalr.py | 46 +++++++++++++++++-- app/signalr/packet.py | 2 +- .../msgpack_lazer_api/msgpack_lazer_api.pyi | 2 +- packages/msgpack_lazer_api/src/decode.rs | 4 +- packages/msgpack_lazer_api/src/encode.rs | 10 ++-- 5 files changed, 50 insertions(+), 14 deletions(-) diff --git a/app/models/signalr.py b/app/models/signalr.py index 37b2741..202da4f 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -1,10 +1,12 @@ from __future__ import annotations import datetime -from typing import Any, get_origin +from enum import Enum +from typing import Any from pydantic import ( BaseModel, + BeforeValidator, ConfigDict, Field, TypeAdapter, @@ -17,22 +19,56 @@ def serialize_to_list(value: BaseModel) -> list[Any]: data = [] for field, info in value.__class__.model_fields.items(): v = getattr(value, field) - anno = get_origin(info.annotation) - if anno and issubclass(anno, BaseModel): + typ = v.__class__ + if issubclass(typ, BaseModel): data.append(serialize_to_list(v)) - elif anno and issubclass(anno, list): + elif issubclass(typ, list): data.append( TypeAdapter( info.annotation, config=ConfigDict(arbitrary_types_allowed=True) ).dump_python(v) ) - elif isinstance(v, datetime.datetime): + elif issubclass(typ, datetime.datetime): data.append([v, 0]) + elif issubclass(typ, Enum): + list_ = list(typ) + data.append(list_.index(v) if v in list_ else v.value) else: data.append(v) return data +def _by_index(v: Any, class_: type[Enum]): + enum_list = list(class_) + if not isinstance(v, int): + return v + if 0 <= v < len(enum_list): + return enum_list[v] + raise ValueError( + f"Value {v} is out of range for enum " + f"{class_.__name__} with {len(enum_list)} items" + ) + + +def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator: + return BeforeValidator(lambda v: _by_index(v, enum_class)) + + +def msgpack_union(v): + data = v[1] + data.append(v[0]) + return data + + +def msgpack_union_dump(v: BaseModel) -> list[Any]: + _type = getattr(v, "type", None) + if _type is None: + raise ValueError( + f"Model {v.__class__.__name__} does not have a '_type' attribute" + ) + return [_type, serialize_to_list(v)] + + class MessagePackArrayModel(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index e361ef8..387231c 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -158,7 +158,7 @@ class MsgpackProtocol: result_kind = 2 if packet.error: result_kind = 1 - elif packet.result is None: + elif packet.result is not None: result_kind = 3 payload.extend( [ diff --git a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi index 88b79c5..b8653f0 100644 --- a/packages/msgpack_lazer_api/msgpack_lazer_api.pyi +++ b/packages/msgpack_lazer_api/msgpack_lazer_api.pyi @@ -5,7 +5,7 @@ class APIMod: @property def acronym(self) -> str: ... @property - def settings(self) -> str: ... + def settings(self) -> dict[str, Any]: ... def encode(obj: Any) -> bytes: ... def decode(data: bytes) -> Any: ... diff --git a/packages/msgpack_lazer_api/src/decode.rs b/packages/msgpack_lazer_api/src/decode.rs index 15156ca..b8e239b 100644 --- a/packages/msgpack_lazer_api/src/decode.rs +++ b/packages/msgpack_lazer_api/src/decode.rs @@ -13,6 +13,8 @@ pub fn read_object( match rmp::decode::read_marker(cursor) { Ok(marker) => match marker { rmp::Marker::Null => Ok(py.None()), + rmp::Marker::True => Ok(true.into_py_any(py)?), + rmp::Marker::False => Ok(false.into_py_any(py)?), rmp::Marker::FixPos(val) => Ok(val.into_pyobject(py)?.into_any().unbind()), rmp::Marker::FixNeg(val) => Ok(val.into_pyobject(py)?.into_any().unbind()), rmp::Marker::U8 => { @@ -86,8 +88,6 @@ pub fn read_object( cursor.read_exact(&mut data).map_err(to_py_err)?; Ok(data.into_pyobject(py)?.into_any().unbind()) } - rmp::Marker::True => Ok(true.into_py_any(py)?), - rmp::Marker::False => Ok(false.into_py_any(py)?), rmp::Marker::FixStr(len) => read_string(py, cursor, len as u32), rmp::Marker::Str8 => { let mut buf = [0u8; 1]; diff --git a/packages/msgpack_lazer_api/src/encode.rs b/packages/msgpack_lazer_api/src/encode.rs index 88a732b..0e0907c 100644 --- a/packages/msgpack_lazer_api/src/encode.rs +++ b/packages/msgpack_lazer_api/src/encode.rs @@ -110,12 +110,12 @@ pub fn write_object(buf: &mut Vec, obj: &Bound<'_, PyAny>) { write_list(buf, list); } else if let Ok(string) = obj.downcast::() { write_string(buf, string); - } else if let Ok(integer) = obj.downcast::() { - write_integer(buf, integer); - } else if let Ok(float) = obj.downcast::() { - write_float(buf, float); } else if let Ok(boolean) = obj.downcast::() { - write_bool(buf, boolean); + write_bool(buf, boolean); + } else if let Ok(float) = obj.downcast::() { + write_float(buf, float); + } else if let Ok(integer) = obj.downcast::() { + write_integer(buf, integer); } else if let Ok(bytes) = obj.downcast::() { write_bin(buf, bytes); } else if let Ok(dict) = obj.downcast::() {