chore(merge): merge branch 'main' into feat/multiplayer-api
This commit is contained in:
@@ -1,114 +1,85 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any, Literal
|
from typing import ClassVar, Literal
|
||||||
|
|
||||||
from app.models.signalr import UserState
|
from app.models.signalr import SignalRMeta, SignalRUnionMessage, UserState
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class _UserActivity(BaseModel):
|
class _UserActivity(SignalRUnionMessage): ...
|
||||||
model_config = ConfigDict(serialize_by_alias=True)
|
|
||||||
type: Literal[
|
|
||||||
"ChoosingBeatmap",
|
|
||||||
"InSoloGame",
|
|
||||||
"WatchingReplay",
|
|
||||||
"SpectatingUser",
|
|
||||||
"SearchingForLobby",
|
|
||||||
"InLobby",
|
|
||||||
"InMultiplayerGame",
|
|
||||||
"SpectatingMultiplayerGame",
|
|
||||||
"InPlaylistGame",
|
|
||||||
"EditingBeatmap",
|
|
||||||
"ModdingBeatmap",
|
|
||||||
"TestingBeatmap",
|
|
||||||
"InDailyChallengeLobby",
|
|
||||||
"PlayingDailyChallenge",
|
|
||||||
] = Field(alias="$dtype")
|
|
||||||
value: Any | None = Field(alias="$value")
|
|
||||||
|
|
||||||
|
|
||||||
class ChoosingBeatmap(_UserActivity):
|
class ChoosingBeatmap(_UserActivity):
|
||||||
type: Literal["ChoosingBeatmap"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[11]] = 11
|
||||||
|
|
||||||
|
|
||||||
class InGameValue(BaseModel):
|
|
||||||
beatmap_id: int = Field(alias="BeatmapID")
|
|
||||||
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
|
|
||||||
ruleset_id: int = Field(alias="RulesetID")
|
|
||||||
ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb")
|
|
||||||
|
|
||||||
|
|
||||||
class _InGame(_UserActivity):
|
class _InGame(_UserActivity):
|
||||||
value: InGameValue = Field(alias="$value")
|
beatmap_id: int
|
||||||
|
beatmap_display_title: str
|
||||||
|
ruleset_id: int
|
||||||
|
ruleset_playing_verb: str
|
||||||
|
|
||||||
|
|
||||||
class InSoloGame(_InGame):
|
class InSoloGame(_InGame):
|
||||||
type: Literal["InSoloGame"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[12]] = 12
|
||||||
|
|
||||||
|
|
||||||
class InMultiplayerGame(_InGame):
|
class InMultiplayerGame(_InGame):
|
||||||
type: Literal["InMultiplayerGame"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[23]] = 23
|
||||||
|
|
||||||
|
|
||||||
class SpectatingMultiplayerGame(_InGame):
|
class SpectatingMultiplayerGame(_InGame):
|
||||||
type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[24]] = 24
|
||||||
|
|
||||||
|
|
||||||
class InPlaylistGame(_InGame):
|
class InPlaylistGame(_InGame):
|
||||||
type: Literal["InPlaylistGame"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[31]] = 31
|
||||||
|
|
||||||
|
|
||||||
class EditingBeatmapValue(BaseModel):
|
class PlayingDailyChallenge(_InGame):
|
||||||
beatmap_id: int = Field(alias="BeatmapID")
|
union_type: ClassVar[Literal[52]] = 52
|
||||||
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
|
|
||||||
|
|
||||||
|
|
||||||
class EditingBeatmap(_UserActivity):
|
class EditingBeatmap(_UserActivity):
|
||||||
type: Literal["EditingBeatmap"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[41]] = 41
|
||||||
value: EditingBeatmapValue = Field(alias="$value")
|
beatmap_id: int
|
||||||
|
beatmap_display_title: str
|
||||||
|
|
||||||
|
|
||||||
class TestingBeatmap(_UserActivity):
|
class TestingBeatmap(EditingBeatmap):
|
||||||
type: Literal["TestingBeatmap"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[43]] = 43
|
||||||
|
|
||||||
|
|
||||||
class ModdingBeatmap(_UserActivity):
|
class ModdingBeatmap(EditingBeatmap):
|
||||||
type: Literal["ModdingBeatmap"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[42]] = 42
|
||||||
|
|
||||||
|
|
||||||
class WatchingReplayValue(BaseModel):
|
|
||||||
score_id: int = Field(alias="ScoreID")
|
|
||||||
player_name: str = Field(alias="PlayerName")
|
|
||||||
beatmap_id: int = Field(alias="BeatmapID")
|
|
||||||
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
|
|
||||||
|
|
||||||
|
|
||||||
class WatchingReplay(_UserActivity):
|
class WatchingReplay(_UserActivity):
|
||||||
type: Literal["WatchingReplay"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[13]] = 13
|
||||||
value: int | None = Field(alias="$value") # Replay ID
|
score_id: int
|
||||||
|
player_name: str
|
||||||
|
beatmap_id: int
|
||||||
|
beatmap_display_title: str
|
||||||
|
|
||||||
|
|
||||||
class SpectatingUser(WatchingReplay):
|
class SpectatingUser(WatchingReplay):
|
||||||
type: Literal["SpectatingUser"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[14]] = 14
|
||||||
|
|
||||||
|
|
||||||
class SearchingForLobby(_UserActivity):
|
class SearchingForLobby(_UserActivity):
|
||||||
type: Literal["SearchingForLobby"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[21]] = 21
|
||||||
|
|
||||||
|
|
||||||
class InLobbyValue(BaseModel):
|
|
||||||
room_id: int = Field(alias="RoomID")
|
|
||||||
room_name: str = Field(alias="RoomName")
|
|
||||||
|
|
||||||
|
|
||||||
class InLobby(_UserActivity):
|
class InLobby(_UserActivity):
|
||||||
type: Literal["InLobby"] = "InLobby"
|
union_type: ClassVar[Literal[22]] = 22
|
||||||
|
room_id: int
|
||||||
|
room_name: str
|
||||||
|
|
||||||
|
|
||||||
class InDailyChallengeLobby(_UserActivity):
|
class InDailyChallengeLobby(_UserActivity):
|
||||||
type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype")
|
union_type: ClassVar[Literal[51]] = 51
|
||||||
|
|
||||||
|
|
||||||
UserActivity = (
|
UserActivity = (
|
||||||
@@ -128,23 +99,28 @@ UserActivity = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetadataClientState(UserState):
|
class UserPresence(BaseModel):
|
||||||
user_activity: UserActivity | None = None
|
activity: UserActivity | None = Field(
|
||||||
status: OnlineStatus | None = None
|
default=None, metadata=SignalRMeta(use_upper_case=True)
|
||||||
|
)
|
||||||
def to_dict(self) -> dict[str, Any] | None:
|
status: OnlineStatus | None = Field(
|
||||||
if self.status is None or self.status == OnlineStatus.OFFLINE:
|
default=None, metadata=SignalRMeta(use_upper_case=True)
|
||||||
return None
|
)
|
||||||
dumped = self.model_dump(by_alias=True, exclude_none=True)
|
|
||||||
return {
|
|
||||||
"Activity": dumped.get("user_activity"),
|
|
||||||
"Status": dumped.get("status"),
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pushable(self) -> bool:
|
def pushable(self) -> bool:
|
||||||
return self.status is not None and self.status != OnlineStatus.OFFLINE
|
return self.status is not None and self.status != OnlineStatus.OFFLINE
|
||||||
|
|
||||||
|
@property
|
||||||
|
def for_push(self) -> "UserPresence | None":
|
||||||
|
return UserPresence(
|
||||||
|
activity=self.activity,
|
||||||
|
status=self.status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataClientState(UserPresence, UserState): ...
|
||||||
|
|
||||||
|
|
||||||
class OnlineStatus(IntEnum):
|
class OnlineStatus(IntEnum):
|
||||||
OFFLINE = 0 # 隐身
|
OFFLINE = 0 # 隐身
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
from .mods import API_MODS, APIMod, init_mods
|
from .mods import API_MODS, APIMod, init_mods
|
||||||
@@ -93,43 +93,6 @@ class HitResult(str, Enum):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HitResultInt(IntEnum):
|
|
||||||
PERFECT = 0
|
|
||||||
GREAT = 1
|
|
||||||
GOOD = 2
|
|
||||||
OK = 3
|
|
||||||
MEH = 4
|
|
||||||
MISS = 5
|
|
||||||
|
|
||||||
LARGE_TICK_HIT = 6
|
|
||||||
SMALL_TICK_HIT = 7
|
|
||||||
SLIDER_TAIL_HIT = 8
|
|
||||||
|
|
||||||
LARGE_BONUS = 9
|
|
||||||
SMALL_BONUS = 10
|
|
||||||
|
|
||||||
LARGE_TICK_MISS = 11
|
|
||||||
SMALL_TICK_MISS = 12
|
|
||||||
|
|
||||||
IGNORE_HIT = 13
|
|
||||||
IGNORE_MISS = 14
|
|
||||||
|
|
||||||
NONE = 15
|
|
||||||
COMBO_BREAK = 16
|
|
||||||
|
|
||||||
LEGACY_COMBO_INCREASE = 99
|
|
||||||
|
|
||||||
def is_hit(self) -> bool:
|
|
||||||
return self not in (
|
|
||||||
HitResultInt.NONE,
|
|
||||||
HitResultInt.IGNORE_MISS,
|
|
||||||
HitResultInt.COMBO_BREAK,
|
|
||||||
HitResultInt.LARGE_TICK_MISS,
|
|
||||||
HitResultInt.SMALL_TICK_MISS,
|
|
||||||
HitResultInt.MISS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LeaderboardType(Enum):
|
class LeaderboardType(Enum):
|
||||||
GLOBAL = "global"
|
GLOBAL = "global"
|
||||||
FRIENDS = "friend"
|
FRIENDS = "friend"
|
||||||
@@ -138,7 +101,6 @@ class LeaderboardType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
ScoreStatistics = dict[HitResult, int]
|
ScoreStatistics = dict[HitResult, int]
|
||||||
ScoreStatisticsInt = dict[HitResultInt, int]
|
|
||||||
|
|
||||||
|
|
||||||
class SoloScoreSubmissionInfo(BaseModel):
|
class SoloScoreSubmissionInfo(BaseModel):
|
||||||
@@ -176,8 +138,8 @@ class SoloScoreSubmissionInfo(BaseModel):
|
|||||||
class LegacyReplaySoloScoreInfo(TypedDict):
|
class LegacyReplaySoloScoreInfo(TypedDict):
|
||||||
online_id: int
|
online_id: int
|
||||||
mods: list[APIMod]
|
mods: list[APIMod]
|
||||||
statistics: ScoreStatisticsInt
|
statistics: ScoreStatistics
|
||||||
maximum_statistics: ScoreStatisticsInt
|
maximum_statistics: ScoreStatistics
|
||||||
client_version: str
|
client_version: str
|
||||||
rank: Rank
|
rank: Rank
|
||||||
user_id: int
|
user_id: int
|
||||||
|
|||||||
@@ -1,90 +1,23 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from typing import ClassVar
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
BeforeValidator,
|
|
||||||
ConfigDict,
|
|
||||||
Field,
|
Field,
|
||||||
TypeAdapter,
|
|
||||||
model_serializer,
|
|
||||||
model_validator,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def serialize_msgpack(v: Any) -> Any:
|
@dataclass
|
||||||
typ = v.__class__
|
class SignalRMeta:
|
||||||
if issubclass(typ, BaseModel):
|
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
|
||||||
return serialize_to_list(v)
|
json_ignore: bool = False # implement of JsonIgnore (json) attribute
|
||||||
elif issubclass(typ, list):
|
use_upper_case: bool = False # use upper CamelCase for field names
|
||||||
return TypeAdapter(
|
|
||||||
typ, config=ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
).dump_python(v)
|
|
||||||
elif issubclass(typ, datetime.datetime):
|
|
||||||
return [v, 0]
|
|
||||||
elif issubclass(typ, Enum):
|
|
||||||
list_ = list(typ)
|
|
||||||
return list_.index(v) if v in list_ else v.value
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_to_list(value: BaseModel) -> list[Any]:
|
class SignalRUnionMessage(BaseModel):
|
||||||
data = []
|
union_type: ClassVar[int]
|
||||||
for field, info in value.__class__.model_fields.items():
|
|
||||||
data.append(serialize_msgpack(v=getattr(value, field)))
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def _by_index(v: Any, class_: type[Enum]):
|
|
||||||
enum_list = list(class_)
|
|
||||||
if not isinstance(v, int):
|
|
||||||
return v
|
|
||||||
if 0 <= v < len(enum_list):
|
|
||||||
return enum_list[v]
|
|
||||||
raise ValueError(
|
|
||||||
f"Value {v} is out of range for enum "
|
|
||||||
f"{class_.__name__} with {len(enum_list)} items"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def EnumByIndex(enum_class: type[Enum]) -> BeforeValidator:
|
|
||||||
return BeforeValidator(lambda v: _by_index(v, enum_class))
|
|
||||||
|
|
||||||
|
|
||||||
def msgpack_union(v):
|
|
||||||
data = v[1]
|
|
||||||
data.append(v[0])
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def msgpack_union_dump(v: BaseModel) -> list[Any]:
|
|
||||||
_type = getattr(v, "type", None)
|
|
||||||
if _type is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model {v.__class__.__name__} does not have a '_type' attribute"
|
|
||||||
)
|
|
||||||
return [_type, serialize_to_list(v)]
|
|
||||||
|
|
||||||
|
|
||||||
class MessagePackArrayModel(BaseModel):
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def unpack(cls, v: Any) -> Any:
|
|
||||||
if isinstance(v, list):
|
|
||||||
fields = list(cls.model_fields.keys())
|
|
||||||
if len(v) != len(fields):
|
|
||||||
raise ValueError(f"Expected list of length {len(fields)}, got {len(v)}")
|
|
||||||
return dict(zip(fields, v))
|
|
||||||
return v
|
|
||||||
|
|
||||||
@model_serializer
|
|
||||||
def serialize(self) -> list[Any]:
|
|
||||||
return serialize_to_list(self)
|
|
||||||
|
|
||||||
|
|
||||||
class Transport(BaseModel):
|
class Transport(BaseModel):
|
||||||
|
|||||||
@@ -5,14 +5,14 @@ from enum import IntEnum
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.models.beatmap import BeatmapRankStatus
|
from app.models.beatmap import BeatmapRankStatus
|
||||||
|
from app.models.mods import APIMod
|
||||||
|
|
||||||
from .score import (
|
from .score import (
|
||||||
ScoreStatisticsInt,
|
ScoreStatistics,
|
||||||
)
|
)
|
||||||
from .signalr import MessagePackArrayModel, UserState
|
from .signalr import SignalRMeta, UserState
|
||||||
|
|
||||||
from msgpack_lazer_api import APIMod
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
||||||
|
|
||||||
|
|
||||||
class SpectatedUserState(IntEnum):
|
class SpectatedUserState(IntEnum):
|
||||||
@@ -24,14 +24,12 @@ class SpectatedUserState(IntEnum):
|
|||||||
Quit = 5
|
Quit = 5
|
||||||
|
|
||||||
|
|
||||||
class SpectatorState(MessagePackArrayModel):
|
class SpectatorState(BaseModel):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
beatmap_id: int | None = None
|
beatmap_id: int | None = None
|
||||||
ruleset_id: int | None = None # 0,1,2,3
|
ruleset_id: int | None = None # 0,1,2,3
|
||||||
mods: list[APIMod] = Field(default_factory=list)
|
mods: list[APIMod] = Field(default_factory=list)
|
||||||
state: SpectatedUserState
|
state: SpectatedUserState
|
||||||
maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict)
|
maximum_statistics: ScoreStatistics = Field(default_factory=dict)
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, SpectatorState):
|
if not isinstance(other, SpectatorState):
|
||||||
@@ -44,22 +42,20 @@ class SpectatorState(MessagePackArrayModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScoreProcessorStatistics(MessagePackArrayModel):
|
class ScoreProcessorStatistics(BaseModel):
|
||||||
base_score: int
|
base_score: float
|
||||||
maximum_base_score: int
|
maximum_base_score: float
|
||||||
accuracy_judgement_count: int
|
accuracy_judgement_count: int
|
||||||
combo_portion: float
|
combo_portion: float
|
||||||
bouns_portion: float
|
bonus_portion: float
|
||||||
|
|
||||||
|
|
||||||
class FrameHeader(MessagePackArrayModel):
|
class FrameHeader(BaseModel):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
total_score: int
|
total_score: int
|
||||||
acc: float
|
accuracy: float
|
||||||
combo: int
|
combo: int
|
||||||
max_combo: int
|
max_combo: int
|
||||||
statistics: ScoreStatisticsInt = Field(default_factory=dict)
|
statistics: ScoreStatistics = Field(default_factory=dict)
|
||||||
score_processor_statistics: ScoreProcessorStatistics
|
score_processor_statistics: ScoreProcessorStatistics
|
||||||
received_time: datetime.datetime
|
received_time: datetime.datetime
|
||||||
mods: list[APIMod] = Field(default_factory=list)
|
mods: list[APIMod] = Field(default_factory=list)
|
||||||
@@ -87,14 +83,18 @@ class FrameHeader(MessagePackArrayModel):
|
|||||||
# SMOKE = 16
|
# SMOKE = 16
|
||||||
|
|
||||||
|
|
||||||
class LegacyReplayFrame(MessagePackArrayModel):
|
class LegacyReplayFrame(BaseModel):
|
||||||
time: float # from ReplayFrame,the parent of LegacyReplayFrame
|
time: float # from ReplayFrame,the parent of LegacyReplayFrame
|
||||||
x: float | None = None
|
mouse_x: float | None = None
|
||||||
y: float | None = None
|
mouse_y: float | None = None
|
||||||
button_state: int
|
button_state: int
|
||||||
|
|
||||||
|
header: FrameHeader | None = Field(
|
||||||
|
default=None, metadata=[SignalRMeta(member_ignore=True)]
|
||||||
|
)
|
||||||
|
|
||||||
class FrameDataBundle(MessagePackArrayModel):
|
|
||||||
|
class FrameDataBundle(BaseModel):
|
||||||
header: FrameHeader
|
header: FrameHeader
|
||||||
frames: list[LegacyReplayFrame]
|
frames: list[LegacyReplayFrame]
|
||||||
|
|
||||||
@@ -106,18 +106,16 @@ class APIUser(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ScoreInfo(BaseModel):
|
class ScoreInfo(BaseModel):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
mods: list[APIMod]
|
mods: list[APIMod]
|
||||||
user: APIUser
|
user: APIUser
|
||||||
ruleset: int
|
ruleset: int
|
||||||
maximum_statistics: ScoreStatisticsInt
|
maximum_statistics: ScoreStatistics
|
||||||
id: int | None = None
|
id: int | None = None
|
||||||
total_score: int | None = None
|
total_score: int | None = None
|
||||||
acc: float | None = None
|
accuracy: float | None = None
|
||||||
max_combo: int | None = None
|
max_combo: int | None = None
|
||||||
combo: int | None = None
|
combo: int | None = None
|
||||||
statistics: ScoreStatisticsInt = Field(default_factory=dict)
|
statistics: ScoreStatistics = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class StoreScore(BaseModel):
|
class StoreScore(BaseModel):
|
||||||
|
|||||||
@@ -2,15 +2,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from enum import Enum
|
|
||||||
import inspect
|
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.exception import InvokeException
|
from app.exception import InvokeException
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.models.signalr import UserState, _by_index
|
from app.models.signalr import UserState
|
||||||
from app.signalr.packet import (
|
from app.signalr.packet import (
|
||||||
ClosePacket,
|
ClosePacket,
|
||||||
CompletionPacket,
|
CompletionPacket,
|
||||||
@@ -23,7 +21,6 @@ from app.signalr.store import ResultStore
|
|||||||
from app.signalr.utils import get_signature
|
from app.signalr.utils import get_signature
|
||||||
|
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from pydantic import BaseModel
|
|
||||||
from starlette.websockets import WebSocketDisconnect
|
from starlette.websockets import WebSocketDisconnect
|
||||||
|
|
||||||
|
|
||||||
@@ -51,7 +48,7 @@ class Client:
|
|||||||
self.connection_id = connection_id
|
self.connection_id = connection_id
|
||||||
self.connection_token = connection_token
|
self.connection_token = connection_token
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.procotol = protocol
|
self.protocol = protocol
|
||||||
self._listen_task: asyncio.Task | None = None
|
self._listen_task: asyncio.Task | None = None
|
||||||
self._ping_task: asyncio.Task | None = None
|
self._ping_task: asyncio.Task | None = None
|
||||||
self._store = ResultStore()
|
self._store = ResultStore()
|
||||||
@@ -64,14 +61,14 @@ class Client:
|
|||||||
return int(self.connection_id)
|
return int(self.connection_id)
|
||||||
|
|
||||||
async def send_packet(self, packet: Packet):
|
async def send_packet(self, packet: Packet):
|
||||||
await self.connection.send_bytes(self.procotol.encode(packet))
|
await self.connection.send_bytes(self.protocol.encode(packet))
|
||||||
|
|
||||||
async def receive_packets(self) -> list[Packet]:
|
async def receive_packets(self) -> list[Packet]:
|
||||||
message = await self.connection.receive()
|
message = await self.connection.receive()
|
||||||
d = message.get("bytes") or message.get("text", "").encode()
|
d = message.get("bytes") or message.get("text", "").encode()
|
||||||
if not d:
|
if not d:
|
||||||
return []
|
return []
|
||||||
return self.procotol.decode(d)
|
return self.protocol.decode(d)
|
||||||
|
|
||||||
async def _ping(self):
|
async def _ping(self):
|
||||||
while True:
|
while True:
|
||||||
@@ -265,14 +262,9 @@ class Hub[TState: UserState]:
|
|||||||
for name, param in signature.parameters.items():
|
for name, param in signature.parameters.items():
|
||||||
if name == "self" or param.annotation is Client:
|
if name == "self" or param.annotation is Client:
|
||||||
continue
|
continue
|
||||||
if issubclass(param.annotation, BaseModel):
|
call_params.append(
|
||||||
call_params.append(param.annotation.model_validate(args.pop(0)))
|
client.protocol.validate_object(args.pop(0), param.annotation)
|
||||||
elif inspect.isclass(param.annotation) and issubclass(
|
)
|
||||||
param.annotation, Enum
|
|
||||||
):
|
|
||||||
call_params.append(_by_index(args.pop(0), param.annotation))
|
|
||||||
else:
|
|
||||||
call_params.append(args.pop(0))
|
|
||||||
return await method_(client, *call_params)
|
return await method_(client, *call_params)
|
||||||
|
|
||||||
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActiv
|
|||||||
|
|
||||||
from .hub import Client, Hub
|
from .hub import Client, Hub
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -31,7 +30,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
|||||||
) -> set[Coroutine]:
|
) -> set[Coroutine]:
|
||||||
if store is not None and not store.pushable:
|
if store is not None and not store.pushable:
|
||||||
return set()
|
return set()
|
||||||
data = store.to_dict() if store else None
|
data = store.for_push if store else None
|
||||||
return {
|
return {
|
||||||
self.broadcast_group_call(
|
self.broadcast_group_call(
|
||||||
self.online_presence_watchers_group(),
|
self.online_presence_watchers_group(),
|
||||||
@@ -102,7 +101,9 @@ class MetadataHub(Hub[MetadataClientState]):
|
|||||||
self.friend_presence_watchers_group(friend_id),
|
self.friend_presence_watchers_group(friend_id),
|
||||||
"FriendPresenceUpdated",
|
"FriendPresenceUpdated",
|
||||||
friend_id,
|
friend_id,
|
||||||
friend_state.to_dict(),
|
friend_state.for_push
|
||||||
|
if friend_state.pushable
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
@@ -122,27 +123,24 @@ class MetadataHub(Hub[MetadataClientState]):
|
|||||||
client,
|
client,
|
||||||
"UserPresenceUpdated",
|
"UserPresenceUpdated",
|
||||||
user_id,
|
user_id,
|
||||||
store.to_dict(),
|
store.for_push,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
|
async def UpdateActivity(
|
||||||
|
self, client: Client, activity: UserActivity | None
|
||||||
|
) -> None:
|
||||||
user_id = int(client.connection_id)
|
user_id = int(client.connection_id)
|
||||||
activity = (
|
|
||||||
TypeAdapter(UserActivity).validate_python(activity_dict)
|
|
||||||
if activity_dict
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
store = self.get_or_create_state(client)
|
store = self.get_or_create_state(client)
|
||||||
store.user_activity = activity
|
store.activity = activity
|
||||||
tasks = self.broadcast_tasks(user_id, store)
|
tasks = self.broadcast_tasks(user_id, store)
|
||||||
tasks.add(
|
tasks.add(
|
||||||
self.call_noblock(
|
self.call_noblock(
|
||||||
client,
|
client,
|
||||||
"UserPresenceUpdated",
|
"UserPresenceUpdated",
|
||||||
user_id,
|
user_id,
|
||||||
store.to_dict(),
|
store.for_push,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
@@ -154,7 +152,7 @@ class MetadataHub(Hub[MetadataClientState]):
|
|||||||
client,
|
client,
|
||||||
"UserPresenceUpdated",
|
"UserPresenceUpdated",
|
||||||
user_id,
|
user_id,
|
||||||
store.to_dict(),
|
store,
|
||||||
)
|
)
|
||||||
for user_id, store in self.state.items()
|
for user_id, store in self.state.items()
|
||||||
if store.pushable
|
if store.pushable
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ from app.database.score_token import ScoreToken
|
|||||||
from app.dependencies.database import engine
|
from app.dependencies.database import engine
|
||||||
from app.models.beatmap import BeatmapRankStatus
|
from app.models.beatmap import BeatmapRankStatus
|
||||||
from app.models.mods import mods_to_int
|
from app.models.mods import mods_to_int
|
||||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
|
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatistics
|
||||||
from app.models.signalr import serialize_to_list
|
|
||||||
from app.models.spectator_hub import (
|
from app.models.spectator_hub import (
|
||||||
APIUser,
|
APIUser,
|
||||||
FrameDataBundle,
|
FrameDataBundle,
|
||||||
@@ -69,8 +68,8 @@ def save_replay(
|
|||||||
md5: str,
|
md5: str,
|
||||||
username: str,
|
username: str,
|
||||||
score: Score,
|
score: Score,
|
||||||
statistics: ScoreStatisticsInt,
|
statistics: ScoreStatistics,
|
||||||
maximum_statistics: ScoreStatisticsInt,
|
maximum_statistics: ScoreStatistics,
|
||||||
frames: list[LegacyReplayFrame],
|
frames: list[LegacyReplayFrame],
|
||||||
) -> None:
|
) -> None:
|
||||||
data = bytearray()
|
data = bytearray()
|
||||||
@@ -107,8 +106,8 @@ def save_replay(
|
|||||||
last_time = 0
|
last_time = 0
|
||||||
for frame in frames:
|
for frame in frames:
|
||||||
frame_strs.append(
|
frame_strs.append(
|
||||||
f"{frame.time - last_time}|{frame.x or 0.0}"
|
f"{frame.time - last_time}|{frame.mouse_x or 0.0}"
|
||||||
f"|{frame.y or 0.0}|{frame.button_state}"
|
f"|{frame.mouse_y or 0.0}|{frame.button_state}"
|
||||||
)
|
)
|
||||||
last_time = frame.time
|
last_time = frame.time
|
||||||
frame_strs.append("-12345|0|0|0")
|
frame_strs.append("-12345|0|0|0")
|
||||||
@@ -165,9 +164,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
|
|
||||||
async def on_client_connect(self, client: Client) -> None:
|
async def on_client_connect(self, client: Client) -> None:
|
||||||
tasks = [
|
tasks = [
|
||||||
self.call_noblock(
|
self.call_noblock(client, "UserBeganPlaying", user_id, store.state)
|
||||||
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
|
|
||||||
)
|
|
||||||
for user_id, store in self.state.items()
|
for user_id, store in self.state.items()
|
||||||
if store.state is not None
|
if store.state is not None
|
||||||
]
|
]
|
||||||
@@ -214,7 +211,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
self.group_id(user_id),
|
self.group_id(user_id),
|
||||||
"UserBeganPlaying",
|
"UserBeganPlaying",
|
||||||
user_id,
|
user_id,
|
||||||
serialize_to_list(state),
|
state,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
||||||
@@ -222,7 +219,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
state = self.get_or_create_state(client)
|
state = self.get_or_create_state(client)
|
||||||
if not state.score:
|
if not state.score:
|
||||||
return
|
return
|
||||||
state.score.score_info.acc = frame_data.header.acc
|
state.score.score_info.accuracy = frame_data.header.accuracy
|
||||||
state.score.score_info.combo = frame_data.header.combo
|
state.score.score_info.combo = frame_data.header.combo
|
||||||
state.score.score_info.max_combo = frame_data.header.max_combo
|
state.score.score_info.max_combo = frame_data.header.max_combo
|
||||||
state.score.score_info.statistics = frame_data.header.statistics
|
state.score.score_info.statistics = frame_data.header.statistics
|
||||||
@@ -233,7 +230,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
self.group_id(user_id),
|
self.group_id(user_id),
|
||||||
"UserSentFrames",
|
"UserSentFrames",
|
||||||
user_id,
|
user_id,
|
||||||
frame_data.model_dump(),
|
frame_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
||||||
@@ -316,7 +313,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
self.group_id(user_id),
|
self.group_id(user_id),
|
||||||
"UserFinishedPlaying",
|
"UserFinishedPlaying",
|
||||||
user_id,
|
user_id,
|
||||||
serialize_to_list(state) if state else None,
|
state,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
||||||
@@ -327,7 +324,7 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
client,
|
client,
|
||||||
"UserBeganPlaying",
|
"UserBeganPlaying",
|
||||||
target_id,
|
target_id,
|
||||||
serialize_to_list(target_store.state),
|
target_store.state,
|
||||||
)
|
)
|
||||||
store = self.get_or_create_state(client)
|
store = self.get_or_create_state(client)
|
||||||
store.watched_user.add(target_id)
|
store.watched_user.add(target_id)
|
||||||
|
|||||||
@@ -1,16 +1,24 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
import datetime
|
||||||
|
from enum import Enum, IntEnum
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
from types import NoneType, UnionType
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Protocol as TypingProtocol,
|
Protocol as TypingProtocol,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.models.signalr import serialize_msgpack
|
from app.models.signalr import SignalRMeta, SignalRUnionMessage
|
||||||
|
from app.utils import camel_to_snake, snake_to_camel
|
||||||
|
|
||||||
import msgpack_lazer_api as m
|
import msgpack_lazer_api as m
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
SEP = b"\x1e"
|
SEP = b"\x1e"
|
||||||
|
|
||||||
@@ -75,8 +83,61 @@ class Protocol(TypingProtocol):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def encode(packet: Packet) -> bytes: ...
|
def encode(packet: Packet) -> bytes: ...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_object(cls, v: Any, typ: type) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
class MsgpackProtocol:
|
class MsgpackProtocol:
|
||||||
|
@classmethod
|
||||||
|
def serialize_msgpack(cls, v: Any) -> Any:
|
||||||
|
typ = v.__class__
|
||||||
|
if issubclass(typ, BaseModel):
|
||||||
|
return cls.serialize_to_list(v)
|
||||||
|
elif issubclass(typ, list):
|
||||||
|
return [cls.serialize_msgpack(item) for item in v]
|
||||||
|
elif issubclass(typ, datetime.datetime):
|
||||||
|
return [v, 0]
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
return {
|
||||||
|
cls.serialize_msgpack(k): cls.serialize_msgpack(value)
|
||||||
|
for k, value in v.items()
|
||||||
|
}
|
||||||
|
elif issubclass(typ, Enum):
|
||||||
|
list_ = list(typ)
|
||||||
|
return list_.index(v) if v in list_ else v.value
|
||||||
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serialize_to_list(cls, value: BaseModel) -> list[Any]:
|
||||||
|
values = []
|
||||||
|
for field, info in value.__class__.model_fields.items():
|
||||||
|
metadata = next(
|
||||||
|
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||||
|
)
|
||||||
|
if metadata and metadata.member_ignore:
|
||||||
|
continue
|
||||||
|
values.append(cls.serialize_msgpack(v=getattr(value, field)))
|
||||||
|
if issubclass(value.__class__, SignalRUnionMessage):
|
||||||
|
return [value.__class__.union_type, values]
|
||||||
|
else:
|
||||||
|
return values
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_object(v: Any, typ: type[BaseModel]) -> Any:
|
||||||
|
if isinstance(v, list):
|
||||||
|
d = {}
|
||||||
|
for i, f in enumerate(typ.model_fields.items()):
|
||||||
|
field, info = f
|
||||||
|
if info.exclude:
|
||||||
|
continue
|
||||||
|
anno = info.annotation
|
||||||
|
if anno is None:
|
||||||
|
d[camel_to_snake(field)] = v[i]
|
||||||
|
continue
|
||||||
|
d[field] = MsgpackProtocol.validate_object(v[i], anno)
|
||||||
|
return d
|
||||||
|
return v
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _encode_varint(value: int) -> bytes:
|
def _encode_varint(value: int) -> bytes:
|
||||||
result = []
|
result = []
|
||||||
@@ -142,6 +203,49 @@ class MsgpackProtocol:
|
|||||||
]
|
]
|
||||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_object(cls, v: Any, typ: type) -> Any:
|
||||||
|
if issubclass(typ, BaseModel):
|
||||||
|
return typ.model_validate(obj=cls.process_object(v, typ))
|
||||||
|
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||||
|
return v[0]
|
||||||
|
elif isinstance(v, list):
|
||||||
|
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||||
|
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||||
|
list_ = list(typ)
|
||||||
|
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
|
||||||
|
elif get_origin(typ) is dict:
|
||||||
|
return {
|
||||||
|
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
|
||||||
|
v, get_args(typ)[1]
|
||||||
|
)
|
||||||
|
for k, v in v.items()
|
||||||
|
}
|
||||||
|
elif (origin := get_origin(typ)) is Union or origin is UnionType:
|
||||||
|
args = get_args(typ)
|
||||||
|
if len(args) == 2 and NoneType in args:
|
||||||
|
non_none_args = [arg for arg in args if arg is not NoneType]
|
||||||
|
if len(non_none_args) == 1:
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
return cls.validate_object(v, non_none_args[0])
|
||||||
|
|
||||||
|
# suppose use `MessagePack-CSharp Union | None`
|
||||||
|
# except `X (Other Type) | None`
|
||||||
|
if NoneType in args and v is None:
|
||||||
|
return None
|
||||||
|
if not all(issubclass(arg, SignalRUnionMessage) for arg in args):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot validate {v} to {typ}, "
|
||||||
|
"only SignalRUnionMessage subclasses are supported"
|
||||||
|
)
|
||||||
|
union_type = v[0]
|
||||||
|
for arg in args:
|
||||||
|
assert issubclass(arg, SignalRUnionMessage)
|
||||||
|
if arg.union_type == union_type:
|
||||||
|
return cls.validate_object(v[1], arg)
|
||||||
|
return v
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encode(packet: Packet) -> bytes:
|
def encode(packet: Packet) -> bytes:
|
||||||
payload = [packet.type.value, packet.header or {}]
|
payload = [packet.type.value, packet.header or {}]
|
||||||
@@ -153,7 +257,9 @@ class MsgpackProtocol:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
if packet.arguments is not None:
|
if packet.arguments is not None:
|
||||||
payload.append([serialize_msgpack(arg) for arg in packet.arguments])
|
payload.append(
|
||||||
|
[MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments]
|
||||||
|
)
|
||||||
if packet.stream_ids is not None:
|
if packet.stream_ids is not None:
|
||||||
payload.append(packet.stream_ids)
|
payload.append(packet.stream_ids)
|
||||||
elif isinstance(packet, CompletionPacket):
|
elif isinstance(packet, CompletionPacket):
|
||||||
@@ -166,7 +272,9 @@ class MsgpackProtocol:
|
|||||||
[
|
[
|
||||||
packet.invocation_id,
|
packet.invocation_id,
|
||||||
result_kind,
|
result_kind,
|
||||||
packet.error or packet.result or None,
|
packet.error
|
||||||
|
or MsgpackProtocol.serialize_msgpack(packet.result)
|
||||||
|
or None,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
elif isinstance(packet, ClosePacket):
|
elif isinstance(packet, ClosePacket):
|
||||||
@@ -183,6 +291,62 @@ class MsgpackProtocol:
|
|||||||
|
|
||||||
|
|
||||||
class JSONProtocol:
|
class JSONProtocol:
|
||||||
|
@classmethod
|
||||||
|
def serialize_to_json(cls, v: Any):
|
||||||
|
typ = v.__class__
|
||||||
|
if issubclass(typ, BaseModel):
|
||||||
|
return cls.serialize_model(v)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
return {
|
||||||
|
cls.serialize_to_json(k): cls.serialize_to_json(value)
|
||||||
|
for k, value in v.items()
|
||||||
|
}
|
||||||
|
elif isinstance(v, list):
|
||||||
|
return [cls.serialize_to_json(item) for item in v]
|
||||||
|
elif isinstance(v, datetime.datetime):
|
||||||
|
return v.isoformat()
|
||||||
|
elif isinstance(v, Enum):
|
||||||
|
return v.value
|
||||||
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serialize_model(cls, v: BaseModel) -> dict[str, Any]:
|
||||||
|
d = {}
|
||||||
|
for field, info in v.__class__.model_fields.items():
|
||||||
|
metadata = next(
|
||||||
|
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||||
|
)
|
||||||
|
if metadata and metadata.json_ignore:
|
||||||
|
continue
|
||||||
|
d[snake_to_camel(field, metadata.use_upper_case if metadata else False)] = (
|
||||||
|
cls.serialize_to_json(getattr(v, field))
|
||||||
|
)
|
||||||
|
if issubclass(v.__class__, SignalRUnionMessage):
|
||||||
|
return {
|
||||||
|
"$dtype": v.__class__.__name__,
|
||||||
|
"$value": d,
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_object(
|
||||||
|
v: Any, typ: type[BaseModel], from_union: bool = False
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
d = {}
|
||||||
|
for field, info in typ.model_fields.items():
|
||||||
|
metadata = next(
|
||||||
|
(m for m in info.metadata if isinstance(m, SignalRMeta)), None
|
||||||
|
)
|
||||||
|
if metadata and metadata.json_ignore:
|
||||||
|
continue
|
||||||
|
value = v.get(snake_to_camel(field, not from_union))
|
||||||
|
anno = typ.model_fields[field].annotation
|
||||||
|
if anno is None:
|
||||||
|
d[field] = value
|
||||||
|
continue
|
||||||
|
d[field] = JSONProtocol.validate_object(value, anno)
|
||||||
|
return d
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def decode(input: bytes) -> list[Packet]:
|
def decode(input: bytes) -> list[Packet]:
|
||||||
packets_raw = input.removesuffix(SEP).split(SEP)
|
packets_raw = input.removesuffix(SEP).split(SEP)
|
||||||
@@ -227,6 +391,52 @@ class JSONProtocol:
|
|||||||
]
|
]
|
||||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
|
||||||
|
if issubclass(typ, BaseModel):
|
||||||
|
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
|
||||||
|
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||||
|
return datetime.datetime.fromisoformat(v)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||||
|
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||||
|
list_ = list(typ)
|
||||||
|
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
|
||||||
|
elif get_origin(typ) is dict:
|
||||||
|
return {
|
||||||
|
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(
|
||||||
|
v, get_args(typ)[1]
|
||||||
|
)
|
||||||
|
for k, v in v.items()
|
||||||
|
}
|
||||||
|
elif (origin := get_origin(typ)) is Union or origin is UnionType:
|
||||||
|
args = get_args(typ)
|
||||||
|
if len(args) == 2 and NoneType in args:
|
||||||
|
non_none_args = [arg for arg in args if arg is not NoneType]
|
||||||
|
if len(non_none_args) == 1:
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
return cls.validate_object(v, non_none_args[0])
|
||||||
|
|
||||||
|
# suppose use `MessagePack-CSharp Union | None`
|
||||||
|
# except `X (Other Type) | None`
|
||||||
|
if NoneType in args and v is None:
|
||||||
|
return None
|
||||||
|
if not all(
|
||||||
|
issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot validate {v} to {typ}, "
|
||||||
|
"only SignalRUnionMessage subclasses are supported"
|
||||||
|
)
|
||||||
|
# https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
|
||||||
|
union_type = v["$dtype"]
|
||||||
|
for arg in args:
|
||||||
|
assert issubclass(arg, SignalRUnionMessage)
|
||||||
|
if arg.__name__ == union_type:
|
||||||
|
return cls.validate_object(v["$value"], arg, True)
|
||||||
|
return v
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encode(packet: Packet) -> bytes:
|
def encode(packet: Packet) -> bytes:
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
@@ -243,7 +453,9 @@ class JSONProtocol:
|
|||||||
if packet.invocation_id is not None:
|
if packet.invocation_id is not None:
|
||||||
payload["invocationId"] = packet.invocation_id
|
payload["invocationId"] = packet.invocation_id
|
||||||
if packet.arguments is not None:
|
if packet.arguments is not None:
|
||||||
payload["arguments"] = packet.arguments
|
payload["arguments"] = [
|
||||||
|
JSONProtocol.serialize_to_json(arg) for arg in packet.arguments
|
||||||
|
]
|
||||||
if packet.stream_ids is not None:
|
if packet.stream_ids is not None:
|
||||||
payload["streamIds"] = packet.stream_ids
|
payload["streamIds"] = packet.stream_ids
|
||||||
elif isinstance(packet, CompletionPacket):
|
elif isinstance(packet, CompletionPacket):
|
||||||
@@ -255,7 +467,7 @@ class JSONProtocol:
|
|||||||
if packet.error is not None:
|
if packet.error is not None:
|
||||||
payload["error"] = packet.error
|
payload["error"] = packet.error
|
||||||
if packet.result is not None:
|
if packet.result is not None:
|
||||||
payload["result"] = packet.result
|
payload["result"] = JSONProtocol.serialize_to_json(packet.result)
|
||||||
elif isinstance(packet, PingPacket):
|
elif isinstance(packet, PingPacket):
|
||||||
pass
|
pass
|
||||||
elif isinstance(packet, ClosePacket):
|
elif isinstance(packet, ClosePacket):
|
||||||
|
|||||||
52
app/utils.py
52
app/utils.py
@@ -4,3 +4,55 @@ from __future__ import annotations
|
|||||||
def unix_timestamp_to_windows(timestamp: int) -> int:
|
def unix_timestamp_to_windows(timestamp: int) -> int:
|
||||||
"""Convert a Unix timestamp to a Windows timestamp."""
|
"""Convert a Unix timestamp to a Windows timestamp."""
|
||||||
return (timestamp + 62135596800) * 10_000_000
|
return (timestamp + 62135596800) * 10_000_000
|
||||||
|
|
||||||
|
|
||||||
|
def camel_to_snake(name: str) -> str:
|
||||||
|
"""Convert a camelCase string to snake_case."""
|
||||||
|
result = []
|
||||||
|
last_chr = ""
|
||||||
|
for char in name:
|
||||||
|
if char.isupper():
|
||||||
|
if not last_chr.isupper() and result:
|
||||||
|
result.append("_")
|
||||||
|
result.append(char.lower())
|
||||||
|
else:
|
||||||
|
result.append(char)
|
||||||
|
last_chr = char
|
||||||
|
return "".join(result)
|
||||||
|
|
||||||
|
|
||||||
|
def snake_to_camel(name: str, lower_case: bool = True) -> str:
|
||||||
|
"""Convert a snake_case string to camelCase."""
|
||||||
|
if not name:
|
||||||
|
return name
|
||||||
|
|
||||||
|
parts = name.split("_")
|
||||||
|
if not parts:
|
||||||
|
return name
|
||||||
|
|
||||||
|
# 常见缩写词列表
|
||||||
|
abbreviations = {
|
||||||
|
"id",
|
||||||
|
"url",
|
||||||
|
"api",
|
||||||
|
"http",
|
||||||
|
"https",
|
||||||
|
"xml",
|
||||||
|
"json",
|
||||||
|
"css",
|
||||||
|
"html",
|
||||||
|
"sql",
|
||||||
|
"db",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for part in parts:
|
||||||
|
if part.lower() in abbreviations:
|
||||||
|
result.append(part.upper())
|
||||||
|
else:
|
||||||
|
if result or not lower_case:
|
||||||
|
result.append(part.capitalize())
|
||||||
|
else:
|
||||||
|
result.append(part.lower())
|
||||||
|
|
||||||
|
return "".join(result)
|
||||||
|
|||||||
@@ -1,11 +1,4 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
class APIMod:
|
|
||||||
def __init__(self, acronym: str, settings: dict[str, Any]) -> None: ...
|
|
||||||
@property
|
|
||||||
def acronym(self) -> str: ...
|
|
||||||
@property
|
|
||||||
def settings(self) -> dict[str, Any]: ...
|
|
||||||
|
|
||||||
def encode(obj: Any) -> bytes: ...
|
def encode(obj: Any) -> bytes: ...
|
||||||
def decode(data: bytes) -> Any: ...
|
def decode(data: bytes) -> Any: ...
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
use crate::APIMod;
|
|
||||||
use chrono::{TimeZone, Utc};
|
use chrono::{TimeZone, Utc};
|
||||||
use pyo3::types::PyDict;
|
use pyo3::types::PyDict;
|
||||||
use pyo3::{prelude::*, IntoPyObjectExt};
|
use pyo3::{prelude::*, IntoPyObjectExt};
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
|
|
||||||
pub fn read_object(
|
pub fn read_object(
|
||||||
@@ -206,13 +204,12 @@ fn read_array(
|
|||||||
let obj1 = read_object(py, cursor, false)?;
|
let obj1 = read_object(py, cursor, false)?;
|
||||||
if obj1.extract::<String>(py).map_or(false, |k| k.len() == 2) {
|
if obj1.extract::<String>(py).map_or(false, |k| k.len() == 2) {
|
||||||
let obj2 = read_object(py, cursor, true)?;
|
let obj2 = read_object(py, cursor, true)?;
|
||||||
return Ok(APIMod {
|
|
||||||
acronym: obj1.extract::<String>(py)?,
|
let api_mod_dict = PyDict::new(py);
|
||||||
settings: obj2.extract::<HashMap<String, PyObject>>(py)?,
|
api_mod_dict.set_item("acronym", obj1)?;
|
||||||
}
|
api_mod_dict.set_item("settings", obj2)?;
|
||||||
.into_pyobject(py)?
|
|
||||||
.into_any()
|
return Ok(api_mod_dict.into_pyobject(py)?.into_any().unbind());
|
||||||
.unbind());
|
|
||||||
} else {
|
} else {
|
||||||
items.push(obj1);
|
items.push(obj1);
|
||||||
i += 1;
|
i += 1;
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use crate::APIMod;
|
use chrono::{DateTime, Utc};
|
||||||
use chrono::{DateTime, Utc};
|
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyResult, PyStringMethods};
|
||||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods, PyStringMethods};
|
|
||||||
use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString};
|
use pyo3::types::{PyBool, PyBytes, PyDateTime, PyDict, PyFloat, PyInt, PyList, PyNone, PyString};
|
||||||
use pyo3::{Bound, PyAny, PyRef, Python};
|
use pyo3::{Bound, PyAny};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
fn write_list(buf: &mut Vec<u8>, obj: &Bound<'_, PyList>) {
|
fn write_list(buf: &mut Vec<u8>, obj: &Bound<'_, PyList>) {
|
||||||
@@ -61,19 +60,42 @@ fn write_hashmap(buf: &mut Vec<u8>, obj: &Bound<'_, PyDict>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn write_nil(buf: &mut Vec<u8>){
|
fn write_nil(buf: &mut Vec<u8>) {
|
||||||
rmp::encode::write_nil(buf).unwrap();
|
rmp::encode::write_nil(buf).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
|
fn is_api_mod(dict: &Bound<'_, PyDict>) -> bool {
|
||||||
fn write_api_mod(buf: &mut Vec<u8>, api_mod: PyRef<APIMod>) {
|
if let Ok(Some(acronym)) = dict.get_item("acronym") {
|
||||||
rmp::encode::write_array_len(buf, 2).unwrap();
|
if let Ok(acronym_str) = acronym.extract::<String>() {
|
||||||
rmp::encode::write_str(buf, &api_mod.acronym).unwrap();
|
return acronym_str.len() == 2;
|
||||||
rmp::encode::write_array_len(buf, api_mod.settings.len() as u32).unwrap();
|
}
|
||||||
for (k, v) in api_mod.settings.iter() {
|
|
||||||
rmp::encode::write_str(buf, k).unwrap();
|
|
||||||
Python::with_gil(|py| write_object(buf, &v.bind(py)));
|
|
||||||
}
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/ppy/osu/blob/3dced3/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
|
||||||
|
fn write_api_mod(buf: &mut Vec<u8>, api_mod: &Bound<'_, PyDict>) -> PyResult<()> {
|
||||||
|
let acronym = api_mod
|
||||||
|
.get_item("acronym")?
|
||||||
|
.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("APIMod missing 'acronym' field"))?;
|
||||||
|
let acronym_str = acronym.extract::<String>()?;
|
||||||
|
|
||||||
|
let settings = api_mod
|
||||||
|
.get_item("settings")?
|
||||||
|
.unwrap_or_else(|| PyDict::new(acronym.py()).into_any());
|
||||||
|
let settings_dict = settings.downcast::<PyDict>()?;
|
||||||
|
|
||||||
|
rmp::encode::write_array_len(buf, 2).unwrap();
|
||||||
|
rmp::encode::write_str(buf, &acronym_str).unwrap();
|
||||||
|
rmp::encode::write_array_len(buf, settings_dict.len() as u32).unwrap();
|
||||||
|
|
||||||
|
for (k, v) in settings_dict.iter() {
|
||||||
|
let key_str = k.extract::<String>()?;
|
||||||
|
rmp::encode::write_str(buf, &key_str).unwrap();
|
||||||
|
write_object(buf, &v);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn write_datetime(buf: &mut Vec<u8>, obj: &Bound<'_, PyDateTime>) {
|
fn write_datetime(buf: &mut Vec<u8>, obj: &Bound<'_, PyDateTime>) {
|
||||||
@@ -111,21 +133,23 @@ pub fn write_object(buf: &mut Vec<u8>, obj: &Bound<'_, PyAny>) {
|
|||||||
} else if let Ok(string) = obj.downcast::<PyString>() {
|
} else if let Ok(string) = obj.downcast::<PyString>() {
|
||||||
write_string(buf, string);
|
write_string(buf, string);
|
||||||
} else if let Ok(boolean) = obj.downcast::<PyBool>() {
|
} else if let Ok(boolean) = obj.downcast::<PyBool>() {
|
||||||
write_bool(buf, boolean);
|
write_bool(buf, boolean);
|
||||||
} else if let Ok(float) = obj.downcast::<PyFloat>() {
|
} else if let Ok(float) = obj.downcast::<PyFloat>() {
|
||||||
write_float(buf, float);
|
write_float(buf, float);
|
||||||
} else if let Ok(integer) = obj.downcast::<PyInt>() {
|
} else if let Ok(integer) = obj.downcast::<PyInt>() {
|
||||||
write_integer(buf, integer);
|
write_integer(buf, integer);
|
||||||
} else if let Ok(bytes) = obj.downcast::<PyBytes>() {
|
} else if let Ok(bytes) = obj.downcast::<PyBytes>() {
|
||||||
write_bin(buf, bytes);
|
write_bin(buf, bytes);
|
||||||
} else if let Ok(dict) = obj.downcast::<PyDict>() {
|
} else if let Ok(dict) = obj.downcast::<PyDict>() {
|
||||||
write_hashmap(buf, dict);
|
if is_api_mod(dict) {
|
||||||
|
write_api_mod(buf, dict).unwrap_or_else(|_| write_hashmap(buf, dict));
|
||||||
|
} else {
|
||||||
|
write_hashmap(buf, dict);
|
||||||
|
}
|
||||||
} else if let Ok(_none) = obj.downcast::<PyNone>() {
|
} else if let Ok(_none) = obj.downcast::<PyNone>() {
|
||||||
write_nil(buf);
|
write_nil(buf);
|
||||||
} else if let Ok(datetime) = obj.downcast::<PyDateTime>() {
|
} else if let Ok(datetime) = obj.downcast::<PyDateTime>() {
|
||||||
write_datetime(buf, datetime);
|
write_datetime(buf, datetime);
|
||||||
} else if let Ok(api_mod) = obj.extract::<PyRef<APIMod>>() {
|
|
||||||
write_api_mod(buf, api_mod);
|
|
||||||
} else {
|
} else {
|
||||||
panic!("Unsupported type");
|
panic!("Unsupported type");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,30 +2,6 @@ mod decode;
|
|||||||
mod encode;
|
mod encode;
|
||||||
|
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
#[pyclass]
|
|
||||||
struct APIMod {
|
|
||||||
#[pyo3(get, set)]
|
|
||||||
acronym: String,
|
|
||||||
#[pyo3(get, set)]
|
|
||||||
settings: HashMap<String, PyObject>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pymethods]
|
|
||||||
impl APIMod {
|
|
||||||
#[new]
|
|
||||||
fn new(acronym: String, settings: HashMap<String, PyObject>) -> Self {
|
|
||||||
APIMod { acronym, settings }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __repr__(&self) -> String {
|
|
||||||
format!(
|
|
||||||
"APIMod(acronym='{}', settings={:?})",
|
|
||||||
self.acronym, self.settings
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
#[pyo3(name = "encode")]
|
#[pyo3(name = "encode")]
|
||||||
@@ -46,6 +22,5 @@ fn decode_py(py: Python, data: &[u8]) -> PyResult<PyObject> {
|
|||||||
fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
fn msgpack_lazer_api(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||||
m.add_function(wrap_pyfunction!(encode_py, m)?)?;
|
m.add_function(wrap_pyfunction!(encode_py, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(decode_py, m)?)?;
|
m.add_function(wrap_pyfunction!(decode_py, m)?)?;
|
||||||
m.add_class::<APIMod>()?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user