fix(signalr): encode enum by index
This commit is contained in:
@@ -15,26 +15,26 @@ from pydantic import (
|
||||
)
|
||||
|
||||
|
||||
def serialize_msgpack(v: Any) -> Any:
|
||||
typ = v.__class__
|
||||
if issubclass(typ, BaseModel):
|
||||
return serialize_to_list(v)
|
||||
elif issubclass(typ, list):
|
||||
return TypeAdapter(
|
||||
typ, config=ConfigDict(arbitrary_types_allowed=True)
|
||||
).dump_python(v)
|
||||
elif issubclass(typ, datetime.datetime):
|
||||
return [v, 0]
|
||||
elif issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
return list_.index(v) if v in list_ else v.value
|
||||
return v
|
||||
|
||||
|
||||
def serialize_to_list(value: BaseModel) -> list[Any]:
|
||||
data = []
|
||||
for field, info in value.__class__.model_fields.items():
|
||||
v = getattr(value, field)
|
||||
typ = v.__class__
|
||||
if issubclass(typ, BaseModel):
|
||||
data.append(serialize_to_list(v))
|
||||
elif issubclass(typ, list):
|
||||
data.append(
|
||||
TypeAdapter(
|
||||
info.annotation, config=ConfigDict(arbitrary_types_allowed=True)
|
||||
).dump_python(v)
|
||||
)
|
||||
elif issubclass(typ, datetime.datetime):
|
||||
data.append([v, 0])
|
||||
elif issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
data.append(list_.index(v) if v in list_ else v.value)
|
||||
else:
|
||||
data.append(v)
|
||||
data.append(serialize_msgpack(v=getattr(value, field)))
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -2,12 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.log import logger
|
||||
from app.models.signalr import UserState
|
||||
from app.models.signalr import UserState, _by_index
|
||||
from app.signalr.exception import InvokeException
|
||||
from app.signalr.packet import (
|
||||
ClosePacket,
|
||||
@@ -265,6 +267,10 @@ class Hub[TState: UserState]:
|
||||
continue
|
||||
if issubclass(param.annotation, BaseModel):
|
||||
call_params.append(param.annotation.model_validate(args.pop(0)))
|
||||
elif inspect.isclass(param.annotation) and issubclass(
|
||||
param.annotation, Enum
|
||||
):
|
||||
call_params.append(_by_index(args.pop(0), param.annotation))
|
||||
else:
|
||||
call_params.append(args.pop(0))
|
||||
return await method_(client, *call_params)
|
||||
|
||||
@@ -8,6 +8,8 @@ from typing import (
|
||||
Protocol as TypingProtocol,
|
||||
)
|
||||
|
||||
from app.models.signalr import serialize_msgpack
|
||||
|
||||
import msgpack_lazer_api as m
|
||||
|
||||
SEP = b"\x1e"
|
||||
@@ -151,7 +153,7 @@ class MsgpackProtocol:
|
||||
]
|
||||
)
|
||||
if packet.arguments is not None:
|
||||
payload.append(packet.arguments)
|
||||
payload.append([serialize_msgpack(arg) for arg in packet.arguments])
|
||||
if packet.stream_ids is not None:
|
||||
payload.append(packet.stream_ids)
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
|
||||
Reference in New Issue
Block a user