fix(signalr): encode enum by index

This commit is contained in:
MingxuanGame
2025-08-02 14:59:12 +00:00
parent a11ea743a7
commit 5ccb35dc8b
3 changed files with 27 additions and 19 deletions

View File

@@ -15,26 +15,26 @@ from pydantic import (
) )
def serialize_msgpack(v: Any) -> Any:
typ = v.__class__
if issubclass(typ, BaseModel):
return serialize_to_list(v)
elif issubclass(typ, list):
return TypeAdapter(
typ, config=ConfigDict(arbitrary_types_allowed=True)
).dump_python(v)
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif issubclass(typ, Enum):
list_ = list(typ)
return list_.index(v) if v in list_ else v.value
return v
def serialize_to_list(value: BaseModel) -> list[Any]: def serialize_to_list(value: BaseModel) -> list[Any]:
data = [] data = []
for field, info in value.__class__.model_fields.items(): for field, info in value.__class__.model_fields.items():
v = getattr(value, field) data.append(serialize_msgpack(v=getattr(value, field)))
typ = v.__class__
if issubclass(typ, BaseModel):
data.append(serialize_to_list(v))
elif issubclass(typ, list):
data.append(
TypeAdapter(
info.annotation, config=ConfigDict(arbitrary_types_allowed=True)
).dump_python(v)
)
elif issubclass(typ, datetime.datetime):
data.append([v, 0])
elif issubclass(typ, Enum):
list_ = list(typ)
data.append(list_.index(v) if v in list_ else v.value)
else:
data.append(v)
return data return data

View File

@@ -2,12 +2,14 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio import asyncio
from enum import Enum
import inspect
import time import time
from typing import Any from typing import Any
from app.config import settings from app.config import settings
from app.log import logger from app.log import logger
from app.models.signalr import UserState from app.models.signalr import UserState, _by_index
from app.signalr.exception import InvokeException from app.signalr.exception import InvokeException
from app.signalr.packet import ( from app.signalr.packet import (
ClosePacket, ClosePacket,
@@ -265,6 +267,10 @@ class Hub[TState: UserState]:
continue continue
if issubclass(param.annotation, BaseModel): if issubclass(param.annotation, BaseModel):
call_params.append(param.annotation.model_validate(args.pop(0))) call_params.append(param.annotation.model_validate(args.pop(0)))
elif inspect.isclass(param.annotation) and issubclass(
param.annotation, Enum
):
call_params.append(_by_index(args.pop(0), param.annotation))
else: else:
call_params.append(args.pop(0)) call_params.append(args.pop(0))
return await method_(client, *call_params) return await method_(client, *call_params)

View File

@@ -8,6 +8,8 @@ from typing import (
Protocol as TypingProtocol, Protocol as TypingProtocol,
) )
from app.models.signalr import serialize_msgpack
import msgpack_lazer_api as m import msgpack_lazer_api as m
SEP = b"\x1e" SEP = b"\x1e"
@@ -151,7 +153,7 @@ class MsgpackProtocol:
] ]
) )
if packet.arguments is not None: if packet.arguments is not None:
payload.append(packet.arguments) payload.append([serialize_msgpack(arg) for arg in packet.arguments])
if packet.stream_ids is not None: if packet.stream_ids is not None:
payload.append(packet.stream_ids) payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket): elif isinstance(packet, CompletionPacket):