fix(signalr): encode enum by index

This commit is contained in:
MingxuanGame
2025-08-02 14:59:12 +00:00
parent a11ea743a7
commit 5ccb35dc8b
3 changed files with 27 additions and 19 deletions

View File

@@ -15,26 +15,26 @@ from pydantic import (
)
def serialize_msgpack(v: Any) -> Any:
typ = v.__class__
if issubclass(typ, BaseModel):
return serialize_to_list(v)
elif issubclass(typ, list):
return TypeAdapter(
typ, config=ConfigDict(arbitrary_types_allowed=True)
).dump_python(v)
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif issubclass(typ, Enum):
list_ = list(typ)
return list_.index(v) if v in list_ else v.value
return v
def serialize_to_list(value: BaseModel) -> list[Any]:
data = []
for field, info in value.__class__.model_fields.items():
v = getattr(value, field)
typ = v.__class__
if issubclass(typ, BaseModel):
data.append(serialize_to_list(v))
elif issubclass(typ, list):
data.append(
TypeAdapter(
info.annotation, config=ConfigDict(arbitrary_types_allowed=True)
).dump_python(v)
)
elif issubclass(typ, datetime.datetime):
data.append([v, 0])
elif issubclass(typ, Enum):
list_ = list(typ)
data.append(list_.index(v) if v in list_ else v.value)
else:
data.append(v)
data.append(serialize_msgpack(v=getattr(value, field)))
return data

View File

@@ -2,12 +2,14 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from enum import Enum
import inspect
import time
from typing import Any
from app.config import settings
from app.log import logger
from app.models.signalr import UserState
from app.models.signalr import UserState, _by_index
from app.signalr.exception import InvokeException
from app.signalr.packet import (
ClosePacket,
@@ -265,6 +267,10 @@ class Hub[TState: UserState]:
continue
if issubclass(param.annotation, BaseModel):
call_params.append(param.annotation.model_validate(args.pop(0)))
elif inspect.isclass(param.annotation) and issubclass(
param.annotation, Enum
):
call_params.append(_by_index(args.pop(0), param.annotation))
else:
call_params.append(args.pop(0))
return await method_(client, *call_params)

View File

@@ -8,6 +8,8 @@ from typing import (
Protocol as TypingProtocol,
)
from app.models.signalr import serialize_msgpack
import msgpack_lazer_api as m
SEP = b"\x1e"
@@ -151,7 +153,7 @@ class MsgpackProtocol:
]
)
if packet.arguments is not None:
payload.append(packet.arguments)
payload.append([serialize_msgpack(arg) for arg in packet.arguments])
if packet.stream_ids is not None:
payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket):