fix(signalr): use custom msgpack to encode/decode
This commit is contained in:
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import datetime
|
||||
from typing import Any, get_origin
|
||||
|
||||
import msgpack
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -24,11 +23,11 @@ def serialize_to_list(value: BaseModel) -> list[Any]:
|
||||
elif anno and issubclass(anno, list):
|
||||
data.append(
|
||||
TypeAdapter(
|
||||
info.annotation,
|
||||
info.annotation, config=ConfigDict(arbitrary_types_allowed=True)
|
||||
).dump_python(v)
|
||||
)
|
||||
elif isinstance(v, datetime.datetime):
|
||||
data.append([msgpack.ext.Timestamp.from_datetime(v), 0])
|
||||
data.append([v, 0])
|
||||
else:
|
||||
data.append(v)
|
||||
return data
|
||||
|
||||
@@ -11,15 +11,8 @@ from .score import (
|
||||
)
|
||||
from .signalr import MessagePackArrayModel, UserState
|
||||
|
||||
import msgpack
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class APIMod(MessagePackArrayModel):
|
||||
acronym: str
|
||||
settings: dict[str, Any] | list = Field(
|
||||
default_factory=dict
|
||||
) # FIXME: with settings
|
||||
from msgpack_lazer_api import APIMod
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class SpectatedUserState(IntEnum):
|
||||
@@ -32,6 +25,8 @@ class SpectatedUserState(IntEnum):
|
||||
|
||||
|
||||
class SpectatorState(MessagePackArrayModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
beatmap_id: int | None = None
|
||||
ruleset_id: int | None = None # 0,1,2,3
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
@@ -58,6 +53,8 @@ class ScoreProcessorStatistics(MessagePackArrayModel):
|
||||
|
||||
|
||||
class FrameHeader(MessagePackArrayModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
total_score: int
|
||||
acc: float
|
||||
combo: int
|
||||
@@ -70,10 +67,8 @@ class FrameHeader(MessagePackArrayModel):
|
||||
@field_validator("received_time", mode="before")
|
||||
@classmethod
|
||||
def validate_timestamp(cls, v: Any) -> datetime.datetime:
|
||||
if isinstance(v, msgpack.ext.Timestamp):
|
||||
return v.to_datetime()
|
||||
if isinstance(v, list):
|
||||
return v[0].to_datetime()
|
||||
return v[0]
|
||||
if isinstance(v, datetime.datetime):
|
||||
return v
|
||||
if isinstance(v, int | float):
|
||||
@@ -111,6 +106,8 @@ class APIUser(BaseModel):
|
||||
|
||||
|
||||
class ScoreInfo(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
mods: list[APIMod]
|
||||
user: APIUser
|
||||
ruleset: int
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Literal
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.score import INT_TO_MODE
|
||||
from app.models.user import User as ApiUser
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
@@ -20,24 +19,18 @@ from sqlmodel.sql.expression import col
|
||||
|
||||
# ---------- Shared Utility ----------
|
||||
async def get_user_by_lookup(
|
||||
db: AsyncSession,
|
||||
lookup: str,
|
||||
key: str = "id"
|
||||
db: AsyncSession, lookup: str, key: str = "id"
|
||||
) -> DBUser | None:
|
||||
"""根据查找方式获取用户"""
|
||||
if key == "id":
|
||||
try:
|
||||
user_id = int(lookup)
|
||||
result = await db.exec(
|
||||
select(DBUser).where(DBUser.id == user_id)
|
||||
)
|
||||
result = await db.exec(select(DBUser).where(DBUser.id == user_id))
|
||||
return result.first()
|
||||
except ValueError:
|
||||
return None
|
||||
elif key == "username":
|
||||
result = await db.exec(
|
||||
select(DBUser).where(DBUser.name == lookup)
|
||||
)
|
||||
result = await db.exec(select(DBUser).where(DBUser.name == lookup))
|
||||
return result.first()
|
||||
else:
|
||||
return None
|
||||
@@ -50,6 +43,7 @@ class BatchUserResponse(BaseModel):
|
||||
|
||||
@router.get("/users", response_model=BatchUserResponse)
|
||||
@router.get("/users/lookup", response_model=BatchUserResponse)
|
||||
@router.get("/users/lookup/", response_model=BatchUserResponse)
|
||||
async def get_users(
|
||||
user_ids: list[int] = Query(default_factory=list, alias="ids[]"),
|
||||
include_variant_statistics: bool = Query(default=False), # TODO: future use
|
||||
@@ -75,41 +69,43 @@ async def get_users(
|
||||
)
|
||||
|
||||
|
||||
# ---------- Individual User ----------
|
||||
@router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
|
||||
@router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
|
||||
async def get_user_with_mode(
|
||||
user_lookup: str,
|
||||
mode: Literal["osu", "taiko", "fruits", "mania"],
|
||||
key: Literal["id", "username"] = Query("id"),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取指定游戏模式的用户信息"""
|
||||
user = await get_user_by_lookup(db, user_lookup, key)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
# # ---------- Individual User ----------
|
||||
# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
|
||||
# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
|
||||
# async def get_user_with_mode(
|
||||
# user_lookup: str,
|
||||
# mode: Literal["osu", "taiko", "fruits", "mania"],
|
||||
# key: Literal["id", "username"] = Query("id"),
|
||||
# current_user: DBUser = Depends(get_current_user),
|
||||
# db: AsyncSession = Depends(get_db),
|
||||
# ):
|
||||
# """获取指定游戏模式的用户信息"""
|
||||
# user = await get_user_by_lookup(db, user_lookup, key)
|
||||
# if not user:
|
||||
# raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return await convert_db_user_to_api_user(user, mode)
|
||||
# return await convert_db_user_to_api_user(user, mode)
|
||||
|
||||
|
||||
@router.get("/users/{user_lookup}", response_model=ApiUser)
|
||||
@router.get("/users/{user_lookup}/", response_model=ApiUser)
|
||||
async def get_user_default(
|
||||
user_lookup: str,
|
||||
key: Literal["id", "username"] = Query("id"),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取用户信息(默认使用osu模式,但包含所有模式的统计信息)"""
|
||||
user = await get_user_by_lookup(db, user_lookup, key)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
# @router.get("/users/{user_lookup}", response_model=ApiUser)
|
||||
# @router.get("/users/{user_lookup}/", response_model=ApiUser)
|
||||
# async def get_user_default(
|
||||
# user_lookup: str,
|
||||
# key: Literal["id", "username"] = Query("id"),
|
||||
# current_user: DBUser = Depends(get_current_user),
|
||||
# db: AsyncSession = Depends(get_db),
|
||||
# ):
|
||||
# """获取用户信息(默认使用osu模式,但包含所有模式的统计信息)"""
|
||||
# user = await get_user_by_lookup(db, user_lookup, key)
|
||||
# if not user:
|
||||
# raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return await convert_db_user_to_api_user(user, "osu")
|
||||
# return await convert_db_user_to_api_user(user, "osu")
|
||||
|
||||
|
||||
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
|
||||
@router.get("/users/{user}/", response_model=ApiUser)
|
||||
@router.get("/users/{user}", response_model=ApiUser)
|
||||
async def get_user_info(
|
||||
user: str,
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import (
|
||||
Protocol as TypingProtocol,
|
||||
)
|
||||
|
||||
import msgpack
|
||||
import msgpack_lazer_api as m
|
||||
|
||||
SEP = b"\x1e"
|
||||
|
||||
@@ -104,11 +104,7 @@ class MsgpackProtocol:
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
length, offset = MsgpackProtocol._decode_varint(input)
|
||||
message_data = input[offset : offset + length]
|
||||
# FIXME: custom deserializer for APIMod
|
||||
# https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
|
||||
unpacked = msgpack.unpackb(
|
||||
message_data, raw=False, strict_map_key=False, use_list=True
|
||||
)
|
||||
unpacked = m.decode(message_data)
|
||||
packet_type = PacketType(unpacked[0])
|
||||
if packet_type not in PACKETS:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
@@ -180,7 +176,7 @@ class MsgpackProtocol:
|
||||
)
|
||||
elif isinstance(packet, PingPacket):
|
||||
payload.pop(-1)
|
||||
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
|
||||
data = m.encode(payload)
|
||||
return MsgpackProtocol._encode_varint(len(data)) + data
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user