fix(signalr): use custom msgpack to encode/decode

This commit is contained in:
MingxuanGame
2025-07-30 06:01:17 +00:00
parent a53c63a33a
commit 4a5a1c86c6
17 changed files with 1191 additions and 892 deletions

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import datetime
from typing import Any, get_origin
import msgpack
from pydantic import (
BaseModel,
ConfigDict,
@@ -24,11 +23,11 @@ def serialize_to_list(value: BaseModel) -> list[Any]:
elif anno and issubclass(anno, list):
data.append(
TypeAdapter(
info.annotation,
info.annotation, config=ConfigDict(arbitrary_types_allowed=True)
).dump_python(v)
)
elif isinstance(v, datetime.datetime):
data.append([msgpack.ext.Timestamp.from_datetime(v), 0])
data.append([v, 0])
else:
data.append(v)
return data

View File

@@ -11,15 +11,8 @@ from .score import (
)
from .signalr import MessagePackArrayModel, UserState
import msgpack
from pydantic import BaseModel, Field, field_validator
class APIMod(MessagePackArrayModel):
acronym: str
settings: dict[str, Any] | list = Field(
default_factory=dict
) # FIXME: with settings
from msgpack_lazer_api import APIMod
from pydantic import BaseModel, ConfigDict, Field, field_validator
class SpectatedUserState(IntEnum):
@@ -32,6 +25,8 @@ class SpectatedUserState(IntEnum):
class SpectatorState(MessagePackArrayModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
beatmap_id: int | None = None
ruleset_id: int | None = None # 0,1,2,3
mods: list[APIMod] = Field(default_factory=list)
@@ -58,6 +53,8 @@ class ScoreProcessorStatistics(MessagePackArrayModel):
class FrameHeader(MessagePackArrayModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
total_score: int
acc: float
combo: int
@@ -70,10 +67,8 @@ class FrameHeader(MessagePackArrayModel):
@field_validator("received_time", mode="before")
@classmethod
def validate_timestamp(cls, v: Any) -> datetime.datetime:
if isinstance(v, msgpack.ext.Timestamp):
return v.to_datetime()
if isinstance(v, list):
return v[0].to_datetime()
return v[0]
if isinstance(v, datetime.datetime):
return v
if isinstance(v, int | float):
@@ -111,6 +106,8 @@ class APIUser(BaseModel):
class ScoreInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
mods: list[APIMod]
user: APIUser
ruleset: int

View File

@@ -4,7 +4,6 @@ from typing import Literal
from app.database import User as DBUser
from app.dependencies.database import get_db
from app.dependencies import get_current_user
from app.models.score import INT_TO_MODE
from app.models.user import User as ApiUser
from app.utils import convert_db_user_to_api_user
@@ -20,24 +19,18 @@ from sqlmodel.sql.expression import col
# ---------- Shared Utility ----------
async def get_user_by_lookup(
db: AsyncSession,
lookup: str,
key: str = "id"
db: AsyncSession, lookup: str, key: str = "id"
) -> DBUser | None:
"""根据查找方式获取用户"""
if key == "id":
try:
user_id = int(lookup)
result = await db.exec(
select(DBUser).where(DBUser.id == user_id)
)
result = await db.exec(select(DBUser).where(DBUser.id == user_id))
return result.first()
except ValueError:
return None
elif key == "username":
result = await db.exec(
select(DBUser).where(DBUser.name == lookup)
)
result = await db.exec(select(DBUser).where(DBUser.name == lookup))
return result.first()
else:
return None
@@ -50,6 +43,7 @@ class BatchUserResponse(BaseModel):
@router.get("/users", response_model=BatchUserResponse)
@router.get("/users/lookup", response_model=BatchUserResponse)
@router.get("/users/lookup/", response_model=BatchUserResponse)
async def get_users(
user_ids: list[int] = Query(default_factory=list, alias="ids[]"),
include_variant_statistics: bool = Query(default=False), # TODO: future use
@@ -75,41 +69,43 @@ async def get_users(
)
# ---------- Individual User ----------
@router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
@router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
async def get_user_with_mode(
user_lookup: str,
mode: Literal["osu", "taiko", "fruits", "mania"],
key: Literal["id", "username"] = Query("id"),
current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取指定游戏模式的用户信息"""
user = await get_user_by_lookup(db, user_lookup, key)
if not user:
raise HTTPException(status_code=404, detail="User not found")
# # ---------- Individual User ----------
# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
# async def get_user_with_mode(
# user_lookup: str,
# mode: Literal["osu", "taiko", "fruits", "mania"],
# key: Literal["id", "username"] = Query("id"),
# current_user: DBUser = Depends(get_current_user),
# db: AsyncSession = Depends(get_db),
# ):
# """获取指定游戏模式的用户信息"""
# user = await get_user_by_lookup(db, user_lookup, key)
# if not user:
# raise HTTPException(status_code=404, detail="User not found")
return await convert_db_user_to_api_user(user, mode)
# return await convert_db_user_to_api_user(user, mode)
@router.get("/users/{user_lookup}", response_model=ApiUser)
@router.get("/users/{user_lookup}/", response_model=ApiUser)
async def get_user_default(
user_lookup: str,
key: Literal["id", "username"] = Query("id"),
current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取用户信息默认使用osu模式但包含所有模式的统计信息"""
user = await get_user_by_lookup(db, user_lookup, key)
if not user:
raise HTTPException(status_code=404, detail="User not found")
# @router.get("/users/{user_lookup}", response_model=ApiUser)
# @router.get("/users/{user_lookup}/", response_model=ApiUser)
# async def get_user_default(
# user_lookup: str,
# key: Literal["id", "username"] = Query("id"),
# current_user: DBUser = Depends(get_current_user),
# db: AsyncSession = Depends(get_db),
# ):
# """获取用户信息默认使用osu模式但包含所有模式的统计信息"""
# user = await get_user_by_lookup(db, user_lookup, key)
# if not user:
# raise HTTPException(status_code=404, detail="User not found")
return await convert_db_user_to_api_user(user, "osu")
# return await convert_db_user_to_api_user(user, "osu")
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
@router.get("/users/{user}/", response_model=ApiUser)
@router.get("/users/{user}", response_model=ApiUser)
async def get_user_info(
user: str,
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",

View File

@@ -8,7 +8,7 @@ from typing import (
Protocol as TypingProtocol,
)
import msgpack
import msgpack_lazer_api as m
SEP = b"\x1e"
@@ -104,11 +104,7 @@ class MsgpackProtocol:
def decode(input: bytes) -> list[Packet]:
length, offset = MsgpackProtocol._decode_varint(input)
message_data = input[offset : offset + length]
# FIXME: custom deserializer for APIMod
# https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
unpacked = msgpack.unpackb(
message_data, raw=False, strict_map_key=False, use_list=True
)
unpacked = m.decode(message_data)
packet_type = PacketType(unpacked[0])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
@@ -180,7 +176,7 @@ class MsgpackProtocol:
)
elif isinstance(packet, PingPacket):
payload.pop(-1)
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
data = m.encode(payload)
return MsgpackProtocol._encode_varint(len(data)) + data