diff --git a/app/models/signalr.py b/app/models/signalr.py index 202da4f..9e189e9 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -15,26 +15,26 @@ from pydantic import ( ) +def serialize_msgpack(v: Any) -> Any: + typ = v.__class__ + if issubclass(typ, BaseModel): + return serialize_to_list(v) + elif issubclass(typ, list): + return TypeAdapter( + typ, config=ConfigDict(arbitrary_types_allowed=True) + ).dump_python(v) + elif issubclass(typ, datetime.datetime): + return [v, 0] + elif issubclass(typ, Enum): + list_ = list(typ) + return list_.index(v) if v in list_ else v.value + return v + + def serialize_to_list(value: BaseModel) -> list[Any]: data = [] for field, info in value.__class__.model_fields.items(): - v = getattr(value, field) - typ = v.__class__ - if issubclass(typ, BaseModel): - data.append(serialize_to_list(v)) - elif issubclass(typ, list): - data.append( - TypeAdapter( - info.annotation, config=ConfigDict(arbitrary_types_allowed=True) - ).dump_python(v) - ) - elif issubclass(typ, datetime.datetime): - data.append([v, 0]) - elif issubclass(typ, Enum): - list_ = list(typ) - data.append(list_.index(v) if v in list_ else v.value) - else: - data.append(v) + data.append(serialize_msgpack(v=getattr(value, field))) return data diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index 276140f..a11fbe7 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -2,12 +2,14 @@ from __future__ import annotations from abc import abstractmethod import asyncio +from enum import Enum +import inspect import time from typing import Any from app.config import settings from app.log import logger -from app.models.signalr import UserState +from app.models.signalr import UserState, _by_index from app.signalr.exception import InvokeException from app.signalr.packet import ( ClosePacket, @@ -265,6 +267,10 @@ class Hub[TState: UserState]: continue if issubclass(param.annotation, BaseModel): call_params.append(param.annotation.model_validate(args.pop(0))) + elif inspect.isclass(param.annotation) and issubclass( + param.annotation, Enum + ): + call_params.append(_by_index(args.pop(0), param.annotation)) else: call_params.append(args.pop(0)) return await method_(client, *call_params) diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 387231c..de5ce8a 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -8,6 +8,8 @@ from typing import ( Protocol as TypingProtocol, ) +from app.models.signalr import serialize_msgpack + import msgpack_lazer_api as m SEP = b"\x1e" @@ -151,7 +153,7 @@ class MsgpackProtocol: ] ) if packet.arguments is not None: - payload.append(packet.arguments) + payload.append([serialize_msgpack(arg) for arg in packet.arguments]) if packet.stream_ids is not None: payload.append(packet.stream_ids) elif isinstance(packet, CompletionPacket):