chore(merge): merge branch 'main' into feat/multiplayer-api

This commit is contained in:
MingxuanGame
2025-08-03 09:50:53 +00:00
13 changed files with 434 additions and 325 deletions

View File

@@ -1,114 +1,85 @@
from __future__ import annotations from __future__ import annotations
from enum import IntEnum from enum import IntEnum
from typing import Any, Literal from typing import ClassVar, Literal
from app.models.signalr import UserState from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
class _UserActivity(BaseModel): class _UserActivity(SignalRUnionMessage): ...
model_config = ConfigDict(serialize_by_alias=True)
type: Literal[
"ChoosingBeatmap",
"InSoloGame",
"WatchingReplay",
"SpectatingUser",
"SearchingForLobby",
"InLobby",
"InMultiplayerGame",
"SpectatingMultiplayerGame",
"InPlaylistGame",
"EditingBeatmap",
"ModdingBeatmap",
"TestingBeatmap",
"InDailyChallengeLobby",
"PlayingDailyChallenge",
] = Field(alias="$dtype")
value: Any | None = Field(alias="$value")
class ChoosingBeatmap(_UserActivity): class ChoosingBeatmap(_UserActivity):
type: Literal["ChoosingBeatmap"] = Field(alias="$dtype") union_type: ClassVar[Literal[11]] = 11
class InGameValue(BaseModel):
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
ruleset_id: int = Field(alias="RulesetID")
ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb")
class _InGame(_UserActivity): class _InGame(_UserActivity):
value: InGameValue = Field(alias="$value") beatmap_id: int
beatmap_display_title: str
ruleset_id: int
ruleset_playing_verb: str
class InSoloGame(_InGame): class InSoloGame(_InGame):
type: Literal["InSoloGame"] = Field(alias="$dtype") union_type: ClassVar[Literal[12]] = 12
class InMultiplayerGame(_InGame): class InMultiplayerGame(_InGame):
type: Literal["InMultiplayerGame"] = Field(alias="$dtype") union_type: ClassVar[Literal[23]] = 23
class SpectatingMultiplayerGame(_InGame): class SpectatingMultiplayerGame(_InGame):
type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype") union_type: ClassVar[Literal[24]] = 24
class InPlaylistGame(_InGame): class InPlaylistGame(_InGame):
type: Literal["InPlaylistGame"] = Field(alias="$dtype") union_type: ClassVar[Literal[31]] = 31
class EditingBeatmapValue(BaseModel): class PlayingDailyChallenge(_InGame):
beatmap_id: int = Field(alias="BeatmapID") union_type: ClassVar[Literal[52]] = 52
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
class EditingBeatmap(_UserActivity): class EditingBeatmap(_UserActivity):
type: Literal["EditingBeatmap"] = Field(alias="$dtype") union_type: ClassVar[Literal[41]] = 41
value: EditingBeatmapValue = Field(alias="$value") beatmap_id: int
beatmap_display_title: str
class TestingBeatmap(_UserActivity): class TestingBeatmap(EditingBeatmap):
type: Literal["TestingBeatmap"] = Field(alias="$dtype") union_type: ClassVar[Literal[43]] = 43
class ModdingBeatmap(_UserActivity): class ModdingBeatmap(EditingBeatmap):
type: Literal["ModdingBeatmap"] = Field(alias="$dtype") union_type: ClassVar[Literal[42]] = 42
class WatchingReplayValue(BaseModel):
score_id: int = Field(alias="ScoreID")
player_name: str = Field(alias="PlayerName")
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
class WatchingReplay(_UserActivity): class WatchingReplay(_UserActivity):
type: Literal["WatchingReplay"] = Field(alias="$dtype") union_type: ClassVar[Literal[13]] = 13
value: int | None = Field(alias="$value") # Replay ID score_id: int
player_name: str
beatmap_id: int
beatmap_display_title: str
class SpectatingUser(WatchingReplay): class SpectatingUser(WatchingReplay):
type: Literal["SpectatingUser"] = Field(alias="$dtype") union_type: ClassVar[Literal[14]] = 14
class SearchingForLobby(_UserActivity): class SearchingForLobby(_UserActivity):
type: Literal["SearchingForLobby"] = Field(alias="$dtype") union_type: ClassVar[Literal[21]] = 21
class InLobbyValue(BaseModel):
room_id: int = Field(alias="RoomID")
room_name: str = Field(alias="RoomName")
class InLobby(_UserActivity): class InLobby(_UserActivity):
type: Literal["InLobby"] = "InLobby" union_type: ClassVar[Literal[22]] = 22
room_id: int
room_name: str
class InDailyChallengeLobby(_UserActivity): class InDailyChallengeLobby(_UserActivity):
type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype") union_type: ClassVar[Literal[51]] = 51
UserActivity = ( UserActivity = (
@@ -128,23 +99,28 @@ UserActivity = (
) )
class MetadataClientState(UserState): class UserPresence(BaseModel):
user_activity: UserActivity | None = None activity: UserActivity | None = Field(
status: OnlineStatus | None = None default=None, metadata=SignalRMeta(use_upper_case=True)
)
def to_dict(self) -> dict[str, Any] | None: status: OnlineStatus | None = Field(
if self.status is None or self.status == OnlineStatus.OFFLINE: default=None, metadata=SignalRMeta(use_upper_case=True)
return None )
dumped = self.model_dump(by_alias=True, exclude_none=True)
return {
"Activity": dumped.get("user_activity"),
"Status": dumped.get("status"),
}
@property @property
def pushable(self) -> bool: def pushable(self) -> bool:
return self.status is not None and self.status != OnlineStatus.OFFLINE return self.status is not None and self.status != OnlineStatus.OFFLINE
@property
def for_push(self) -> "UserPresence | None":
return UserPresence(
activity=self.activity,
status=self.status,
)
class MetadataClientState(UserPresence, UserState): ...
class OnlineStatus(IntEnum): class OnlineStatus(IntEnum):
OFFLINE = 0 # 隐身 OFFLINE = 0 # 隐身

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum, IntEnum from enum import Enum
from typing import Literal, TypedDict from typing import Literal, TypedDict
from .mods import API_MODS, APIMod, init_mods from .mods import API_MODS, APIMod, init_mods
@@ -93,43 +93,6 @@ class HitResult(str, Enum):
) )
class HitResultInt(IntEnum):
PERFECT = 0
GREAT = 1
GOOD = 2
OK = 3
MEH = 4
MISS = 5
LARGE_TICK_HIT = 6
SMALL_TICK_HIT = 7
SLIDER_TAIL_HIT = 8
LARGE_BONUS = 9
SMALL_BONUS = 10
LARGE_TICK_MISS = 11
SMALL_TICK_MISS = 12
IGNORE_HIT = 13
IGNORE_MISS = 14
NONE = 15
COMBO_BREAK = 16
LEGACY_COMBO_INCREASE = 99
def is_hit(self) -> bool:
return self not in (
HitResultInt.NONE,
HitResultInt.IGNORE_MISS,
HitResultInt.COMBO_BREAK,
HitResultInt.LARGE_TICK_MISS,
HitResultInt.SMALL_TICK_MISS,
HitResultInt.MISS,
)
class LeaderboardType(Enum): class LeaderboardType(Enum):
GLOBAL = "global" GLOBAL = "global"
FRIENDS = "friend" FRIENDS = "friend"
@@ -138,7 +101,6 @@ class LeaderboardType(Enum):
ScoreStatistics = dict[HitResult, int] ScoreStatistics = dict[HitResult, int]
ScoreStatisticsInt = dict[HitResultInt, int]
class SoloScoreSubmissionInfo(BaseModel): class SoloScoreSubmissionInfo(BaseModel):
@@ -176,8 +138,8 @@ class SoloScoreSubmissionInfo(BaseModel):
class LegacyReplaySoloScoreInfo(TypedDict): class LegacyReplaySoloScoreInfo(TypedDict):
online_id: int online_id: int
mods: list[APIMod] mods: list[APIMod]
statistics: ScoreStatisticsInt statistics: ScoreStatistics
maximum_statistics: ScoreStatisticsInt maximum_statistics: ScoreStatistics
client_version: str client_version: str
rank: Rank rank: Rank
user_id: int user_id: int

View File

@@ -1,90 +1,23 @@
from __future__ import annotations from __future__ import annotations
import datetime from dataclasses import dataclass
from enum import Enum from typing import ClassVar
from typing import Any
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
BeforeValidator,
ConfigDict,
Field, Field,
TypeAdapter,
model_serializer,
model_validator,
) )
def serialize_msgpack(v: Any) -> Any: @dataclass
typ = v.__class__ class SignalRMeta:
if issubclass(typ, BaseModel): member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
return serialize_to_list(v) json_ignore: bool = False # implement of JsonIgnore (json) attribute
elif issubclass(typ, list): use_upper_case: bool = False # use upper CamelCase for field names
return TypeAdapter(
typ, config=ConfigDict(arbitrary_types_allowed=True)
).dump_python(v)
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif issubclass(typ, Enum):
list_ = list(typ)
return list_.index(v) if v in list_ else v.value
return v
def serialize_to_list(value: BaseModel) -> list[Any]: class SignalRUnionMessage(BaseModel):
data = [] union_type: ClassVar[int]
for field, info in value.__class__.model_fields.items():
data.append(serialize_msgpack(v=getattr(value, field)))
return data
def _by_index(v: Any, class_: type[Enum]):
enum_list = list(class_)
if not isinstance(v, int):
return v
if 0 <= v < len(enum_list):
return enum_list[v]
raise ValueError(
f"Value {v} is out of range for enum "
f"{class_.__name__} with {len(enum_list)} items"
)
def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator:
return BeforeValidator(lambda v: _by_index(v, enum_class))
def msgpack_union(v):
data = v[1]
data.append(v[0])
return data
def msgpack_union_dump(v: BaseModel) -> list[Any]:
_type = getattr(v, "type", None)
if _type is None:
raise ValueError(
f"Model {v.__class__.__name__} does not have a '_type' attribute"
)
return [_type, serialize_to_list(v)]
class MessagePackArrayModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="before")
@classmethod
def unpack(cls, v: Any) -> Any:
if isinstance(v, list):
fields = list(cls.model_fields.keys())
if len(v) != len(fields):
raise ValueError(f"Expected list of length {len(fields)}, got {len(v)}")
return dict(zip(fields, v))
return v
@model_serializer
def serialize(self) -> list[Any]:
return serialize_to_list(self)
class Transport(BaseModel): class Transport(BaseModel):

View File

@@ -5,14 +5,14 @@ from enum import IntEnum
from typing import Any from typing import Any
from app.models.beatmap import BeatmapRankStatus from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from .score import ( from .score import (
ScoreStatisticsInt, ScoreStatistics,
) )
from .signalr import MessagePackArrayModel, UserState from .signalr import SignalRMeta, UserState
from msgpack_lazer_api import APIMod from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
class SpectatedUserState(IntEnum): class SpectatedUserState(IntEnum):
@@ -24,14 +24,12 @@ class SpectatedUserState(IntEnum):
Quit = 5 Quit = 5
class SpectatorState(MessagePackArrayModel): class SpectatorState(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
beatmap_id: int | None = None beatmap_id: int | None = None
ruleset_id: int | None = None # 0,1,2,3 ruleset_id: int | None = None # 0,1,2,3
mods: list[APIMod] = Field(default_factory=list) mods: list[APIMod] = Field(default_factory=list)
state: SpectatedUserState state: SpectatedUserState
maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict) maximum_statistics: ScoreStatistics = Field(default_factory=dict)
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SpectatorState): if not isinstance(other, SpectatorState):
@@ -44,22 +42,20 @@ class SpectatorState(MessagePackArrayModel):
) )
class ScoreProcessorStatistics(MessagePackArrayModel): class ScoreProcessorStatistics(BaseModel):
base_score: int base_score: float
maximum_base_score: int maximum_base_score: float
accuracy_judgement_count: int accuracy_judgement_count: int
combo_portion: float combo_portion: float
bouns_portion: float bonus_portion: float
class FrameHeader(MessagePackArrayModel): class FrameHeader(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
total_score: int total_score: int
acc: float accuracy: float
combo: int combo: int
max_combo: int max_combo: int
statistics: ScoreStatisticsInt = Field(default_factory=dict) statistics: ScoreStatistics = Field(default_factory=dict)
score_processor_statistics: ScoreProcessorStatistics score_processor_statistics: ScoreProcessorStatistics
received_time: datetime.datetime received_time: datetime.datetime
mods: list[APIMod] = Field(default_factory=list) mods: list[APIMod] = Field(default_factory=list)
@@ -87,14 +83,18 @@ class FrameHeader(MessagePackArrayModel):
# SMOKE = 16 # SMOKE = 16
class LegacyReplayFrame(MessagePackArrayModel): class LegacyReplayFrame(BaseModel):
time: float # from ReplayFrame,the parent of LegacyReplayFrame time: float # from ReplayFrame,the parent of LegacyReplayFrame
x: float | None = None mouse_x: float | None = None
y: float | None = None mouse_y: float | None = None
button_state: int button_state: int
header: FrameHeader | None = Field(
default=None, metadata=[SignalRMeta(member_ignore=True)]
)
class FrameDataBundle(MessagePackArrayModel):
class FrameDataBundle(BaseModel):
header: FrameHeader header: FrameHeader
frames: list[LegacyReplayFrame] frames: list[LegacyReplayFrame]
@@ -106,18 +106,16 @@ class APIUser(BaseModel):
class ScoreInfo(BaseModel): class ScoreInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
mods: list[APIMod] mods: list[APIMod]
user: APIUser user: APIUser
ruleset: int ruleset: int
maximum_statistics: ScoreStatisticsInt maximum_statistics: ScoreStatistics
id: int | None = None id: int | None = None
total_score: int | None = None total_score: int | None = None
acc: float | None = None accuracy: float | None = None
max_combo: int | None = None max_combo: int | None = None
combo: int | None = None combo: int | None = None
statistics: ScoreStatisticsInt = Field(default_factory=dict) statistics: ScoreStatistics = Field(default_factory=dict)
class StoreScore(BaseModel): class StoreScore(BaseModel):

View File

@@ -2,15 +2,13 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio import asyncio
from enum import Enum
import inspect
import time import time
from typing import Any from typing import Any
from app.config import settings from app.config import settings
from app.exception import InvokeException from app.exception import InvokeException
from app.log import logger from app.log import logger
from app.models.signalr import UserState, _by_index from app.models.signalr import UserState
from app.signalr.packet import ( from app.signalr.packet import (
ClosePacket, ClosePacket,
CompletionPacket, CompletionPacket,
@@ -23,7 +21,6 @@ from app.signalr.store import ResultStore
from app.signalr.utils import get_signature from app.signalr.utils import get_signature
from fastapi import WebSocket from fastapi import WebSocket
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect from starlette.websockets import WebSocketDisconnect
@@ -51,7 +48,7 @@ class Client:
self.connection_id = connection_id self.connection_id = connection_id
self.connection_token = connection_token self.connection_token = connection_token
self.connection = connection self.connection = connection
self.procotol = protocol self.protocol = protocol
self._listen_task: asyncio.Task | None = None self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None self._ping_task: asyncio.Task | None = None
self._store = ResultStore() self._store = ResultStore()
@@ -64,14 +61,14 @@ class Client:
return int(self.connection_id) return int(self.connection_id)
async def send_packet(self, packet: Packet): async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.procotol.encode(packet)) await self.connection.send_bytes(self.protocol.encode(packet))
async def receive_packets(self) -> list[Packet]: async def receive_packets(self) -> list[Packet]:
message = await self.connection.receive() message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode() d = message.get("bytes") or message.get("text", "").encode()
if not d: if not d:
return [] return []
return self.procotol.decode(d) return self.protocol.decode(d)
async def _ping(self): async def _ping(self):
while True: while True:
@@ -265,14 +262,9 @@ class Hub[TState: UserState]:
for name, param in signature.parameters.items(): for name, param in signature.parameters.items():
if name == "self" or param.annotation is Client: if name == "self" or param.annotation is Client:
continue continue
if issubclass(param.annotation, BaseModel): call_params.append(
call_params.append(param.annotation.model_validate(args.pop(0))) client.protocol.validate_object(args.pop(0), param.annotation)
elif inspect.isclass(param.annotation) and issubclass( )
param.annotation, Enum
):
call_params.append(_by_index(args.pop(0), param.annotation))
else:
call_params.append(args.pop(0))
return await method_(client, *call_params) return await method_(client, *call_params)
async def call(self, client: Client, method: str, *args: Any) -> Any: async def call(self, client: Client, method: str, *args: Any) -> Any:

View File

@@ -11,7 +11,6 @@ from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActiv
from .hub import Client, Hub from .hub import Client, Hub
from pydantic import TypeAdapter
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -31,7 +30,7 @@ class MetadataHub(Hub[MetadataClientState]):
) -> set[Coroutine]: ) -> set[Coroutine]:
if store is not None and not store.pushable: if store is not None and not store.pushable:
return set() return set()
data = store.to_dict() if store else None data = store.for_push if store else None
return { return {
self.broadcast_group_call( self.broadcast_group_call(
self.online_presence_watchers_group(), self.online_presence_watchers_group(),
@@ -102,7 +101,9 @@ class MetadataHub(Hub[MetadataClientState]):
self.friend_presence_watchers_group(friend_id), self.friend_presence_watchers_group(friend_id),
"FriendPresenceUpdated", "FriendPresenceUpdated",
friend_id, friend_id,
friend_state.to_dict(), friend_state.for_push
if friend_state.pushable
else None,
) )
) )
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@@ -122,27 +123,24 @@ class MetadataHub(Hub[MetadataClientState]):
client, client,
"UserPresenceUpdated", "UserPresenceUpdated",
user_id, user_id,
store.to_dict(), store.for_push,
) )
) )
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None: async def UpdateActivity(
self, client: Client, activity: UserActivity | None
) -> None:
user_id = int(client.connection_id) user_id = int(client.connection_id)
activity = (
TypeAdapter(UserActivity).validate_python(activity_dict)
if activity_dict
else None
)
store = self.get_or_create_state(client) store = self.get_or_create_state(client)
store.user_activity = activity store.activity = activity
tasks = self.broadcast_tasks(user_id, store) tasks = self.broadcast_tasks(user_id, store)
tasks.add( tasks.add(
self.call_noblock( self.call_noblock(
client, client,
"UserPresenceUpdated", "UserPresenceUpdated",
user_id, user_id,
store.to_dict(), store.for_push,
) )
) )
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@@ -154,7 +152,7 @@ class MetadataHub(Hub[MetadataClientState]):
client, client,
"UserPresenceUpdated", "UserPresenceUpdated",
user_id, user_id,
store.to_dict(), store,
) )
for user_id, store in self.state.items() for user_id, store in self.state.items()
if store.pushable if store.pushable

View File

@@ -13,8 +13,7 @@ from app.database.score_token import ScoreToken
from app.dependencies.database import engine from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int from app.models.mods import mods_to_int
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
from app.models.signalr import serialize_to_list
from app.models.spectator_hub import ( from app.models.spectator_hub import (
APIUser, APIUser,
FrameDataBundle, FrameDataBundle,
@@ -69,8 +68,8 @@ def save_replay(
md5: str, md5: str,
username: str, username: str,
score: Score, score: Score,
statistics: ScoreStatisticsInt, statistics: ScoreStatistics,
maximum_statistics: ScoreStatisticsInt, maximum_statistics: ScoreStatistics,
frames: list[LegacyReplayFrame], frames: list[LegacyReplayFrame],
) -> None: ) -> None:
data = bytearray() data = bytearray()
@@ -107,8 +106,8 @@ def save_replay(
last_time = 0 last_time = 0
for frame in frames: for frame in frames:
frame_strs.append( frame_strs.append(
f"{frame.time - last_time}|{frame.x or 0.0}" f"{frame.time - last_time}|{frame.mouse_x or 0.0}"
f"|{frame.y or 0.0}|{frame.button_state}" f"|{frame.mouse_y or 0.0}|{frame.button_state}"
) )
last_time = frame.time last_time = frame.time
frame_strs.append("-12345|0|0|0") frame_strs.append("-12345|0|0|0")
@@ -165,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]):
async def on_client_connect(self, client: Client) -> None: async def on_client_connect(self, client: Client) -> None:
tasks = [ tasks = [
self.call_noblock( self.call_noblock(client, "UserBeganPlaying", user_id, store.state)
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
)
for user_id, store in self.state.items() for user_id, store in self.state.items()
if store.state is not None if store.state is not None
] ]
@@ -214,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id), self.group_id(user_id),
"UserBeganPlaying", "UserBeganPlaying",
user_id, user_id,
serialize_to_list(state), state,
) )
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None: async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
@@ -222,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]):
state = self.get_or_create_state(client) state = self.get_or_create_state(client)
if not state.score: if not state.score:
return return
state.score.score_info.acc = frame_data.header.acc state.score.score_info.accuracy = frame_data.header.accuracy
state.score.score_info.combo = frame_data.header.combo state.score.score_info.combo = frame_data.header.combo
state.score.score_info.max_combo = frame_data.header.max_combo state.score.score_info.max_combo = frame_data.header.max_combo
state.score.score_info.statistics = frame_data.header.statistics state.score.score_info.statistics = frame_data.header.statistics
@@ -233,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id), self.group_id(user_id),
"UserSentFrames", "UserSentFrames",
user_id, user_id,
frame_data.model_dump(), frame_data,
) )
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None: async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
@@ -316,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]):
self.group_id(user_id), self.group_id(user_id),
"UserFinishedPlaying", "UserFinishedPlaying",
user_id, user_id,
serialize_to_list(state) if state else None, state,
) )
async def StartWatchingUser(self, client: Client, target_id: int) -> None: async def StartWatchingUser(self, client: Client, target_id: int) -> None:
@@ -327,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]):
client, client,
"UserBeganPlaying", "UserBeganPlaying",
target_id, target_id,
serialize_to_list(target_store.state), target_store.state,
) )
store = self.get_or_create_state(client) store = self.get_or_create_state(client)
store.watched_user.add(target_id) store.watched_user.add(target_id)

View File

@@ -1,16 +1,24 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum import datetime
from enum import Enum, IntEnum
import inspect
import json import json
from types import NoneType, UnionType
from typing import ( from typing import (
Any, Any,
Protocol as TypingProtocol, Protocol as TypingProtocol,
Union,
get_args,
get_origin,
) )
from app.models.signalr import serialize_msgpack from app.models.signalr import SignalRMeta, SignalRUnionMessage
from app.utils import camel_to_snake, snake_to_camel
import msgpack_lazer_api as m import msgpack_lazer_api as m
from pydantic import BaseModel
SEP = b"\x1e" SEP = b"\x1e"
@@ -75,8 +83,61 @@ class Protocol(TypingProtocol):
@staticmethod @staticmethod
def encode(packet: Packet) -> bytes: ... def encode(packet: Packet) -> bytes: ...
@classmethod
def validate_object(cls, v: Any, typ: type) -> Any: ...
class MsgpackProtocol: class MsgpackProtocol:
@classmethod
def serialize_msgpack(cls, v: Any) -> Any:
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_to_list(v)
elif issubclass(typ, list):
return [cls.serialize_msgpack(item) for item in v]
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif isinstance(v, dict):
return {
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
for k, value in v.items()
}
elif issubclass(typ, Enum):
list_ = list(typ)
return list_.index(v) if v in list_ else v.value
return v
@classmethod
def serialize_to_list(cls, value: BaseModel) -> list[Any]:
values = []
for field, info in value.__class__.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.member_ignore:
continue
values.append(cls.serialize_msgpack(v=getattr(value, field)))
if issubclass(value.__class__, SignalRUnionMessage):
return [value.__class__.union_type, values]
else:
return values
@staticmethod
def process_object(v: Any, typ: type[BaseModel]) -> Any:
if isinstance(v, list):
d = {}
for i, f in enumerate(typ.model_fields.items()):
field, info = f
if info.exclude:
continue
anno = info.annotation
if anno is None:
d[camel_to_snake(field)] = v[i]
continue
d[field] = MsgpackProtocol.validate_object(v[i], anno)
return d
return v
@staticmethod @staticmethod
def _encode_varint(value: int) -> bytes: def _encode_varint(value: int) -> bytes:
result = [] result = []
@@ -142,6 +203,49 @@ class MsgpackProtocol:
] ]
raise ValueError(f"Unsupported packet type: {packet_type}") raise ValueError(f"Unsupported packet type: {packet_type}")
@classmethod
def validate_object(cls, v: Any, typ: type) -> Any:
if issubclass(typ, BaseModel):
return typ.model_validate(obj=cls.process_object(v, typ))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return v[0]
elif isinstance(v, list):
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
elif get_origin(typ) is dict:
return {
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
v, get_args(typ)[1]
)
for k, v in v.items()
}
elif (origin := get_origin(typ)) is Union or origin is UnionType:
args = get_args(typ)
if len(args) == 2 and NoneType in args:
non_none_args = [arg for arg in args if arg is not NoneType]
if len(non_none_args) == 1:
if v is None:
return None
return cls.validate_object(v, non_none_args[0])
# suppose use `MessagePack-CSharp Union | None`
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(issubclass(arg, SignalRUnionMessage) for arg in args):
raise ValueError(
f"Cannot validate {v} to {typ}, "
"only SignalRUnionMessage subclasses are supported"
)
union_type = v[0]
for arg in args:
assert issubclass(arg, SignalRUnionMessage)
if arg.union_type == union_type:
return cls.validate_object(v[1], arg)
return v
@staticmethod @staticmethod
def encode(packet: Packet) -> bytes: def encode(packet: Packet) -> bytes:
payload = [packet.type.value, packet.header or {}] payload = [packet.type.value, packet.header or {}]
@@ -153,7 +257,9 @@ class MsgpackProtocol:
] ]
) )
if packet.arguments is not None: if packet.arguments is not None:
payload.append([serialize_msgpack(arg) for arg in packet.arguments]) payload.append(
[MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments]
)
if packet.stream_ids is not None: if packet.stream_ids is not None:
payload.append(packet.stream_ids) payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket): elif isinstance(packet, CompletionPacket):
@@ -166,7 +272,9 @@ class MsgpackProtocol:
[ [
packet.invocation_id, packet.invocation_id,
result_kind, result_kind,
packet.error or packet.result or None, packet.error
or MsgpackProtocol.serialize_msgpack(packet.result)
or None,
] ]
) )
elif isinstance(packet, ClosePacket): elif isinstance(packet, ClosePacket):
@@ -183,6 +291,62 @@ class MsgpackProtocol:
class JSONProtocol: class JSONProtocol:
@classmethod
def serialize_to_json(cls, v: Any):
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_model(v)
elif isinstance(v, dict):
return {
cls.serialize_to_json(k): cls.serialize_to_json(value)
for k, value in v.items()
}
elif isinstance(v, list):
return [cls.serialize_to_json(item) for item in v]
elif isinstance(v, datetime.datetime):
return v.isoformat()
elif isinstance(v, Enum):
return v.value
return v
@classmethod
def serialize_model(cls, v: BaseModel) -> dict[str, Any]:
d = {}
for field, info in v.__class__.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.json_ignore:
continue
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = (
cls.serialize_to_json(getattr(v, field))
)
if issubclass(v.__class__, SignalRUnionMessage):
return {
"$dtype": v.__class__.__name__,
"$value": d,
}
return d
@staticmethod
def process_object(
v: Any, typ: type[BaseModel], from_union: bool = False
) -> dict[str, Any]:
d = {}
for field, info in typ.model_fields.items():
metadata = next(
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
)
if metadata and metadata.json_ignore:
continue
value = v.get(snake_to_camel(field, not from_union))
anno = typ.model_fields[field].annotation
if anno is None:
d[field] = value
continue
d[field] = JSONProtocol.validate_object(value, anno)
return d
@staticmethod @staticmethod
def decode(input: bytes) -> list[Packet]: def decode(input: bytes) -> list[Packet]:
packets_raw = input.removesuffix(SEP).split(SEP) packets_raw = input.removesuffix(SEP).split(SEP)
@@ -227,6 +391,52 @@ class JSONProtocol:
] ]
raise ValueError(f"Unsupported packet type: {packet_type}") raise ValueError(f"Unsupported packet type: {packet_type}")
@classmethod
def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
if issubclass(typ, BaseModel):
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return datetime.datetime.fromisoformat(v)
elif isinstance(v, list):
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
elif get_origin(typ) is dict:
return {
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
v, get_args(typ)[1]
)
for k, v in v.items()
}
elif (origin := get_origin(typ)) is Union or origin is UnionType:
args = get_args(typ)
if len(args) == 2 and NoneType in args:
non_none_args = [arg for arg in args if arg is not NoneType]
if len(non_none_args) == 1:
if v is None:
return None
return cls.validate_object(v, non_none_args[0])
# suppose use `MessagePack-CSharp Union | None`
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
):
raise ValueError(
f"Cannot validate {v} to {typ}, "
"only SignalRUnionMessage subclasses are supported"
)
# https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
union_type = v["$dtype"]
for arg in args:
assert issubclass(arg, SignalRUnionMessage)
if arg.__name__ == union_type:
return cls.validate_object(v["$value"], arg, True)
return v
@staticmethod @staticmethod
def encode(packet: Packet) -> bytes: def encode(packet: Packet) -> bytes:
payload: dict[str, Any] = { payload: dict[str, Any] = {
@@ -243,7 +453,9 @@ class JSONProtocol:
if packet.invocation_id is not None: if packet.invocation_id is not None:
payload["invocationId"] = packet.invocation_id payload["invocationId"] = packet.invocation_id
if packet.arguments is not None: if packet.arguments is not None:
payload["arguments"] = packet.arguments payload["arguments"] = [
JSONProtocol.serialize_to_json(arg) for arg in packet.arguments
]
if packet.stream_ids is not None: if packet.stream_ids is not None:
payload["streamIds"] = packet.stream_ids payload["streamIds"] = packet.stream_ids
elif isinstance(packet, CompletionPacket): elif isinstance(packet, CompletionPacket):
@@ -255,7 +467,7 @@ class JSONProtocol:
if packet.error is not None: if packet.error is not None:
payload["error"] = packet.error payload["error"] = packet.error
if packet.result is not None: if packet.result is not None:
payload["result"] = packet.result payload["result"] = JSONProtocol.serialize_to_json(packet.result)
elif isinstance(packet, PingPacket): elif isinstance(packet, PingPacket):
pass pass
elif isinstance(packet, ClosePacket): elif isinstance(packet, ClosePacket):

View File

@@ -4,3 +4,55 @@ from __future__ import annotations
def unix_timestamp_to_windows(timestamp: int) -> int: def unix_timestamp_to_windows(timestamp: int) -> int:
"""Convert a Unix timestamp to a Windows timestamp.""" """Convert a Unix timestamp to a Windows timestamp."""
return (timestamp + 62135596800) * 10_000_000 return (timestamp + 62135596800) * 10_000_000
def camel_to_snake(name: str) -> str:
"""Convert a camelCase string to snake_case."""
result = []
last_chr = ""
for char in name:
if char.isupper():
if not last_chr.isupper() and result:
result.append("_")
result.append(char.lower())
else:
result.append(char)
last_chr = char
return "".join(result)
def snake_to_camel(name: str, lower_case: bool = True) -> str:
"""Convert a snake_case string to camelCase."""
if not name:
return name
parts = name.split("_")
if not parts:
return name
# 常见缩写词列表
abbreviations = {
"id",
"url",
"api",
"http",
"https",
"xml",
"json",
"css",
"html",
"sql",
"db",
}
result = []
for part in parts:
if part.lower() in abbreviations:
result.append(part.upper())
else:
if result or not lower_case:
result.append(part.capitalize())
else:
result.append(part.lower())
return "".join(result)

View File

@@ -1,11 +1,4 @@
from typing import Any from typing import Any
class APIMod:
def __init__(self, acronym: str, settings: dict[str, Any]) -> None: ...
@property
def acronym(self) -> str: ...
@property
def settings(self) -> dict[str, Any]: ...
def encode(obj: Any) -> bytes: ... def encode(obj: Any) -> bytes: ...
def decode(data: bytes) -> Any: ... def decode(data: bytes) -> Any: ...

View File

@@ -1,8 +1,6 @@
use crate::APIMod;
use chrono::{TimeZone, Utc}; use chrono::{TimeZone, Utc};
use pyo3::types::PyDict; use pyo3::types::PyDict;
use pyo3::{prelude::*, IntoPyObjectExt}; use pyo3::{prelude::*, IntoPyObjectExt};
use std::collections::HashMap;
use std::io::Read; use std::io::Read;
pub fn read_object( pub fn read_object(
@@ -206,13 +204,12 @@ fn read_array(
let obj1 = read_object(py, cursor, false)?; let obj1 = read_object(py, cursor, false)?;
if obj1.extract::<String>(py).map_or(false, |k| k.len() == 2) { if obj1.extract::<String>(py).map_or(false, |k| k.len() == 2) {
let obj2 = read_object(py, cursor, true)?; let obj2 = read_object(py, cursor, true)?;
return Ok(APIMod {
acronym: obj1.extract::<String>(py)?, let api_mod_dict = PyDict::new(py);
settings: obj2.extract::<HashMap<String, PyObject>>(py)?, api_mod_dict.set_item("acronym", obj1)?;
} api_mod_dict.set_item("settings", obj2)?;
.into_pyobject(py)?
.into_any() return Ok(api_mod_dict.into_pyobject(py)?.into_any().unbind());
.unbind());
} else { } else {
items.push(obj1); items.push(obj1);
i += 1; i += 1;

View File

@@ -1,8 +1,7 @@
use crate::APIMod; use chrono::{DateTime, Utc};
use chrono::{DateTime, Utc}; use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyResult, PyStringMethods};
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyStringMethods};
use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString}; use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString};
use pyo3::{Bound, PyAny, PyRef, Python}; use pyo3::{Bound, PyAny};
use std::io::Write; use std::io::Write;
fn write_list(buf: &mut Vec<u8>, obj: &Bound<'_, PyList>) { fn write_list(buf: &mut Vec<u8>, obj: &Bound<'_, PyList>) {
@@ -61,19 +60,42 @@ fn write_hashmap(buf: &mut Vec<u8>, obj: &Bound<'_, PyDict>) {
} }
} }
fn write_nil(buf: &mut Vec<u8>){ fn write_nil(buf: &mut Vec<u8>) {
rmp::encode::write_nil(buf).unwrap(); rmp::encode::write_nil(buf).unwrap();
} }
// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs fn is_api_mod(dict: &Bound<'_, PyDict>) -> bool {
fn write_api_mod(buf: &mut Vec<u8>, api_mod: PyRef<APIMod>) { if let Ok(Some(acronym)) = dict.get_item("acronym") {
rmp::encode::write_array_len(buf, 2).unwrap(); if let Ok(acronym_str) = acronym.extract::<String>() {
rmp::encode::write_str(buf, &api_mod.acronym).unwrap(); return acronym_str.len() == 2;
rmp::encode::write_array_len(buf, api_mod.settings.len() as u32).unwrap(); }
for (k, v) in api_mod.settings.iter() {
rmp::encode::write_str(buf, k).unwrap();
Python::with_gil(|py| write_object(buf, &v.bind(py)));
} }
false
}
// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
fn write_api_mod(buf: &mut Vec<u8>, api_mod: &Bound<'_, PyDict>) -> PyResult<()> {
let acronym = api_mod
.get_item("acronym")?
.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("APIMod missing 'acronym' field"))?;
let acronym_str = acronym.extract::<String>()?;
let settings = api_mod
.get_item("settings")?
.unwrap_or_else(|| PyDict::new(acronym.py()).into_any());
let settings_dict = settings.downcast::<PyDict>()?;
rmp::encode::write_array_len(buf, 2).unwrap();
rmp::encode::write_str(buf, &acronym_str).unwrap();
rmp::encode::write_array_len(buf, settings_dict.len() as u32).unwrap();
for (k, v) in settings_dict.iter() {
let key_str = k.extract::<String>()?;
rmp::encode::write_str(buf, &key_str).unwrap();
write_object(buf, &v);
}
Ok(())
} }
fn write_datetime(buf: &mut Vec<u8>, obj: &Bound<'_, PyDateTime>) { fn write_datetime(buf: &mut Vec<u8>, obj: &Bound<'_, PyDateTime>) {
@@ -111,21 +133,23 @@ pub fn write_object(buf: &mut Vec<u8>, obj: &Bound<'_, PyAny>) {
} else if let Ok(string) = obj.downcast::<PyString>() { } else if let Ok(string) = obj.downcast::<PyString>() {
write_string(buf, string); write_string(buf, string);
} else if let Ok(boolean) = obj.downcast::<PyBool>() { } else if let Ok(boolean) = obj.downcast::<PyBool>() {
write_bool(buf, boolean); write_bool(buf, boolean);
} else if let Ok(float) = obj.downcast::<PyFloat>() { } else if let Ok(float) = obj.downcast::<PyFloat>() {
write_float(buf, float); write_float(buf, float);
} else if let Ok(integer) = obj.downcast::<PyInt>() { } else if let Ok(integer) = obj.downcast::<PyInt>() {
write_integer(buf, integer); write_integer(buf, integer);
} else if let Ok(bytes) = obj.downcast::<PyBytes>() { } else if let Ok(bytes) = obj.downcast::<PyBytes>() {
write_bin(buf, bytes); write_bin(buf, bytes);
} else if let Ok(dict) = obj.downcast::<PyDict>() { } else if let Ok(dict) = obj.downcast::<PyDict>() {
write_hashmap(buf, dict); if is_api_mod(dict) {
write_api_mod(buf, dict).unwrap_or_else(|_| write_hashmap(buf, dict));
} else {
write_hashmap(buf, dict);
}
} else if let Ok(_none) = obj.downcast::<PyNone>() { } else if let Ok(_none) = obj.downcast::<PyNone>() {
write_nil(buf); write_nil(buf);
} else if let Ok(datetime) = obj.downcast::<PyDateTime>() { } else if let Ok(datetime) = obj.downcast::<PyDateTime>() {
write_datetime(buf, datetime); write_datetime(buf, datetime);
} else if let Ok(api_mod) = obj.extract::<PyRef<APIMod>>() {
write_api_mod(buf, api_mod);
} else { } else {
panic!("Unsupported type"); panic!("Unsupported type");
} }

View File

@@ -2,30 +2,6 @@ mod decode;
mod encode; mod encode;
use pyo3::prelude::*; use pyo3::prelude::*;
use std::collections::HashMap;
#[pyclass]
struct APIMod {
#[pyo3(get, set)]
acronym: String,
#[pyo3(get, set)]
settings: HashMap<String, PyObject>,
}
#[pymethods]
impl APIMod {
#[new]
fn new(acronym: String, settings: HashMap<String, PyObject>) -> Self {
APIMod { acronym, settings }
}
fn __repr__(&self) -> String {
format!(
"APIMod(acronym='{}', settings={:?})",
self.acronym, self.settings
)
}
}
#[pyfunction] #[pyfunction]
#[pyo3(name = "encode")] #[pyo3(name = "encode")]
@@ -46,6 +22,5 @@ fn decode_py(py: Python, data: &[u8]) -> PyResult<PyObject> {
fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> { fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(encode_py, m)?)?; m.add_function(wrap_pyfunction!(encode_py, m)?)?;
m.add_function(wrap_pyfunction!(decode_py, m)?)?; m.add_function(wrap_pyfunction!(decode_py, m)?)?;
m.add_class::<APIMod>()?;
Ok(()) Ok(())
} }