refactor(signalr): remove SignalR server & msgpack_lazer_api
Maybe we can make `msgpack_lazer_api` independent?
This commit is contained in:
@@ -266,18 +266,6 @@ STORAGE_SETTINGS='{
|
||||
else:
|
||||
return "/"
|
||||
|
||||
# SignalR 设置
|
||||
signalr_negotiate_timeout: Annotated[
|
||||
int,
|
||||
Field(default=30, description="SignalR 协商超时时间(秒)"),
|
||||
"SignalR 服务器设置",
|
||||
]
|
||||
signalr_ping_interval: Annotated[
|
||||
int,
|
||||
Field(default=15, description="SignalR ping 间隔(秒)"),
|
||||
"SignalR 服务器设置",
|
||||
]
|
||||
|
||||
# Fetcher 设置
|
||||
fetcher_client_id: Annotated[
|
||||
str,
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.mods import APIMod
|
||||
from app.models.playlist import PlaylistItem
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
|
||||
@@ -21,8 +22,6 @@ from sqlmodel import (
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.multiplayer_hub import PlaylistItem
|
||||
|
||||
from .room import Room
|
||||
|
||||
|
||||
@@ -73,7 +72,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
return result.one()
|
||||
|
||||
@classmethod
|
||||
async def from_hub(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession) -> "Playlist":
|
||||
async def from_model(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
|
||||
next_id = await cls.get_next_id_for_room(room_id, session=session)
|
||||
return cls(
|
||||
id=next_id,
|
||||
@@ -90,7 +89,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession):
|
||||
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id))
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
@@ -107,8 +106,8 @@ class Playlist(PlaylistBase, table=True):
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def add_to_db(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession):
|
||||
db_playlist = await cls.from_hub(playlist, room_id, session)
|
||||
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await cls.from_model(playlist, room_id, session)
|
||||
session.add(db_playlist)
|
||||
await session.commit()
|
||||
await session.refresh(db_playlist)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.database.item_attempts_count import PlaylistAggregateScore
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
@@ -32,9 +31,6 @@ from sqlmodel import (
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.multiplayer_hub import ServerMultiplayerRoom
|
||||
|
||||
|
||||
class RoomBase(SQLModel, UTCBaseModel):
|
||||
name: str = Field(index=True)
|
||||
@@ -163,25 +159,6 @@ class RoomResp(RoomBase):
|
||||
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
async def from_hub(cls, server_room: "ServerMultiplayerRoom") -> "RoomResp":
|
||||
room = server_room.room
|
||||
resp = cls(
|
||||
id=room.room_id,
|
||||
name=room.settings.name,
|
||||
type=room.settings.match_type,
|
||||
queue_mode=room.settings.queue_mode,
|
||||
auto_skip=room.settings.auto_skip,
|
||||
auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
|
||||
status=server_room.status,
|
||||
category=server_room.category,
|
||||
# duration = room.settings.duration,
|
||||
starts_at=server_room.start_at,
|
||||
participant_count=len(room.users),
|
||||
channel_id=server_room.room.channel_id or 0,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class APIUploadedRoom(RoomBase):
|
||||
def to_room(self) -> Room:
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class SignalRException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvokeException(SignalRException):
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
@@ -1,73 +0,0 @@
|
||||
"""
|
||||
会话验证接口
|
||||
|
||||
基于osu-web的SessionVerificationInterface实现
|
||||
用于标准化会话验证行为
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SessionVerificationInterface(ABC):
|
||||
"""会话验证接口
|
||||
|
||||
定义了会话验证所需的基本操作,参考osu-web的实现
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def find_for_verification(cls, session_id: str) -> SessionVerificationInterface | None:
|
||||
"""根据会话ID查找会话用于验证
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
会话实例或None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_key(self) -> str:
|
||||
"""获取会话密钥/ID"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_key_for_event(self) -> str:
|
||||
"""获取用于事件广播的会话密钥"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_verification_method(self) -> str | None:
|
||||
"""获取当前验证方法
|
||||
|
||||
Returns:
|
||||
验证方法 ('totp', 'mail') 或 None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_verified(self) -> bool:
|
||||
"""检查会话是否已验证"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def mark_verified(self) -> None:
|
||||
"""标记会话为已验证"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_verification_method(self, method: str) -> None:
|
||||
"""设置验证方法
|
||||
|
||||
Args:
|
||||
method: 验证方法 ('totp', 'mail')
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def user_id(self) -> int | None:
|
||||
"""获取关联的用户ID"""
|
||||
pass
|
||||
@@ -1,157 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
from app.models.signalr import SignalRUnionMessage, UserState
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS = 13
|
||||
|
||||
|
||||
class _UserActivity(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class ChoosingBeatmap(_UserActivity):
|
||||
union_type: ClassVar[Literal[11]] = 11
|
||||
|
||||
|
||||
class _InGame(_UserActivity):
|
||||
beatmap_id: int
|
||||
beatmap_display_title: str
|
||||
ruleset_id: int
|
||||
ruleset_playing_verb: str
|
||||
|
||||
|
||||
class InSoloGame(_InGame):
|
||||
union_type: ClassVar[Literal[12]] = 12
|
||||
|
||||
|
||||
class InMultiplayerGame(_InGame):
|
||||
union_type: ClassVar[Literal[23]] = 23
|
||||
|
||||
|
||||
class SpectatingMultiplayerGame(_InGame):
|
||||
union_type: ClassVar[Literal[24]] = 24
|
||||
|
||||
|
||||
class InPlaylistGame(_InGame):
|
||||
union_type: ClassVar[Literal[31]] = 31
|
||||
|
||||
|
||||
class PlayingDailyChallenge(_InGame):
|
||||
union_type: ClassVar[Literal[52]] = 52
|
||||
|
||||
|
||||
class EditingBeatmap(_UserActivity):
|
||||
union_type: ClassVar[Literal[41]] = 41
|
||||
beatmap_id: int
|
||||
beatmap_display_title: str
|
||||
|
||||
|
||||
class TestingBeatmap(EditingBeatmap):
|
||||
union_type: ClassVar[Literal[43]] = 43
|
||||
|
||||
|
||||
class ModdingBeatmap(EditingBeatmap):
|
||||
union_type: ClassVar[Literal[42]] = 42
|
||||
|
||||
|
||||
class WatchingReplay(_UserActivity):
|
||||
union_type: ClassVar[Literal[13]] = 13
|
||||
score_id: int
|
||||
player_name: str
|
||||
beatmap_id: int
|
||||
beatmap_display_title: str
|
||||
|
||||
|
||||
class SpectatingUser(WatchingReplay):
|
||||
union_type: ClassVar[Literal[14]] = 14
|
||||
|
||||
|
||||
class SearchingForLobby(_UserActivity):
|
||||
union_type: ClassVar[Literal[21]] = 21
|
||||
|
||||
|
||||
class InLobby(_UserActivity):
|
||||
union_type: ClassVar[Literal[22]] = 22
|
||||
room_id: int
|
||||
room_name: str
|
||||
|
||||
|
||||
class InDailyChallengeLobby(_UserActivity):
|
||||
union_type: ClassVar[Literal[51]] = 51
|
||||
|
||||
|
||||
UserActivity = (
|
||||
ChoosingBeatmap
|
||||
| InSoloGame
|
||||
| WatchingReplay
|
||||
| SpectatingUser
|
||||
| SearchingForLobby
|
||||
| InLobby
|
||||
| InMultiplayerGame
|
||||
| SpectatingMultiplayerGame
|
||||
| InPlaylistGame
|
||||
| EditingBeatmap
|
||||
| ModdingBeatmap
|
||||
| TestingBeatmap
|
||||
| InDailyChallengeLobby
|
||||
| PlayingDailyChallenge
|
||||
)
|
||||
|
||||
|
||||
class UserPresence(BaseModel):
|
||||
activity: UserActivity | None = None
|
||||
|
||||
status: OnlineStatus | None = None
|
||||
|
||||
@property
|
||||
def pushable(self) -> bool:
|
||||
return self.status is not None and self.status != OnlineStatus.OFFLINE
|
||||
|
||||
@property
|
||||
def for_push(self) -> "UserPresence | None":
|
||||
return UserPresence(
|
||||
activity=self.activity,
|
||||
status=self.status,
|
||||
)
|
||||
|
||||
|
||||
class MetadataClientState(UserPresence, UserState): ...
|
||||
|
||||
|
||||
class OnlineStatus(IntEnum):
|
||||
OFFLINE = 0 # 隐身
|
||||
DO_NOT_DISTURB = 1
|
||||
ONLINE = 2
|
||||
|
||||
|
||||
class DailyChallengeInfo(BaseModel):
|
||||
room_id: int
|
||||
|
||||
|
||||
class MultiplayerPlaylistItemStats(BaseModel):
|
||||
playlist_item_id: int = 0
|
||||
total_score_distribution: list[int] = Field(
|
||||
default_factory=list,
|
||||
min_length=TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
max_length=TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
)
|
||||
cumulative_score: int = 0
|
||||
last_processed_score_id: int = 0
|
||||
|
||||
|
||||
class MultiplayerRoomStats(BaseModel):
|
||||
room_id: int
|
||||
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class MultiplayerRoomScoreSetEvent(BaseModel):
|
||||
room_id: int
|
||||
playlist_item_id: int
|
||||
score_id: int
|
||||
user_id: int
|
||||
total_score: int
|
||||
new_rank: int | None = None
|
||||
|
||||
@@ -1,840 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import IntEnum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
TypedDict,
|
||||
cast,
|
||||
override,
|
||||
)
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.dependencies.database import with_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.exception import InvokeException
|
||||
from app.utils import utcnow
|
||||
|
||||
from .mods import API_MODS, APIMod
|
||||
from .room import (
|
||||
DownloadState,
|
||||
MatchType,
|
||||
MultiplayerRoomState,
|
||||
MultiplayerUserState,
|
||||
QueueMode,
|
||||
RoomCategory,
|
||||
RoomStatus,
|
||||
)
|
||||
from .signalr import (
|
||||
SignalRMeta,
|
||||
SignalRUnionMessage,
|
||||
UserState,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import col
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.room import Room
|
||||
from app.signalr.hub import MultiplayerHub
|
||||
|
||||
HOST_LIMIT = 50
|
||||
PER_USER_LIMIT = 3
|
||||
|
||||
|
||||
class MultiplayerClientState(UserState):
|
||||
room_id: int = 0
|
||||
|
||||
|
||||
class MultiplayerRoomSettings(BaseModel):
|
||||
name: str = "Unnamed Room"
|
||||
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||
password: str = ""
|
||||
match_type: MatchType = MatchType.HEAD_TO_HEAD
|
||||
queue_mode: QueueMode = QueueMode.HOST_ONLY
|
||||
auto_start_duration: timedelta = timedelta(seconds=0)
|
||||
auto_skip: bool = False
|
||||
|
||||
@property
|
||||
def auto_start_enabled(self) -> bool:
|
||||
return self.auto_start_duration != timedelta(seconds=0)
|
||||
|
||||
|
||||
class BeatmapAvailability(BaseModel):
|
||||
state: DownloadState = DownloadState.UNKNOWN
|
||||
download_progress: float | None = None
|
||||
|
||||
|
||||
class _MatchUserState(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class TeamVersusUserState(_MatchUserState):
|
||||
team_id: int
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
MatchUserState = TeamVersusUserState
|
||||
|
||||
|
||||
class _MatchRoomState(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class MultiplayerTeam(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class TeamVersusRoomState(_MatchRoomState):
|
||||
teams: list[MultiplayerTeam] = Field(
|
||||
default_factory=lambda: [
|
||||
MultiplayerTeam(id=0, name="Team Red"),
|
||||
MultiplayerTeam(id=1, name="Team Blue"),
|
||||
]
|
||||
)
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
MatchRoomState = TeamVersusRoomState
|
||||
|
||||
|
||||
class PlaylistItem(BaseModel):
|
||||
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
|
||||
owner_id: int
|
||||
beatmap_id: int
|
||||
beatmap_checksum: str
|
||||
ruleset_id: int
|
||||
required_mods: list[APIMod] = Field(default_factory=list)
|
||||
allowed_mods: list[APIMod] = Field(default_factory=list)
|
||||
expired: bool
|
||||
playlist_order: int
|
||||
played_at: datetime | None = None
|
||||
star_rating: float
|
||||
freestyle: bool
|
||||
|
||||
def _validate_mod_for_ruleset(self, mod: APIMod, ruleset_key: int, context: str = "mod") -> None:
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
# Check if mod is valid for ruleset
|
||||
if typed_ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[typed_ruleset_key]:
|
||||
raise InvokeException(f"{context} {mod['acronym']} is invalid for this ruleset")
|
||||
|
||||
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
|
||||
|
||||
# Check if mod is unplayable in multiplayer
|
||||
if mod_settings.get("UserPlayable", True) is False:
|
||||
raise InvokeException(f"{context} {mod['acronym']} is not playable by users")
|
||||
|
||||
if mod_settings.get("ValidForMultiplayer", True) is False:
|
||||
raise InvokeException(f"{context} {mod['acronym']} is not valid for multiplayer")
|
||||
|
||||
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
for i, mod1 in enumerate(mods):
|
||||
mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
|
||||
if mod1_settings:
|
||||
incompatible = set(mod1_settings.get("IncompatibleMods", []))
|
||||
for mod2 in mods[i + 1 :]:
|
||||
if mod2["acronym"] in incompatible:
|
||||
raise InvokeException(f"Mods {mod1['acronym']} and {mod2['acronym']} are incompatible")
|
||||
|
||||
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
|
||||
for req_mod in self.required_mods:
|
||||
req_acronym = req_mod["acronym"]
|
||||
req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
|
||||
if req_settings:
|
||||
incompatible = set(req_settings.get("IncompatibleMods", []))
|
||||
conflicting_allowed = allowed_acronyms & incompatible
|
||||
if conflicting_allowed:
|
||||
conflict_list = ", ".join(conflicting_allowed)
|
||||
raise InvokeException(f"Required mod {req_acronym} conflicts with allowed mods: {conflict_list}")
|
||||
|
||||
def validate_playlist_item_mods(self) -> None:
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
|
||||
|
||||
# Validate required mods
|
||||
for mod in self.required_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
|
||||
|
||||
# Validate allowed mods
|
||||
for mod in self.allowed_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
|
||||
|
||||
# Check internal compatibility of required mods
|
||||
self._check_mod_compatibility(self.required_mods, ruleset_key)
|
||||
|
||||
# Check compatibility between required and allowed mods
|
||||
self._check_required_allowed_compatibility(ruleset_key)
|
||||
|
||||
def validate_user_mods(
|
||||
self,
|
||||
user: "MultiplayerRoomUser",
|
||||
proposed_mods: list[APIMod],
|
||||
) -> tuple[bool, list[APIMod]]:
|
||||
"""
|
||||
Validates user mods against playlist item rules and returns valid mods.
|
||||
Returns (is_valid, valid_mods).
|
||||
"""
|
||||
from typing import Literal, cast
|
||||
|
||||
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
|
||||
|
||||
valid_mods = []
|
||||
all_proposed_valid = True
|
||||
|
||||
# Check if mods are valid for the ruleset
|
||||
for mod in proposed_mods:
|
||||
if ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[ruleset_key]:
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
valid_mods.append(mod)
|
||||
|
||||
# Check mod compatibility within user mods
|
||||
incompatible_mods = set()
|
||||
final_valid_mods = []
|
||||
for mod in valid_mods:
|
||||
if mod["acronym"] in incompatible_mods:
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
|
||||
if setting_mods:
|
||||
incompatible_mods.update(setting_mods["IncompatibleMods"])
|
||||
final_valid_mods.append(mod)
|
||||
|
||||
# If not freestyle, check against allowed mods
|
||||
if not self.freestyle:
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
filtered_valid_mods = []
|
||||
for mod in final_valid_mods:
|
||||
if mod["acronym"] not in allowed_acronyms:
|
||||
all_proposed_valid = False
|
||||
else:
|
||||
filtered_valid_mods.append(mod)
|
||||
final_valid_mods = filtered_valid_mods
|
||||
|
||||
# Check compatibility with required mods
|
||||
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
|
||||
all_mod_acronyms = {mod["acronym"] for mod in final_valid_mods} | required_mod_acronyms
|
||||
|
||||
# Check for incompatibility between required and user mods
|
||||
filtered_valid_mods = []
|
||||
for mod in final_valid_mods:
|
||||
mod_acronym = mod["acronym"]
|
||||
is_compatible = True
|
||||
|
||||
for other_acronym in all_mod_acronyms:
|
||||
if other_acronym == mod_acronym:
|
||||
continue
|
||||
setting_mods = API_MODS[ruleset_key].get(mod_acronym)
|
||||
if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
|
||||
is_compatible = False
|
||||
all_proposed_valid = False
|
||||
break
|
||||
|
||||
if is_compatible:
|
||||
filtered_valid_mods.append(mod)
|
||||
|
||||
return all_proposed_valid, filtered_valid_mods
|
||||
|
||||
def clone(self) -> "PlaylistItem":
|
||||
copy = self.model_copy()
|
||||
copy.required_mods = list(self.required_mods)
|
||||
copy.allowed_mods = list(self.allowed_mods)
|
||||
copy.expired = False
|
||||
copy.played_at = None
|
||||
return copy
|
||||
|
||||
|
||||
class _MultiplayerCountdown(SignalRUnionMessage):
|
||||
id: int = 0
|
||||
time_remaining: timedelta
|
||||
is_exclusive: Annotated[bool, Field(default=True), SignalRMeta(member_ignore=True)] = True
|
||||
|
||||
|
||||
class MatchStartCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
class ForceGameplayStartCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
|
||||
|
||||
class ServerShuttingDownCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
|
||||
|
||||
MultiplayerCountdown = MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
|
||||
|
||||
|
||||
class MultiplayerRoomUser(BaseModel):
|
||||
user_id: int
|
||||
state: MultiplayerUserState = MultiplayerUserState.IDLE
|
||||
availability: BeatmapAvailability = BeatmapAvailability(state=DownloadState.UNKNOWN, download_progress=None)
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
match_state: MatchUserState | None = None
|
||||
ruleset_id: int | None = None # freestyle
|
||||
beatmap_id: int | None = None # freestyle
|
||||
|
||||
|
||||
class MultiplayerRoom(BaseModel):
|
||||
room_id: int
|
||||
state: MultiplayerRoomState
|
||||
settings: MultiplayerRoomSettings
|
||||
users: list[MultiplayerRoomUser] = Field(default_factory=list)
|
||||
host: MultiplayerRoomUser | None = None
|
||||
match_state: MatchRoomState | None = None
|
||||
playlist: list[PlaylistItem] = Field(default_factory=list)
|
||||
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
|
||||
channel_id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, room: "Room") -> "MultiplayerRoom":
|
||||
"""
|
||||
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
|
||||
"""
|
||||
|
||||
# 用户列表
|
||||
users = [MultiplayerRoomUser(user_id=room.host_id)]
|
||||
host_user = MultiplayerRoomUser(user_id=room.host_id)
|
||||
# playlist 转换
|
||||
playlist = []
|
||||
if room.playlist:
|
||||
for item in room.playlist:
|
||||
playlist.append(
|
||||
PlaylistItem(
|
||||
id=item.id,
|
||||
owner_id=item.owner_id,
|
||||
beatmap_id=item.beatmap_id,
|
||||
beatmap_checksum=item.beatmap.checksum if item.beatmap else "",
|
||||
ruleset_id=item.ruleset_id,
|
||||
required_mods=item.required_mods,
|
||||
allowed_mods=item.allowed_mods,
|
||||
expired=item.expired,
|
||||
playlist_order=item.playlist_order,
|
||||
played_at=item.played_at,
|
||||
star_rating=item.beatmap.difficulty_rating if item.beatmap is not None else 0.0,
|
||||
freestyle=item.freestyle,
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
room_id=room.id,
|
||||
state=getattr(room, "state", MultiplayerRoomState.OPEN),
|
||||
settings=MultiplayerRoomSettings(
|
||||
name=room.name,
|
||||
playlist_item_id=playlist[0].id if playlist else 0,
|
||||
password=getattr(room, "password", ""),
|
||||
match_type=room.type,
|
||||
queue_mode=room.queue_mode,
|
||||
auto_start_duration=timedelta(seconds=room.auto_start_duration),
|
||||
auto_skip=room.auto_skip,
|
||||
),
|
||||
users=users,
|
||||
host=host_user,
|
||||
match_state=None,
|
||||
playlist=playlist,
|
||||
active_countdowns=[],
|
||||
channel_id=room.channel_id or 0,
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerQueue:
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
self.server_room = room
|
||||
self.current_index = 0
|
||||
|
||||
@property
|
||||
def hub(self) -> "MultiplayerHub":
|
||||
return self.server_room.hub
|
||||
|
||||
@property
|
||||
def upcoming_items(self):
|
||||
return sorted(
|
||||
(item for item in self.room.playlist if not item.expired),
|
||||
key=lambda i: i.playlist_order,
|
||||
)
|
||||
|
||||
@property
|
||||
def room(self):
|
||||
return self.server_room.room
|
||||
|
||||
async def update_order(self):
|
||||
from app.database import Playlist
|
||||
|
||||
match self.room.settings.queue_mode:
|
||||
case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
|
||||
ordered_active_items = []
|
||||
|
||||
is_first_set = True
|
||||
first_set_order_by_user_id = {}
|
||||
|
||||
active_items = [item for item in self.room.playlist if not item.expired]
|
||||
active_items.sort(key=lambda x: x.id)
|
||||
|
||||
user_item_groups = {}
|
||||
for item in active_items:
|
||||
if item.owner_id not in user_item_groups:
|
||||
user_item_groups[item.owner_id] = []
|
||||
user_item_groups[item.owner_id].append(item)
|
||||
|
||||
max_items = max((len(items) for items in user_item_groups.values()), default=0)
|
||||
|
||||
for i in range(max_items):
|
||||
current_set = []
|
||||
for user_id, items in user_item_groups.items():
|
||||
if i < len(items):
|
||||
current_set.append(items[i])
|
||||
|
||||
if is_first_set:
|
||||
current_set.sort(key=lambda item: (item.playlist_order, item.id))
|
||||
ordered_active_items.extend(current_set)
|
||||
first_set_order_by_user_id = {
|
||||
item.owner_id: idx for idx, item in enumerate(ordered_active_items)
|
||||
}
|
||||
else:
|
||||
current_set.sort(key=lambda item: first_set_order_by_user_id.get(item.owner_id, 0))
|
||||
ordered_active_items.extend(current_set)
|
||||
|
||||
is_first_set = False
|
||||
case _:
|
||||
ordered_active_items = sorted(
|
||||
(item for item in self.room.playlist if not item.expired),
|
||||
key=lambda x: x.id,
|
||||
)
|
||||
async with with_db() as session:
|
||||
for idx, item in enumerate(ordered_active_items):
|
||||
if item.playlist_order == idx:
|
||||
continue
|
||||
item.playlist_order = idx
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
await self.hub.playlist_changed(self.server_room, item, beatmap_changed=False)
|
||||
|
||||
async def update_current_item(self):
|
||||
upcoming_items = self.upcoming_items
|
||||
if upcoming_items:
|
||||
# 优先选择未过期的项目
|
||||
next_item = upcoming_items[0]
|
||||
else:
|
||||
# 如果所有项目都过期了,选择最近添加的项目(played_at 为 None 或最新的)
|
||||
# 优先选择 expired=False 的项目,然后是 played_at 最晚的
|
||||
next_item = max(
|
||||
self.room.playlist,
|
||||
key=lambda i: (not i.expired, i.played_at or datetime.min),
|
||||
)
|
||||
self.current_index = self.room.playlist.index(next_item)
|
||||
last_id = self.room.settings.playlist_item_id
|
||||
self.room.settings.playlist_item_id = next_item.id
|
||||
if last_id != next_item.id:
|
||||
await self.hub.setting_changed(self.server_room, True)
|
||||
|
||||
async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
is_host = self.room.host and self.room.host.user_id == user.user_id
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
|
||||
raise InvokeException("You are not the host")
|
||||
|
||||
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
|
||||
if len([True for u in self.room.playlist if u.owner_id == user.user_id and not u.expired]) >= limit:
|
||||
raise InvokeException(f"You can only have {limit} items in the queue")
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||
|
||||
async with with_db() as session:
|
||||
fetcher = await get_fetcher()
|
||||
async with session:
|
||||
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
|
||||
if beatmap is None:
|
||||
raise InvokeException("Beatmap not found")
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = beatmap.difficulty_rating
|
||||
await Playlist.add_to_db(item, self.room.room_id, session)
|
||||
self.room.playlist.append(item)
|
||||
await self.hub.playlist_added(self.server_room, item)
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
|
||||
async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
raise InvokeException("Freestyle items cannot have allowed mods")
|
||||
|
||||
async with with_db() as session:
|
||||
fetcher = await get_fetcher()
|
||||
async with session:
|
||||
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
existing_item = next((i for i in self.room.playlist if i.id == item.id), None)
|
||||
if existing_item is None:
|
||||
raise InvokeException("Attempted to change an item that doesn't exist")
|
||||
|
||||
if existing_item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException("Attempted to change an item which is not owned by the user")
|
||||
|
||||
if existing_item.expired:
|
||||
raise InvokeException("Attempted to change an item which has already been played")
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = float(beatmap.difficulty_rating)
|
||||
item.playlist_order = existing_item.playlist_order
|
||||
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
|
||||
# Update item in playlist
|
||||
for idx, playlist_item in enumerate(self.room.playlist):
|
||||
if playlist_item.id == item.id:
|
||||
self.room.playlist[idx] = item
|
||||
break
|
||||
|
||||
await self.hub.playlist_changed(
|
||||
self.server_room,
|
||||
item,
|
||||
beatmap_changed=item.beatmap_checksum != existing_item.beatmap_checksum,
|
||||
)
|
||||
|
||||
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
|
||||
from app.database import Playlist
|
||||
|
||||
item = next(
|
||||
(i for i in self.room.playlist if i.id == playlist_item_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if item is None:
|
||||
raise InvokeException("Item does not exist in the room")
|
||||
|
||||
# Check if it's the only item and current item
|
||||
if item == self.current_item:
|
||||
upcoming_items = [i for i in self.room.playlist if not i.expired]
|
||||
if len(upcoming_items) == 1:
|
||||
raise InvokeException("The only item in the room cannot be removed")
|
||||
|
||||
if item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException("Attempted to remove an item which is not owned by the user")
|
||||
|
||||
if item.expired:
|
||||
raise InvokeException("Attempted to remove an item which has already been played")
|
||||
|
||||
async with with_db() as session:
|
||||
await Playlist.delete_item(item.id, self.room.room_id, session)
|
||||
|
||||
found_item = next((i for i in self.room.playlist if i.id == item.id), None)
|
||||
if found_item:
|
||||
self.room.playlist.remove(found_item)
|
||||
self.current_index = self.room.playlist.index(self.upcoming_items[0])
|
||||
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
await self.hub.playlist_removed(self.server_room, item.id)
|
||||
|
||||
async def finish_current_item(self):
|
||||
from app.database import Playlist
|
||||
|
||||
async with with_db() as session:
|
||||
played_at = utcnow()
|
||||
await session.execute(
|
||||
update(Playlist)
|
||||
.where(
|
||||
col(Playlist.id) == self.current_item.id,
|
||||
col(Playlist.room_id) == self.room.room_id,
|
||||
)
|
||||
.values(expired=True, played_at=played_at)
|
||||
)
|
||||
self.room.playlist[self.current_index].expired = True
|
||||
self.room.playlist[self.current_index].played_at = played_at
|
||||
await self.hub.playlist_changed(self.server_room, self.current_item, True)
|
||||
await self.update_order()
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||
playitem.expired for playitem in self.room.playlist
|
||||
):
|
||||
assert self.room.host
|
||||
await self.add_item(self.current_item.clone(), self.room.host)
|
||||
await self.update_current_item()
|
||||
|
||||
async def update_queue_mode(self):
|
||||
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
|
||||
playitem.expired for playitem in self.room.playlist
|
||||
):
|
||||
assert self.room.host
|
||||
await self.add_item(self.current_item.clone(), self.room.host)
|
||||
await self.update_order()
|
||||
await self.update_current_item()
|
||||
|
||||
@property
|
||||
def current_item(self):
|
||||
return self.room.playlist[self.current_index]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CountdownInfo:
|
||||
countdown: MultiplayerCountdown
|
||||
duration: timedelta
|
||||
task: asyncio.Task | None = None
|
||||
|
||||
def __init__(self, countdown: MultiplayerCountdown):
|
||||
self.countdown = countdown
|
||||
self.duration = (
|
||||
countdown.time_remaining if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0)
|
||||
)
|
||||
|
||||
|
||||
class _MatchRequest(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class ChangeTeamRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
team_id: int
|
||||
|
||||
|
||||
class StartMatchCountdownRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
duration: timedelta
|
||||
|
||||
|
||||
class StopCountdownRequest(_MatchRequest):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
id: int
|
||||
|
||||
|
||||
MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
|
||||
|
||||
|
||||
class MatchTypeHandler(ABC):
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
self.room = room
|
||||
self.hub = room.hub
|
||||
|
||||
@abstractmethod
|
||||
async def handle_join(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_details(self) -> MatchStartedEventDetail: ...
|
||||
|
||||
|
||||
class HeadToHeadHandler(MatchTypeHandler):
|
||||
@override
|
||||
async def handle_join(self, user: MultiplayerRoomUser):
|
||||
if user.match_state is not None:
|
||||
user.match_state = None
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@override
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
detail = MatchStartedEventDetail(room_type="head_to_head", team=None)
|
||||
return detail
|
||||
|
||||
|
||||
class TeamVersusHandler(MatchTypeHandler):
|
||||
@override
|
||||
def __init__(self, room: "ServerMultiplayerRoom"):
|
||||
super().__init__(room)
|
||||
self.state = TeamVersusRoomState()
|
||||
room.room.match_state = self.state
|
||||
task = asyncio.create_task(self.hub.change_room_match_state(self.room))
|
||||
self.hub.tasks.add(task)
|
||||
task.add_done_callback(self.hub.tasks.discard)
|
||||
|
||||
def _get_best_available_team(self) -> int:
|
||||
for team in self.state.teams:
|
||||
if all(
|
||||
(
|
||||
user.match_state is None
|
||||
or not isinstance(user.match_state, TeamVersusUserState)
|
||||
or user.match_state.team_id != team.id
|
||||
)
|
||||
for user in self.room.room.users
|
||||
):
|
||||
return team.id
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
team_counts = defaultdict(int)
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
|
||||
team_counts[user.match_state.team_id] += 1
|
||||
|
||||
if team_counts:
|
||||
min_count = min(team_counts.values())
|
||||
for team_id, count in team_counts.items():
|
||||
if count == min_count:
|
||||
return team_id
|
||||
return self.state.teams[0].id if self.state.teams else 0
|
||||
|
||||
@override
|
||||
async def handle_join(self, user: MultiplayerRoomUser):
|
||||
best_team_id = self._get_best_available_team()
|
||||
user.match_state = TeamVersusUserState(team_id=best_team_id)
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
|
||||
if not isinstance(request, ChangeTeamRequest):
|
||||
return
|
||||
|
||||
if request.team_id not in [team.id for team in self.state.teams]:
|
||||
raise InvokeException("Invalid team ID")
|
||||
|
||||
user.match_state = TeamVersusUserState(team_id=request.team_id)
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@override
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
teams: dict[int, Literal["blue", "red"]] = {}
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
|
||||
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
|
||||
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
|
||||
return detail
|
||||
|
||||
|
||||
MATCH_TYPE_HANDLERS = {
|
||||
MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
|
||||
MatchType.TEAM_VERSUS: TeamVersusHandler,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerMultiplayerRoom:
|
||||
room: MultiplayerRoom
|
||||
category: RoomCategory
|
||||
status: RoomStatus
|
||||
start_at: datetime
|
||||
hub: "MultiplayerHub"
|
||||
match_type_handler: MatchTypeHandler
|
||||
queue: MultiplayerQueue
|
||||
_next_countdown_id: int
|
||||
_countdown_id_lock: asyncio.Lock
|
||||
_tracked_countdown: dict[int, CountdownInfo]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
room: MultiplayerRoom,
|
||||
category: RoomCategory,
|
||||
start_at: datetime,
|
||||
hub: "MultiplayerHub",
|
||||
):
|
||||
self.room = room
|
||||
self.category = category
|
||||
self.status = RoomStatus.IDLE
|
||||
self.start_at = start_at
|
||||
self.hub = hub
|
||||
self.queue = MultiplayerQueue(self)
|
||||
self._next_countdown_id = 0
|
||||
self._countdown_id_lock = asyncio.Lock()
|
||||
self._tracked_countdown = {}
|
||||
|
||||
async def set_handler(self):
|
||||
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](self)
|
||||
for i in self.room.users:
|
||||
await self.match_type_handler.handle_join(i)
|
||||
|
||||
async def get_next_countdown_id(self) -> int:
|
||||
async with self._countdown_id_lock:
|
||||
self._next_countdown_id += 1
|
||||
return self._next_countdown_id
|
||||
|
||||
async def start_countdown(
|
||||
self,
|
||||
countdown: MultiplayerCountdown,
|
||||
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
|
||||
):
|
||||
async def _countdown_task(self: "ServerMultiplayerRoom"):
|
||||
await asyncio.sleep(info.duration.total_seconds())
|
||||
if on_complete is not None:
|
||||
await on_complete(self)
|
||||
await self.stop_countdown(countdown)
|
||||
|
||||
if countdown.is_exclusive:
|
||||
await self.stop_all_countdowns(countdown.__class__)
|
||||
countdown.id = await self.get_next_countdown_id()
|
||||
info = CountdownInfo(countdown)
|
||||
self.room.active_countdowns.append(info.countdown)
|
||||
self._tracked_countdown[countdown.id] = info
|
||||
await self.hub.send_match_event(self, CountdownStartedEvent(countdown=info.countdown))
|
||||
info.task = asyncio.create_task(_countdown_task(self))
|
||||
|
||||
async def stop_countdown(self, countdown: MultiplayerCountdown):
|
||||
info = self._tracked_countdown.get(countdown.id)
|
||||
if info is None:
|
||||
return
|
||||
del self._tracked_countdown[countdown.id]
|
||||
self.room.active_countdowns.remove(countdown)
|
||||
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
|
||||
if info.task is not None and not info.task.done():
|
||||
info.task.cancel()
|
||||
|
||||
async def stop_all_countdowns(self, typ: type[MultiplayerCountdown]):
|
||||
for countdown in list(self._tracked_countdown.values()):
|
||||
if isinstance(countdown.countdown, typ):
|
||||
await self.stop_countdown(countdown.countdown)
|
||||
|
||||
|
||||
class _MatchServerEvent(SignalRUnionMessage): ...
|
||||
|
||||
|
||||
class CountdownStartedEvent(_MatchServerEvent):
|
||||
countdown: MultiplayerCountdown
|
||||
|
||||
union_type: ClassVar[Literal[0]] = 0
|
||||
|
||||
|
||||
class CountdownStoppedEvent(_MatchServerEvent):
|
||||
id: int
|
||||
|
||||
union_type: ClassVar[Literal[1]] = 1
|
||||
|
||||
|
||||
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
|
||||
|
||||
|
||||
class GameplayAbortReason(IntEnum):
|
||||
LOAD_TOOK_TOO_LONG = 0
|
||||
HOST_ABORTED = 1
|
||||
|
||||
|
||||
class MatchStartedEventDetail(TypedDict):
|
||||
room_type: Literal["playlists", "head_to_head", "team_versus"]
|
||||
team: dict[int, Literal["blue", "red"]] | None
|
||||
|
||||
22
app/models/playlist.py
Normal file
22
app/models/playlist.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.mods import APIMod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PlaylistItem(BaseModel):
|
||||
id: int = Field(default=0, ge=-1)
|
||||
owner_id: int
|
||||
beatmap_id: int
|
||||
beatmap_checksum: str = ""
|
||||
ruleset_id: int = 0
|
||||
required_mods: list[APIMod] = Field(default_factory=list)
|
||||
allowed_mods: list[APIMod] = Field(default_factory=list)
|
||||
expired: bool = False
|
||||
playlist_order: int = 0
|
||||
played_at: datetime | None = None
|
||||
star_rating: float = 0.0
|
||||
freestyle: bool = False
|
||||
@@ -1,37 +1 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignalRMeta:
|
||||
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
|
||||
json_ignore: bool = False # implement of JsonIgnore (json) attribute
|
||||
use_abbr: bool = True
|
||||
|
||||
|
||||
class SignalRUnionMessage(BaseModel):
|
||||
union_type: ClassVar[int]
|
||||
|
||||
|
||||
class Transport(BaseModel):
|
||||
transport: str
|
||||
transfer_formats: list[str] = Field(default_factory=lambda: ["Binary", "Text"], alias="transferFormats")
|
||||
|
||||
|
||||
class NegotiateResponse(BaseModel):
|
||||
connectionId: str
|
||||
connectionToken: str
|
||||
negotiateVersion: int = 1
|
||||
availableTransports: list[Transport]
|
||||
|
||||
|
||||
class UserState(BaseModel):
|
||||
connection_id: str
|
||||
connection_token: str
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from enum import IntEnum
|
||||
from typing import Annotated, Any
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import APIMod
|
||||
|
||||
from .score import (
|
||||
ScoreStatistics,
|
||||
)
|
||||
from .signalr import SignalRMeta, UserState
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SpectatedUserState(IntEnum):
|
||||
Idle = 0
|
||||
Playing = 1
|
||||
Paused = 2
|
||||
Passed = 3
|
||||
Failed = 4
|
||||
Quit = 5
|
||||
|
||||
|
||||
class SpectatorState(BaseModel):
|
||||
beatmap_id: int | None = None
|
||||
ruleset_id: int | None = None # 0,1,2,3
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
state: SpectatedUserState
|
||||
maximum_statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SpectatorState):
|
||||
return False
|
||||
return (
|
||||
self.beatmap_id == other.beatmap_id
|
||||
and self.ruleset_id == other.ruleset_id
|
||||
and self.mods == other.mods
|
||||
and self.state == other.state
|
||||
)
|
||||
|
||||
|
||||
class ScoreProcessorStatistics(BaseModel):
|
||||
base_score: float
|
||||
maximum_base_score: float
|
||||
accuracy_judgement_count: int
|
||||
combo_portion: float
|
||||
bonus_portion: float
|
||||
|
||||
|
||||
class FrameHeader(BaseModel):
|
||||
total_score: int
|
||||
accuracy: float
|
||||
combo: int
|
||||
max_combo: int
|
||||
statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
score_processor_statistics: ScoreProcessorStatistics
|
||||
received_time: datetime.datetime
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
|
||||
@field_validator("received_time", mode="before")
|
||||
@classmethod
|
||||
def validate_timestamp(cls, v: Any) -> datetime.datetime:
|
||||
if isinstance(v, list):
|
||||
return v[0]
|
||||
if isinstance(v, datetime.datetime):
|
||||
return v
|
||||
if isinstance(v, int | float):
|
||||
return datetime.datetime.fromtimestamp(v, tz=datetime.UTC)
|
||||
if isinstance(v, str):
|
||||
return datetime.datetime.fromisoformat(v)
|
||||
raise ValueError(f"Cannot convert {type(v)} to datetime")
|
||||
|
||||
|
||||
# class ReplayButtonState(IntEnum):
|
||||
# NONE = 0
|
||||
# LEFT1 = 1
|
||||
# RIGHT1 = 2
|
||||
# LEFT2 = 4
|
||||
# RIGHT2 = 8
|
||||
# SMOKE = 16
|
||||
|
||||
|
||||
class LegacyReplayFrame(BaseModel):
|
||||
time: float # from ReplayFrame,the parent of LegacyReplayFrame
|
||||
mouse_x: float | None = None
|
||||
mouse_y: float | None = None
|
||||
button_state: int
|
||||
|
||||
header: Annotated[FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)]
|
||||
|
||||
|
||||
class FrameDataBundle(BaseModel):
|
||||
header: FrameHeader
|
||||
frames: list[LegacyReplayFrame]
|
||||
|
||||
|
||||
# Use for server
|
||||
class APIUser(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class ScoreInfo(BaseModel):
|
||||
mods: list[APIMod]
|
||||
user: APIUser
|
||||
ruleset: int
|
||||
maximum_statistics: ScoreStatistics
|
||||
id: int | None = None
|
||||
total_score: int | None = None
|
||||
accuracy: float | None = None
|
||||
max_combo: int | None = None
|
||||
combo: int | None = None
|
||||
statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StoreScore(BaseModel):
|
||||
score_info: ScoreInfo
|
||||
replay_frames: list[LegacyReplayFrame] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StoreClientState(UserState):
|
||||
state: SpectatorState | None = None
|
||||
beatmap_status: BeatmapRankStatus | None = None
|
||||
checksum: str | None = None
|
||||
ruleset_id: int | None = None
|
||||
score_token: int | None = None
|
||||
watched_user: set[int] = Field(default_factory=set)
|
||||
score: StoreScore | None = None
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# from app.signalr import signalr_router as signalr_router
|
||||
from .auth import router as auth_router
|
||||
from .fetcher import fetcher_router as fetcher_router
|
||||
from .file import file_router as file_router
|
||||
@@ -25,5 +24,4 @@ __all__ = [
|
||||
"private_router",
|
||||
"redirect_api_router",
|
||||
"redirect_router",
|
||||
# "signalr_router",
|
||||
]
|
||||
|
||||
@@ -15,7 +15,7 @@ from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.fetcher import Fetcher
|
||||
from app.dependencies.storage import StorageService
|
||||
from app.log import log
|
||||
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
|
||||
from app.models.playlist import PlaylistItem
|
||||
from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus
|
||||
from app.utils import utcnow
|
||||
|
||||
@@ -216,7 +216,7 @@ async def _add_playlist_items(db: Database, room_id: int, room_data: dict[str, A
|
||||
|
||||
# Insert playlist items
|
||||
for item_data in items_raw:
|
||||
hub_item = HubPlaylistItem(
|
||||
playlist_item = PlaylistItem(
|
||||
id=-1, # Placeholder, will be assigned by add_to_db
|
||||
owner_id=item_data["owner_id"],
|
||||
ruleset_id=item_data["ruleset_id"],
|
||||
@@ -230,7 +230,7 @@ async def _add_playlist_items(db: Database, room_id: int, room_data: dict[str, A
|
||||
beatmap_checksum=item_data["beatmap_checksum"],
|
||||
star_rating=item_data["star_rating"],
|
||||
)
|
||||
await DBPlaylist.add_to_db(hub_item, room_id=room_id, session=db)
|
||||
await DBPlaylist.add_to_db(playlist_item, room_id=room_id, session=db)
|
||||
|
||||
|
||||
async def _add_host_as_participant(db: Database, room_id: int, host_user_id: int) -> None:
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import timedelta
|
||||
from math import ceil
|
||||
import random
|
||||
import shlex
|
||||
@@ -10,27 +9,15 @@ import shlex
|
||||
from app.calculator import calculate_weighted_pp
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
|
||||
from app.database.score import Score, get_best_id
|
||||
from app.database.statistics import UserStatistics, get_rank
|
||||
from app.database.user import User
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.exception import InvokeException
|
||||
from app.models.mods import APIMod, get_available_mods, mod_to_save
|
||||
from app.models.multiplayer_hub import (
|
||||
ChangeTeamRequest,
|
||||
ServerMultiplayerRoom,
|
||||
StartMatchCountdownRequest,
|
||||
)
|
||||
from app.models.room import MatchType, QueueMode, RoomStatus
|
||||
from app.models.mods import mod_to_save
|
||||
from app.models.score import GameMode
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
from app.signalr.hub.hub import Client
|
||||
|
||||
from .server import server
|
||||
|
||||
from httpx import HTTPError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -216,352 +203,6 @@ PP: {statistics.pp:.2f}
|
||||
"""
|
||||
|
||||
|
||||
async def _mp_name(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp name <name>"
|
||||
|
||||
name = args[0]
|
||||
try:
|
||||
settings = room.room.settings.model_copy()
|
||||
settings.name = name
|
||||
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
|
||||
return f"Room name has changed to {name}"
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_set(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp set <teammode> [<queuemode>]"
|
||||
|
||||
teammode = {"0": MatchType.HEAD_TO_HEAD, "2": MatchType.TEAM_VERSUS}.get(args[0])
|
||||
if not teammode:
|
||||
return "Invalid teammode. Use 0 for Head-to-Head or 2 for Team Versus."
|
||||
queuemode = (
|
||||
{
|
||||
"0": QueueMode.HOST_ONLY,
|
||||
"1": QueueMode.ALL_PLAYERS,
|
||||
"2": QueueMode.ALL_PLAYERS_ROUND_ROBIN,
|
||||
}.get(args[1])
|
||||
if len(args) >= 2
|
||||
else None
|
||||
)
|
||||
try:
|
||||
settings = room.room.settings.model_copy()
|
||||
settings.match_type = teammode
|
||||
if queuemode:
|
||||
settings.queue_mode = queuemode
|
||||
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
|
||||
return f"Room setting 'teammode' has been changed to {teammode.name.lower()}"
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_host(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp host <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.TransferHost(signalr_client, user_id)
|
||||
return f"User '{username}' has been hosted in the room."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_start(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
timer = None
|
||||
if len(args) >= 1 and args[0].isdigit():
|
||||
timer = int(args[0])
|
||||
|
||||
try:
|
||||
if timer is not None:
|
||||
await MultiplayerHubs.SendMatchRequest(
|
||||
signalr_client,
|
||||
StartMatchCountdownRequest(duration=timedelta(seconds=timer)),
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
await MultiplayerHubs.StartMatch(signalr_client)
|
||||
return "Good luck! Enjoy game!"
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_abort(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
try:
|
||||
await MultiplayerHubs.AbortMatch(signalr_client)
|
||||
return "Match aborted."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_team(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
):
|
||||
if room.room.settings.match_type != MatchType.TEAM_VERSUS:
|
||||
return "This command is only available in Team Versus mode."
|
||||
|
||||
if len(args) < 2:
|
||||
return "Usage: !mp team <username> <colour>"
|
||||
|
||||
username = args[0]
|
||||
team = {"red": 0, "blue": 1}.get(args[1])
|
||||
if team is None:
|
||||
return "Invalid team colour. Use 'red' or 'blue'."
|
||||
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
|
||||
if not user_client:
|
||||
return f"User '{username}' is not in the room."
|
||||
assert room.room.host
|
||||
if user_client.user_id != signalr_client.user_id and room.room.host.user_id != signalr_client.user_id:
|
||||
return "You are not allowed to change other users' teams."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.SendMatchRequest(user_client, ChangeTeamRequest(team_id=team))
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_password(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
password = ""
|
||||
if len(args) >= 1:
|
||||
password = args[0]
|
||||
|
||||
try:
|
||||
settings = room.room.settings.model_copy()
|
||||
settings.password = password
|
||||
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
|
||||
return "Room password has been set."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_kick(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp kick <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.KickUser(signalr_client, user_id)
|
||||
return f"User '{username}' has been kicked from the room."
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_map(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp map <mapid> [<playmode>]"
|
||||
|
||||
if room.status != RoomStatus.IDLE:
|
||||
return "Cannot change map while the game is running."
|
||||
|
||||
map_id = args[0]
|
||||
if not map_id.isdigit():
|
||||
return "Invalid map ID."
|
||||
map_id = int(map_id)
|
||||
playmode = GameMode.parse(args[1].upper()) if len(args) >= 2 else None
|
||||
if playmode not in (
|
||||
GameMode.OSU,
|
||||
GameMode.TAIKO,
|
||||
GameMode.FRUITS,
|
||||
GameMode.MANIA,
|
||||
None,
|
||||
):
|
||||
return "Invalid playmode."
|
||||
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
|
||||
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
|
||||
return f"Cannot convert to {playmode.value}. Original mode is {beatmap.mode.value}."
|
||||
except HTTPError:
|
||||
return "Beatmap not found"
|
||||
|
||||
try:
|
||||
current_item = room.queue.current_item
|
||||
item = current_item.model_copy(deep=True)
|
||||
item.owner_id = signalr_client.user_id
|
||||
item.beatmap_checksum = beatmap.checksum
|
||||
item.required_mods = []
|
||||
item.allowed_mods = []
|
||||
item.freestyle = False
|
||||
item.beatmap_id = map_id
|
||||
if playmode is not None:
|
||||
item.ruleset_id = int(playmode)
|
||||
if item.expired:
|
||||
item.id = 0
|
||||
item.expired = False
|
||||
item.played_at = None
|
||||
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
|
||||
else:
|
||||
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
async def _mp_mods(
|
||||
signalr_client: Client,
|
||||
room: ServerMultiplayerRoom,
|
||||
args: list[str],
|
||||
session: AsyncSession,
|
||||
) -> str:
|
||||
if len(args) < 1:
|
||||
return "Usage: !mp mods <mod1> [<mod2> ...]"
|
||||
|
||||
if room.status != RoomStatus.IDLE:
|
||||
return "Cannot change mods while the game is running."
|
||||
|
||||
required_mods = []
|
||||
allowed_mods = []
|
||||
freestyle = False
|
||||
freemod = False
|
||||
for arg in args:
|
||||
arg = arg.upper()
|
||||
if arg == "NONE":
|
||||
required_mods.clear()
|
||||
allowed_mods.clear()
|
||||
break
|
||||
elif arg == "FREESTYLE":
|
||||
freestyle = True
|
||||
elif arg == "FREEMOD":
|
||||
freemod = True
|
||||
elif arg.startswith("+"):
|
||||
mod = arg.removeprefix("+")
|
||||
if len(mod) != 2:
|
||||
return f"Invalid mod: {mod}."
|
||||
allowed_mods.append(APIMod(acronym=mod))
|
||||
else:
|
||||
if len(arg) != 2:
|
||||
return f"Invalid mod: {arg}."
|
||||
required_mods.append(APIMod(acronym=arg))
|
||||
|
||||
try:
|
||||
current_item = room.queue.current_item
|
||||
item = current_item.model_copy(deep=True)
|
||||
item.owner_id = signalr_client.user_id
|
||||
item.freestyle = freestyle
|
||||
if freestyle:
|
||||
item.allowed_mods = []
|
||||
elif freemod:
|
||||
item.allowed_mods = get_available_mods(current_item.ruleset_id, required_mods)
|
||||
else:
|
||||
item.allowed_mods = allowed_mods
|
||||
item.required_mods = required_mods
|
||||
if item.expired:
|
||||
item.id = 0
|
||||
item.expired = False
|
||||
item.played_at = None
|
||||
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
|
||||
else:
|
||||
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
|
||||
|
||||
_MP_COMMANDS = {
|
||||
"name": _mp_name,
|
||||
"set": _mp_set,
|
||||
"host": _mp_host,
|
||||
"start": _mp_start,
|
||||
"abort": _mp_abort,
|
||||
"map": _mp_map,
|
||||
"mods": _mp_mods,
|
||||
"kick": _mp_kick,
|
||||
"password": _mp_password,
|
||||
"team": _mp_team,
|
||||
}
|
||||
_MP_HELP = """!mp name <name>
|
||||
!mp set <teammode> [<queuemode>]
|
||||
!mp host <host>
|
||||
!mp start [<timer>]
|
||||
!mp abort
|
||||
!mp map <map> [<playmode>]
|
||||
!mp mods <mod1> [<mod2> ...]
|
||||
!mp kick <user>
|
||||
!mp password [<password>]
|
||||
!mp team <user> <team:red|blue>"""
|
||||
|
||||
|
||||
@bot.command("mp")
|
||||
async def _mp(user: User, args: list[str], session: AsyncSession, channel: ChatChannel):
|
||||
if not channel.name.startswith("room_"):
|
||||
return
|
||||
|
||||
room_id = int(channel.name[5:])
|
||||
room = MultiplayerHubs.rooms.get(room_id)
|
||||
if not room:
|
||||
return
|
||||
signalr_client = MultiplayerHubs.get_client_by_id(str(user.id))
|
||||
if not signalr_client:
|
||||
return
|
||||
|
||||
if len(args) < 1:
|
||||
return f"Usage: !mp <{'|'.join(_MP_COMMANDS.keys())}> [args]"
|
||||
|
||||
command = args[0].lower()
|
||||
if command not in _MP_COMMANDS:
|
||||
return f"No such command: {command}"
|
||||
|
||||
return await _MP_COMMANDS[command](signalr_client, room, args[1:], session)
|
||||
|
||||
|
||||
async def _score(
|
||||
user_id: int,
|
||||
session: AsyncSession,
|
||||
|
||||
@@ -16,7 +16,6 @@ from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.user import ClientUser, get_current_user
|
||||
from app.models.room import RoomCategory, RoomStatus
|
||||
from app.service.room import create_playlist_room_from_api
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
from app.utils import utcnow
|
||||
|
||||
from .router import router
|
||||
@@ -391,14 +390,12 @@ async def get_room_events(
|
||||
first_event_id = min(first_event_id, event.id)
|
||||
last_event_id = max(last_event_id, event.id)
|
||||
|
||||
if room := MultiplayerHubs.rooms.get(room_id):
|
||||
current_playlist_item_id = room.queue.current_item.id
|
||||
room_resp = await RoomResp.from_hub(room)
|
||||
else:
|
||||
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
room_resp = await RoomResp.from_db(room, db)
|
||||
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
room_resp = await RoomResp.from_db(room, db)
|
||||
if room.category == RoomCategory.REALTIME and room_resp.current_playlist_item:
|
||||
current_playlist_item_id = room_resp.current_playlist_item.id
|
||||
|
||||
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||
user_resps = [await UserResp.from_db(user, db) for user in users]
|
||||
|
||||
@@ -217,8 +217,7 @@ class MessageQueueProcessor:
|
||||
):
|
||||
"""通知客户端消息ID已更新"""
|
||||
try:
|
||||
# 这里我们需要通过 SignalR 发送消息更新通知
|
||||
# 但为了避免循环依赖,我们将通过 Redis 发布消息更新事件
|
||||
# 通过 Redis 发布消息更新事件,由聊天通知服务分发到客户端
|
||||
update_event = {
|
||||
"event": "chat.message.update",
|
||||
"data": {
|
||||
@@ -229,7 +228,6 @@ class MessageQueueProcessor:
|
||||
},
|
||||
}
|
||||
|
||||
# 发布到 Redis 频道,让 SignalR 服务处理
|
||||
await self._redis_exec(
|
||||
self.redis_message.publish,
|
||||
f"chat_updates:{channel_id}",
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.database import PlaylistBestScore, Score
|
||||
from app.database.playlist_best_score import get_position
|
||||
from app.dependencies.database import with_db
|
||||
from app.models.metadata_hub import MultiplayerRoomScoreSetEvent
|
||||
|
||||
from .base import RedisSubscriber
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.signalr.hub import MetadataHub
|
||||
|
||||
|
||||
CHANNEL = "osu-channel:score:processed"
|
||||
|
||||
|
||||
class ScoreSubscriber(RedisSubscriber):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.room_subscriber: dict[int, list[int]] = {}
|
||||
self.metadata_hub: "MetadataHub | None " = None
|
||||
self.subscribed = False
|
||||
self.handlers[CHANNEL] = [self._handler]
|
||||
|
||||
async def subscribe_room_score(self, room_id: int, user_id: int):
|
||||
if room_id not in self.room_subscriber:
|
||||
await self.subscribe(CHANNEL)
|
||||
self.start()
|
||||
self.room_subscriber.setdefault(room_id, []).append(user_id)
|
||||
|
||||
async def unsubscribe_room_score(self, room_id: int, user_id: int):
|
||||
if room_id in self.room_subscriber:
|
||||
try:
|
||||
self.room_subscriber[room_id].remove(user_id)
|
||||
except ValueError:
|
||||
pass
|
||||
if not self.room_subscriber[room_id]:
|
||||
del self.room_subscriber[room_id]
|
||||
|
||||
async def _notify_room_score_processed(self, score_id: int):
|
||||
if not self.metadata_hub:
|
||||
return
|
||||
async with with_db() as session:
|
||||
score = await session.get(Score, score_id)
|
||||
if not score or not score.passed or score.room_id is None or score.playlist_item_id is None:
|
||||
return
|
||||
if not self.room_subscriber.get(score.room_id, []):
|
||||
return
|
||||
|
||||
new_rank = None
|
||||
user_best = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.user_id == score.user_id,
|
||||
PlaylistBestScore.room_id == score.room_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if user_best and user_best.score_id == score_id:
|
||||
new_rank = await get_position(
|
||||
user_best.room_id,
|
||||
user_best.playlist_id,
|
||||
user_best.score_id,
|
||||
session,
|
||||
)
|
||||
|
||||
event = MultiplayerRoomScoreSetEvent(
|
||||
room_id=score.room_id,
|
||||
playlist_item_id=score.playlist_item_id,
|
||||
score_id=score_id,
|
||||
user_id=score.user_id,
|
||||
total_score=score.total_score,
|
||||
new_rank=new_rank,
|
||||
)
|
||||
await self.metadata_hub.notify_room_score_processed(event)
|
||||
|
||||
async def _handler(self, channel: str, data: str):
|
||||
score_id = json.loads(data)["ScoreId"]
|
||||
if self.metadata_hub:
|
||||
await self._notify_room_score_processed(score_id)
|
||||
@@ -1,5 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .router import router as signalr_router
|
||||
|
||||
__all__ = ["signalr_router"]
|
||||
@@ -1,15 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .hub import Hub
|
||||
from .metadata import MetadataHub
|
||||
from .multiplayer import MultiplayerHub
|
||||
from .spectator import SpectatorHub
|
||||
|
||||
SpectatorHubs = SpectatorHub()
|
||||
MultiplayerHubs = MultiplayerHub()
|
||||
MetadataHubs = MetadataHub()
|
||||
Hubs: dict[str, Hub] = {
|
||||
"spectator": SpectatorHubs,
|
||||
"multiplayer": MultiplayerHubs,
|
||||
"metadata": MetadataHubs,
|
||||
}
|
||||
@@ -1,322 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.exception import InvokeException
|
||||
from app.log import logger
|
||||
from app.models.signalr import UserState
|
||||
from app.signalr.packet import (
|
||||
ClosePacket,
|
||||
CompletionPacket,
|
||||
InvocationPacket,
|
||||
Packet,
|
||||
PingPacket,
|
||||
Protocol,
|
||||
)
|
||||
from app.signalr.store import ResultStore
|
||||
from app.signalr.utils import get_signature
|
||||
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
|
||||
class CloseConnection(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Connection closed",
|
||||
allow_reconnect: bool = False,
|
||||
from_client: bool = False,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.allow_reconnect = allow_reconnect
|
||||
self.from_client = from_client
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self,
|
||||
connection_id: str,
|
||||
connection_token: str,
|
||||
connection: WebSocket,
|
||||
protocol: Protocol,
|
||||
) -> None:
|
||||
self.connection_id = connection_id
|
||||
self.connection_token = connection_token
|
||||
self.connection = connection
|
||||
self.protocol = protocol
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._ping_task: asyncio.Task | None = None
|
||||
self._store = ResultStore()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.connection_token)
|
||||
|
||||
@property
|
||||
def user_id(self) -> int:
|
||||
return int(self.connection_id)
|
||||
|
||||
async def send_packet(self, packet: Packet):
|
||||
await self.connection.send_bytes(self.protocol.encode(packet))
|
||||
|
||||
async def receive_packets(self) -> list[Packet]:
|
||||
message = await self.connection.receive()
|
||||
d = message.get("bytes") or message.get("text", "").encode()
|
||||
if not d:
|
||||
return []
|
||||
return self.protocol.decode(d)
|
||||
|
||||
async def _ping(self):
|
||||
while True:
|
||||
try:
|
||||
await self.send_packet(PingPacket())
|
||||
await asyncio.sleep(settings.signalr_ping_interval)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e) or "close message" in str(e):
|
||||
break
|
||||
else:
|
||||
logger.error(f"Error in ping task for {self.connection_id}: {e}")
|
||||
break
|
||||
except Exception:
|
||||
logger.exception(f"Error in client {self.connection_id}")
|
||||
|
||||
|
||||
class Hub[TState: UserState]:
|
||||
def __init__(self) -> None:
|
||||
self.clients: dict[str, Client] = {}
|
||||
self.waited_clients: dict[str, int] = {}
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
self.groups: dict[str, set[Client]] = {}
|
||||
self.state: dict[int, TState] = {}
|
||||
|
||||
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
|
||||
self.waited_clients[connection_token] = timestamp
|
||||
|
||||
def get_client_by_id(self, id: str, default: Any = None) -> Client:
|
||||
for client in self.clients.values():
|
||||
if client.connection_id == id:
|
||||
return client
|
||||
return default
|
||||
|
||||
def get_before_clients(self, id: str, current_token: str) -> list[Client]:
|
||||
clients = []
|
||||
for client in self.clients.values():
|
||||
if client.connection_id != id:
|
||||
continue
|
||||
if client.connection_token == current_token:
|
||||
continue
|
||||
clients.append(client)
|
||||
return clients
|
||||
|
||||
@abstractmethod
|
||||
def create_state(self, client: Client) -> TState:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_or_create_state(self, client: Client) -> TState:
|
||||
if (state := self.state.get(client.user_id)) is not None:
|
||||
return state
|
||||
state = self.create_state(client)
|
||||
self.state[client.user_id] = state
|
||||
return state
|
||||
|
||||
def add_to_group(self, client: Client, group_id: str) -> None:
|
||||
self.groups.setdefault(group_id, set()).add(client)
|
||||
|
||||
def remove_from_group(self, client: Client, group_id: str) -> None:
|
||||
if group_id in self.groups:
|
||||
self.groups[group_id].discard(client)
|
||||
|
||||
async def kick_client(self, client: Client) -> None:
|
||||
await self.call_noblock(client, "DisconnectRequested")
|
||||
await client.send_packet(ClosePacket(allow_reconnect=False))
|
||||
await client.connection.close(code=1000, reason="Disconnected by server")
|
||||
|
||||
async def add_client(
|
||||
self,
|
||||
connection_id: str,
|
||||
connection_token: str,
|
||||
protocol: Protocol,
|
||||
connection: WebSocket,
|
||||
) -> Client:
|
||||
if connection_token in self.clients:
|
||||
raise ValueError(f"Client with connection token {connection_token} already exists.")
|
||||
if connection_token in self.waited_clients:
|
||||
if self.waited_clients[connection_token] < time.time() - settings.signalr_negotiate_timeout:
|
||||
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
||||
del self.waited_clients[connection_token]
|
||||
client = Client(connection_id, connection_token, connection, protocol)
|
||||
self.clients[connection_token] = client
|
||||
task = asyncio.create_task(client._ping())
|
||||
self.tasks.add(task)
|
||||
client._ping_task = task
|
||||
return client
|
||||
|
||||
async def remove_client(self, client: Client) -> None:
|
||||
if client.connection_token not in self.clients:
|
||||
return
|
||||
del self.clients[client.connection_token]
|
||||
if client._listen_task:
|
||||
client._listen_task.cancel()
|
||||
if client._ping_task:
|
||||
client._ping_task.cancel()
|
||||
for group in self.groups.values():
|
||||
group.discard(client)
|
||||
await self.clean_state(client, False)
|
||||
|
||||
@abstractmethod
|
||||
async def _clean_state(self, state: TState) -> None:
|
||||
return
|
||||
|
||||
async def clean_state(self, client: Client, disconnected: bool) -> None:
|
||||
if (state := self.state.get(client.user_id)) is None:
|
||||
return
|
||||
if disconnected and client.connection_token != state.connection_token:
|
||||
return
|
||||
try:
|
||||
await self._clean_state(state)
|
||||
del self.state[client.user_id]
|
||||
except Exception:
|
||||
...
|
||||
|
||||
async def on_connect(self, client: Client) -> None:
|
||||
if method := getattr(self, "on_client_connect", None):
|
||||
await method(client)
|
||||
|
||||
async def send_packet(self, client: Client, packet: Packet) -> None:
|
||||
logger.trace(f"[SignalR] send to {client.connection_id} packet {packet}")
|
||||
try:
|
||||
await client.send_packet(packet)
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}")
|
||||
await self.remove_client(client)
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e):
|
||||
logger.info(f"Client {client.connection_id} closed the connection.")
|
||||
else:
|
||||
logger.exception(f"RuntimeError in client {client.connection_id}: {e}")
|
||||
await self.remove_client(client)
|
||||
except Exception:
|
||||
logger.exception(f"Error in client {client.connection_id}")
|
||||
await self.remove_client(client)
|
||||
|
||||
async def broadcast_call(self, method: str, *args: Any) -> None:
|
||||
tasks = []
|
||||
for client in self.clients.values():
|
||||
tasks.append(self.call_noblock(client, method, *args))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def broadcast_group_call(self, group_id: str, method: str, *args: Any) -> None:
|
||||
tasks = []
|
||||
for client in self.groups.get(group_id, []):
|
||||
tasks.append(self.call_noblock(client, method, *args))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _listen_client(self, client: Client) -> None:
|
||||
try:
|
||||
while True:
|
||||
packets = await client.receive_packets()
|
||||
for packet in packets:
|
||||
if isinstance(packet, PingPacket):
|
||||
continue
|
||||
elif isinstance(packet, ClosePacket):
|
||||
raise CloseConnection(
|
||||
packet.error or "Connection closed by client",
|
||||
packet.allow_reconnect,
|
||||
True,
|
||||
)
|
||||
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}")
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e):
|
||||
logger.info(f"Client {client.connection_id} closed the connection.")
|
||||
else:
|
||||
logger.exception(f"RuntimeError in client {client.connection_id}: {e}")
|
||||
except CloseConnection as e:
|
||||
if not e.from_client:
|
||||
await client.send_packet(ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect))
|
||||
logger.info(f"Client {client.connection_id} closed the connection: {e.message}")
|
||||
except Exception:
|
||||
logger.exception(f"Error in client {client.connection_id}")
|
||||
|
||||
await self.remove_client(client)
|
||||
|
||||
async def _handle_packet(self, client: Client, packet: Packet) -> None:
|
||||
if isinstance(packet, PingPacket):
|
||||
return
|
||||
elif isinstance(packet, InvocationPacket):
|
||||
args = packet.arguments or []
|
||||
error = None
|
||||
result = None
|
||||
try:
|
||||
result = await self.invoke_method(client, packet.target, args)
|
||||
except InvokeException as e:
|
||||
error = e.message
|
||||
logger.debug(f"Client {client.connection_token} call {packet.target} failed: {error}")
|
||||
except Exception:
|
||||
logger.exception(f"Error invoking method {packet.target} for client {client.connection_id}")
|
||||
error = "Unknown error occured in server"
|
||||
if packet.invocation_id is not None:
|
||||
await self.send_packet(
|
||||
client,
|
||||
CompletionPacket(
|
||||
invocation_id=packet.invocation_id,
|
||||
error=error,
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
client._store.add_result(packet.invocation_id, packet.result, packet.error)
|
||||
|
||||
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
|
||||
method_ = getattr(self, method, None)
|
||||
call_params = []
|
||||
if not method_:
|
||||
raise InvokeException(f"Method '{method}' not found in hub.")
|
||||
signature = get_signature(method_)
|
||||
for name, param in signature.parameters.items():
|
||||
if name == "self" or param.annotation is Client:
|
||||
continue
|
||||
call_params.append(client.protocol.validate_object(args.pop(0), param.annotation))
|
||||
return await method_(client, *call_params)
|
||||
|
||||
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
||||
invocation_id = client._store.get_invocation_id()
|
||||
await self.send_packet(
|
||||
client,
|
||||
InvocationPacket(
|
||||
header={},
|
||||
invocation_id=invocation_id,
|
||||
target=method,
|
||||
arguments=list(args),
|
||||
stream_ids=None,
|
||||
),
|
||||
)
|
||||
r = await client._store.fetch(invocation_id, None)
|
||||
if r[1]:
|
||||
raise InvokeException(r[1])
|
||||
return r[0]
|
||||
|
||||
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
|
||||
await self.send_packet(
|
||||
client,
|
||||
InvocationPacket(
|
||||
header={},
|
||||
invocation_id=None,
|
||||
target=method,
|
||||
arguments=list(args),
|
||||
stream_ids=None,
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return item in self.clients or item in self.waited_clients
|
||||
@@ -1,296 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Coroutine
|
||||
import math
|
||||
from typing import override
|
||||
|
||||
from app.calculator import clamp
|
||||
from app.database import Relationship, RelationshipType, User
|
||||
from app.database.playlist_best_score import PlaylistBestScore
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import Room
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import with_db
|
||||
from app.log import logger
|
||||
from app.models.metadata_hub import (
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
DailyChallengeInfo,
|
||||
MetadataClientState,
|
||||
MultiplayerPlaylistItemStats,
|
||||
MultiplayerRoomScoreSetEvent,
|
||||
MultiplayerRoomStats,
|
||||
OnlineStatus,
|
||||
UserActivity,
|
||||
)
|
||||
from app.models.room import RoomCategory
|
||||
from app.service.subscribers.score_processed import ScoreSubscriber
|
||||
from app.utils import utcnow
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
from sqlmodel import col, select
|
||||
|
||||
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
||||
|
||||
|
||||
class MetadataHub(Hub[MetadataClientState]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.subscriber = ScoreSubscriber()
|
||||
self.subscriber.metadata_hub = self
|
||||
self._daily_challenge_stats: MultiplayerRoomStats | None = None
|
||||
self._today = utcnow().date()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def get_daily_challenge_stats(self, daily_challenge_room: int) -> MultiplayerRoomStats:
|
||||
if self._daily_challenge_stats is None or self._today != utcnow().date():
|
||||
self._daily_challenge_stats = MultiplayerRoomStats(
|
||||
room_id=daily_challenge_room,
|
||||
playlist_item_stats={},
|
||||
)
|
||||
return self._daily_challenge_stats
|
||||
|
||||
@staticmethod
|
||||
def online_presence_watchers_group() -> str:
|
||||
return ONLINE_PRESENCE_WATCHERS_GROUP
|
||||
|
||||
@staticmethod
|
||||
def room_watcher_group(room_id: int) -> str:
|
||||
return f"metadata:multiplayer-room-watchers:{room_id}"
|
||||
|
||||
def broadcast_tasks(self, user_id: int, store: MetadataClientState | None) -> set[Coroutine]:
|
||||
if store is not None and not store.pushable:
|
||||
return set()
|
||||
data = store.for_push if store else None
|
||||
return {
|
||||
self.broadcast_group_call(
|
||||
self.online_presence_watchers_group(),
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
data,
|
||||
),
|
||||
self.broadcast_group_call(
|
||||
self.friend_presence_watchers_group(user_id),
|
||||
"FriendPresenceUpdated",
|
||||
user_id,
|
||||
data,
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def friend_presence_watchers_group(user_id: int):
|
||||
return f"metadata:friend-presence-watchers:{user_id}"
|
||||
|
||||
@override
|
||||
async def _clean_state(self, state: MetadataClientState) -> None:
|
||||
user_id = int(state.connection_id)
|
||||
|
||||
if state.pushable:
|
||||
await asyncio.gather(*self.broadcast_tasks(user_id, None))
|
||||
|
||||
async with with_db() as session:
|
||||
async with session.begin():
|
||||
user = (await session.exec(select(User).where(User.id == int(state.connection_id)))).one()
|
||||
user.last_visit = utcnow()
|
||||
await session.commit()
|
||||
|
||||
@override
|
||||
def create_state(self, client: Client) -> MetadataClientState:
|
||||
return MetadataClientState(
|
||||
connection_id=client.connection_id,
|
||||
connection_token=client.connection_token,
|
||||
)
|
||||
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
|
||||
# CRITICAL FIX: Set online status IMMEDIATELY upon connection
|
||||
# This matches the C# official implementation behavior
|
||||
store.status = OnlineStatus.ONLINE
|
||||
logger.info(f"[MetadataHub] Set user {user_id} status to ONLINE upon connection")
|
||||
|
||||
async with with_db() as session:
|
||||
async with session.begin():
|
||||
friends = (
|
||||
await session.exec(
|
||||
select(Relationship.target_id).where(
|
||||
Relationship.user_id == user_id,
|
||||
Relationship.type == RelationshipType.FOLLOW,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
tasks = []
|
||||
for friend_id in friends:
|
||||
self.groups.setdefault(self.friend_presence_watchers_group(friend_id), set()).add(client)
|
||||
if (friend_state := self.state.get(friend_id)) and friend_state.pushable:
|
||||
tasks.append(
|
||||
self.broadcast_group_call(
|
||||
self.friend_presence_watchers_group(friend_id),
|
||||
"FriendPresenceUpdated",
|
||||
friend_id,
|
||||
friend_state.for_push if friend_state.pushable else None,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
daily_challenge_room = (
|
||||
await session.exec(
|
||||
select(Room).where(
|
||||
col(Room.ends_at) > utcnow(),
|
||||
Room.category == RoomCategory.DAILY_CHALLENGE,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if daily_challenge_room:
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"DailyChallengeUpdated",
|
||||
DailyChallengeInfo(
|
||||
room_id=daily_challenge_room.id,
|
||||
),
|
||||
)
|
||||
|
||||
# CRITICAL FIX: Immediately broadcast the user's online status to all watchers
|
||||
# This ensures the user appears as "currently online" right after connection
|
||||
# Similar to the C# implementation's immediate broadcast logic
|
||||
online_presence_tasks = self.broadcast_tasks(user_id, store)
|
||||
if online_presence_tasks:
|
||||
await asyncio.gather(*online_presence_tasks)
|
||||
logger.info(f"[MetadataHub] Broadcasted online status for user {user_id} to watchers")
|
||||
|
||||
# Also send the user's own presence update to confirm online status
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.for_push,
|
||||
)
|
||||
logger.info(f"[MetadataHub] User {user_id} is now ONLINE and visible to other clients")
|
||||
|
||||
async def UpdateStatus(self, client: Client, status: int) -> None:
|
||||
status_ = OnlineStatus(status)
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
if store.status is not None and store.status == status_:
|
||||
return
|
||||
store.status = OnlineStatus(status_)
|
||||
tasks = self.broadcast_tasks(user_id, store)
|
||||
tasks.add(
|
||||
self.call_noblock(
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.for_push,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def UpdateActivity(self, client: Client, activity: UserActivity | None) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
store.activity = activity
|
||||
tasks = self.broadcast_tasks(user_id, store)
|
||||
tasks.add(
|
||||
self.call_noblock(
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.for_push,
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def BeginWatchingUserPresence(self, client: Client) -> None:
|
||||
# Critical fix: Send all currently online users to the new watcher
|
||||
# Must use for_push to get the correct UserPresence format
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.call_noblock(
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.for_push, # Fixed: use for_push instead of store
|
||||
)
|
||||
for user_id, store in self.state.items()
|
||||
if store.pushable
|
||||
]
|
||||
)
|
||||
self.add_to_group(client, self.online_presence_watchers_group())
|
||||
logger.info(
|
||||
f"[MetadataHub] Client {client.connection_id} now watching user presence, "
|
||||
f"sent {len([s for s in self.state.values() if s.pushable])} online users"
|
||||
)
|
||||
|
||||
async def EndWatchingUserPresence(self, client: Client) -> None:
|
||||
self.remove_from_group(client, self.online_presence_watchers_group())
|
||||
|
||||
async def notify_room_score_processed(self, event: MultiplayerRoomScoreSetEvent):
|
||||
await self.broadcast_group_call(self.room_watcher_group(event.room_id), "MultiplayerRoomScoreSet", event)
|
||||
|
||||
async def BeginWatchingMultiplayerRoom(self, client: Client, room_id: int):
|
||||
self.add_to_group(client, self.room_watcher_group(room_id))
|
||||
await self.subscriber.subscribe_room_score(room_id, client.user_id)
|
||||
stats = self.get_daily_challenge_stats(room_id)
|
||||
await self.update_daily_challenge_stats(stats)
|
||||
return list(stats.playlist_item_stats.values())
|
||||
|
||||
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
|
||||
async with with_db() as session:
|
||||
playlist_ids = (
|
||||
await session.exec(
|
||||
select(Playlist.id).where(
|
||||
Playlist.room_id == stats.room_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
for playlist_id in playlist_ids:
|
||||
item = stats.playlist_item_stats.get(playlist_id, None)
|
||||
if item is None:
|
||||
item = MultiplayerPlaylistItemStats(
|
||||
playlist_item_id=playlist_id,
|
||||
total_score_distribution=[0] * TOTAL_SCORE_DISTRIBUTION_BINS,
|
||||
cumulative_score=0,
|
||||
last_processed_score_id=0,
|
||||
)
|
||||
stats.playlist_item_stats[playlist_id] = item
|
||||
last_processed_score_id = item.last_processed_score_id
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == stats.room_id,
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.score_id > last_processed_score_id,
|
||||
col(PlaylistBestScore.score).has(col(Score.passed).is_(True)),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
if len(scores) == 0:
|
||||
continue
|
||||
|
||||
async with self._lock:
|
||||
if item.last_processed_score_id == last_processed_score_id:
|
||||
totals = defaultdict(int)
|
||||
for score in scores:
|
||||
bin_index = int(
|
||||
clamp(
|
||||
math.floor(score.total_score / 100000),
|
||||
0,
|
||||
TOTAL_SCORE_DISTRIBUTION_BINS - 1,
|
||||
)
|
||||
)
|
||||
totals[bin_index] += 1
|
||||
|
||||
item.cumulative_score += sum(score.total_score for score in scores)
|
||||
|
||||
for j in range(TOTAL_SCORE_DISTRIBUTION_BINS):
|
||||
item.total_score_distribution[j] += totals.get(j, 0)
|
||||
|
||||
if scores:
|
||||
item.last_processed_score_id = max(score.score_id for score in scores)
|
||||
|
||||
async def EndWatchingMultiplayerRoom(self, client: Client, room_id: int):
|
||||
self.remove_from_group(client, self.room_watcher_group(room_id))
|
||||
await self.subscriber.unsubscribe_room_score(room_id, client.user_id)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,585 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import lzma
|
||||
import struct
|
||||
import time
|
||||
from typing import override
|
||||
|
||||
from app.calculator import clamp
|
||||
from app.config import settings
|
||||
from app.database import Beatmap, User
|
||||
from app.database.failtime import FailTime, FailTimeResp
|
||||
from app.database.score import Score
|
||||
from app.database.score_token import ScoreToken
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.exception import InvokeException
|
||||
from app.log import logger
|
||||
from app.models.mods import APIMod, mods_to_int
|
||||
from app.models.score import GameMode, LegacyReplaySoloScoreInfo, ScoreStatistics
|
||||
from app.models.spectator_hub import (
|
||||
APIUser,
|
||||
FrameDataBundle,
|
||||
LegacyReplayFrame,
|
||||
ScoreInfo,
|
||||
SpectatedUserState,
|
||||
SpectatorState,
|
||||
StoreClientState,
|
||||
StoreScore,
|
||||
)
|
||||
from app.utils import unix_timestamp_to_windows
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
from httpx import HTTPError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import select
|
||||
|
||||
READ_SCORE_TIMEOUT = 30
|
||||
REPLAY_LATEST_VER = 30000016
|
||||
|
||||
|
||||
def encode_uleb128(num: int) -> bytes | bytearray:
|
||||
if num == 0:
|
||||
return b"\x00"
|
||||
|
||||
ret = bytearray()
|
||||
|
||||
while num != 0:
|
||||
ret.append(num & 0x7F)
|
||||
num >>= 7
|
||||
if num != 0:
|
||||
ret[-1] |= 0x80
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def encode_string(s: str) -> bytes:
|
||||
"""Write `s` into bytes (ULEB128 & string)."""
|
||||
if s:
|
||||
encoded = s.encode()
|
||||
ret = b"\x0b" + encode_uleb128(len(encoded)) + encoded
|
||||
else:
|
||||
ret = b"\x00"
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
async def save_replay(
|
||||
ruleset_id: int,
|
||||
md5: str,
|
||||
username: str,
|
||||
score: Score,
|
||||
statistics: ScoreStatistics,
|
||||
maximum_statistics: ScoreStatistics,
|
||||
frames: list[LegacyReplayFrame],
|
||||
) -> None:
|
||||
data = bytearray()
|
||||
data.extend(struct.pack("<bi", ruleset_id, REPLAY_LATEST_VER))
|
||||
data.extend(encode_string(md5))
|
||||
data.extend(encode_string(username))
|
||||
data.extend(encode_string(f"lazer-{username}-{score.started_at.isoformat()}"))
|
||||
data.extend(
|
||||
struct.pack(
|
||||
"<hhhhhhihbi",
|
||||
score.n300,
|
||||
score.n100,
|
||||
score.n50,
|
||||
score.ngeki,
|
||||
score.nkatu,
|
||||
score.nmiss,
|
||||
score.total_score,
|
||||
score.max_combo,
|
||||
score.is_perfect_combo,
|
||||
mods_to_int(score.mods),
|
||||
)
|
||||
)
|
||||
data.extend(encode_string("")) # hp graph
|
||||
data.extend(
|
||||
struct.pack(
|
||||
"<q",
|
||||
unix_timestamp_to_windows(round(score.started_at.timestamp())),
|
||||
)
|
||||
)
|
||||
|
||||
# write frames
|
||||
frame_strs = []
|
||||
last_time = 0
|
||||
for frame in frames:
|
||||
time = round(frame.time)
|
||||
frame_strs.append(f"{time - last_time}|{frame.mouse_x or 0.0}|{frame.mouse_y or 0.0}|{frame.button_state}")
|
||||
last_time = time
|
||||
frame_strs.append("-12345|0|0|0")
|
||||
|
||||
compressed = lzma.compress(",".join(frame_strs).encode("ascii"), format=lzma.FORMAT_ALONE)
|
||||
data.extend(struct.pack("<i", len(compressed)))
|
||||
data.extend(compressed)
|
||||
data.extend(struct.pack("<q", score.id))
|
||||
score_info = LegacyReplaySoloScoreInfo(
|
||||
online_id=score.id,
|
||||
mods=score.mods,
|
||||
statistics=statistics,
|
||||
maximum_statistics=maximum_statistics,
|
||||
client_version="",
|
||||
rank=score.rank,
|
||||
user_id=score.user_id,
|
||||
total_score_without_mods=score.total_score_without_mods,
|
||||
)
|
||||
compressed = lzma.compress(json.dumps(score_info).encode(), format=lzma.FORMAT_ALONE)
|
||||
data.extend(struct.pack("<i", len(compressed)))
|
||||
data.extend(compressed)
|
||||
|
||||
storage_service = get_storage_service()
|
||||
replay_path = score.replay_filename
|
||||
await storage_service.write_file(replay_path, bytes(data), "application/x-osu-replay")
|
||||
|
||||
|
||||
class SpectatorHub(Hub[StoreClientState]):
|
||||
@staticmethod
|
||||
def group_id(user_id: int) -> str:
|
||||
return f"watch:{user_id}"
|
||||
|
||||
@override
|
||||
def create_state(self, client: Client) -> StoreClientState:
|
||||
return StoreClientState(
|
||||
connection_id=client.connection_id,
|
||||
connection_token=client.connection_token,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _clean_state(self, state: StoreClientState) -> None:
|
||||
"""
|
||||
Enhanced cleanup based on official osu-server-spectator implementation.
|
||||
Properly notifies watched users when spectator disconnects.
|
||||
"""
|
||||
user_id = int(state.connection_id)
|
||||
if state.state:
|
||||
await self._end_session(user_id, state.state, state)
|
||||
|
||||
# Critical fix: Notify all watched users that this spectator has disconnected
|
||||
# This matches the official CleanUpState implementation
|
||||
for watched_user_id in state.watched_user:
|
||||
if (target_client := self.get_client_by_id(str(watched_user_id))) is not None:
|
||||
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
||||
logger.debug(f"[SpectatorHub] Notified {watched_user_id} that {user_id} stopped watching")
|
||||
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
"""
|
||||
Enhanced connection handling based on official implementation.
|
||||
Send all active player states to newly connected clients.
|
||||
"""
|
||||
logger.info(f"[SpectatorHub] Client {client.user_id} connected")
|
||||
|
||||
# Send all current player states to the new client
|
||||
# This matches the official OnConnectedAsync behavior
|
||||
active_states = []
|
||||
for user_id, store in self.state.items():
|
||||
if store.state is not None:
|
||||
active_states.append((user_id, store.state))
|
||||
|
||||
if active_states:
|
||||
logger.debug(f"[SpectatorHub] Sending {len(active_states)} active player states to {client.user_id}")
|
||||
# Send states sequentially to avoid overwhelming the client
|
||||
for user_id, state in active_states:
|
||||
try:
|
||||
await self.call_noblock(client, "UserBeganPlaying", user_id, state)
|
||||
except Exception as e:
|
||||
logger.debug(f"[SpectatorHub] Failed to send state for user {user_id}: {e}")
|
||||
|
||||
# Also sync with MultiplayerHub for cross-hub spectating
|
||||
await self._sync_with_multiplayer_hub(client)
|
||||
|
||||
async def _sync_with_multiplayer_hub(self, client: Client) -> None:
|
||||
"""
|
||||
Sync with MultiplayerHub to get active multiplayer game states.
|
||||
This ensures spectators can see multiplayer games from other pages.
|
||||
"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from app.signalr.hub import MultiplayerHubs
|
||||
|
||||
# Check all active multiplayer rooms for playing users
|
||||
for room_id, server_room in MultiplayerHubs.rooms.items():
|
||||
for room_user in server_room.room.users:
|
||||
# Send state for users who are playing or in results
|
||||
if room_user.state.is_playing and room_user.user_id not in self.state:
|
||||
# Create a synthetic SpectatorState for multiplayer players
|
||||
# This helps with cross-hub spectating
|
||||
try:
|
||||
synthetic_state = SpectatorState(
|
||||
beatmap_id=server_room.queue.current_item.beatmap_id,
|
||||
ruleset_id=room_user.ruleset_id or 0, # Default to osu!
|
||||
mods=room_user.mods,
|
||||
state=SpectatedUserState.Playing,
|
||||
maximum_statistics={},
|
||||
)
|
||||
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"UserBeganPlaying",
|
||||
room_user.user_id,
|
||||
synthetic_state,
|
||||
)
|
||||
logger.debug(
|
||||
f"[SpectatorHub] Sent synthetic multiplayer state for user {room_user.user_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"[SpectatorHub] Failed to create synthetic state: {e}")
|
||||
|
||||
# Critical addition: Notify about finished players in multiplayer games
|
||||
elif (
|
||||
hasattr(room_user.state, "name")
|
||||
and room_user.state.name == "RESULTS"
|
||||
and room_user.user_id not in self.state
|
||||
):
|
||||
try:
|
||||
# Create a synthetic finished state
|
||||
finished_state = SpectatorState(
|
||||
beatmap_id=server_room.queue.current_item.beatmap_id,
|
||||
ruleset_id=room_user.ruleset_id or 0,
|
||||
mods=room_user.mods,
|
||||
state=SpectatedUserState.Passed, # Assume passed for results
|
||||
maximum_statistics={},
|
||||
)
|
||||
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"UserFinishedPlaying",
|
||||
room_user.user_id,
|
||||
finished_state,
|
||||
)
|
||||
logger.debug(f"[SpectatorHub] Sent synthetic finished state for user {room_user.user_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"[SpectatorHub] Failed to create synthetic finished state: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[SpectatorHub] Failed to sync with MultiplayerHub: {e}")
|
||||
# This is not critical, so we don't raise the exception
|
||||
|
||||
async def BeginPlaySession(self, client: Client, score_token: int, state: SpectatorState) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
if store.state is not None:
|
||||
logger.warning(f"[SpectatorHub] User {user_id} began new session without ending previous one; cleaning up")
|
||||
try:
|
||||
await self._end_session(user_id, store.state, store)
|
||||
finally:
|
||||
store.state = None
|
||||
store.beatmap_status = None
|
||||
store.checksum = None
|
||||
store.ruleset_id = None
|
||||
store.score_token = None
|
||||
store.score = None
|
||||
if state.beatmap_id is None or state.ruleset_id is None:
|
||||
return
|
||||
|
||||
fetcher = await get_fetcher()
|
||||
async with with_db() as session:
|
||||
async with session.begin():
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=state.beatmap_id)
|
||||
except HTTPError:
|
||||
raise InvokeException(f"Beatmap {state.beatmap_id} not found.")
|
||||
user = (await session.exec(select(User).where(User.id == user_id))).first()
|
||||
if not user:
|
||||
return
|
||||
name = user.username
|
||||
store.state = state
|
||||
store.beatmap_status = beatmap.beatmap_status
|
||||
store.checksum = beatmap.checksum
|
||||
store.ruleset_id = state.ruleset_id
|
||||
store.score_token = score_token
|
||||
store.score = StoreScore(
|
||||
score_info=ScoreInfo(
|
||||
mods=state.mods,
|
||||
user=APIUser(id=user_id, name=name),
|
||||
ruleset=state.ruleset_id,
|
||||
maximum_statistics=state.maximum_statistics,
|
||||
)
|
||||
)
|
||||
logger.info(f"[SpectatorHub] {client.user_id} began playing {state.beatmap_id}")
|
||||
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserBeganPlaying",
|
||||
user_id,
|
||||
state,
|
||||
)
|
||||
|
||||
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
if store.state is None or store.score is None:
|
||||
return
|
||||
|
||||
header = frame_data.header
|
||||
score_info = store.score.score_info
|
||||
score_info.accuracy = header.accuracy
|
||||
score_info.combo = header.combo
|
||||
score_info.max_combo = header.max_combo
|
||||
score_info.statistics = header.statistics
|
||||
store.score.replay_frames.extend(frame_data.frames)
|
||||
|
||||
await self.broadcast_group_call(self.group_id(user_id), "UserSentFrames", user_id, frame_data)
|
||||
|
||||
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
score = store.score
|
||||
|
||||
# Early return if no active session
|
||||
if (
|
||||
score is None
|
||||
or store.score_token is None
|
||||
or store.beatmap_status is None
|
||||
or store.state is None
|
||||
or store.score is None
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
# Process score if conditions are met
|
||||
if (settings.enable_all_beatmap_leaderboard and store.beatmap_status.has_leaderboard()) and any(
|
||||
k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()
|
||||
):
|
||||
await self._process_score(store, client)
|
||||
|
||||
# End the play session and notify watchers
|
||||
await self._end_session(user_id, state, store)
|
||||
|
||||
finally:
|
||||
# CRITICAL FIX: Always clear state in finally block to ensure cleanup
|
||||
# This matches the official C# implementation pattern
|
||||
store.state = None
|
||||
store.beatmap_status = None
|
||||
store.checksum = None
|
||||
store.ruleset_id = None
|
||||
store.score_token = None
|
||||
store.score = None
|
||||
logger.info(f"[SpectatorHub] Cleared all session state for user {user_id}")
|
||||
|
||||
async def _process_score(self, store: StoreClientState, client: Client) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
assert store.state is not None
|
||||
assert store.score_token is not None
|
||||
assert store.checksum is not None
|
||||
assert store.ruleset_id is not None
|
||||
assert store.score is not None
|
||||
async with with_db() as session:
|
||||
async with session:
|
||||
start_time = time.time()
|
||||
score_record = None
|
||||
while time.time() - start_time < READ_SCORE_TIMEOUT:
|
||||
sub_query = select(ScoreToken.score_id).where(
|
||||
ScoreToken.id == store.score_token,
|
||||
)
|
||||
result = await session.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.beatmap))
|
||||
.where(
|
||||
Score.id == sub_query.scalar_subquery(),
|
||||
Score.user_id == user_id,
|
||||
)
|
||||
)
|
||||
score_record = result.first()
|
||||
if score_record:
|
||||
break
|
||||
if not score_record:
|
||||
return
|
||||
if not score_record.passed:
|
||||
return
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"UserScoreProcessed",
|
||||
user_id,
|
||||
score_record.id,
|
||||
)
|
||||
# save replay
|
||||
score_record.has_replay = True
|
||||
await session.commit()
|
||||
await session.refresh(score_record)
|
||||
await save_replay(
|
||||
ruleset_id=store.ruleset_id,
|
||||
md5=store.checksum,
|
||||
username=store.score.score_info.user.name,
|
||||
score=score_record,
|
||||
statistics=store.score.score_info.statistics,
|
||||
maximum_statistics=store.score.score_info.maximum_statistics,
|
||||
frames=store.score.replay_frames,
|
||||
)
|
||||
|
||||
async def _end_session(self, user_id: int, state: SpectatorState, store: StoreClientState) -> None:
|
||||
async def _add_failtime():
|
||||
async with with_db() as session:
|
||||
failtime = await session.get(FailTime, state.beatmap_id)
|
||||
total_length = (
|
||||
await session.exec(select(Beatmap.total_length).where(Beatmap.id == state.beatmap_id))
|
||||
).one()
|
||||
index = clamp(round((exit_time / total_length) * 100), 0, 99)
|
||||
if failtime is not None:
|
||||
resp = FailTimeResp.from_db(failtime)
|
||||
else:
|
||||
resp = FailTimeResp()
|
||||
if state.state == SpectatedUserState.Failed:
|
||||
resp.fail[index] += 1
|
||||
elif state.state == SpectatedUserState.Quit:
|
||||
resp.exit[index] += 1
|
||||
|
||||
assert state.beatmap_id
|
||||
new_failtime = FailTime.from_resp(state.beatmap_id, resp)
|
||||
if failtime is not None:
|
||||
await session.merge(new_failtime)
|
||||
else:
|
||||
session.add(new_failtime)
|
||||
await session.commit()
|
||||
|
||||
async def _edit_playtime(token: int, ruleset_id: int, mods: list[APIMod]):
|
||||
redis = get_redis()
|
||||
key = f"score:existed_time:{token}"
|
||||
messages = await redis.xrange(key, min="-", max="+", count=1)
|
||||
if not messages:
|
||||
return
|
||||
before_time = int(messages[0][1]["time"])
|
||||
await redis.delete(key)
|
||||
async with with_db() as session:
|
||||
gamemode = GameMode.from_int(ruleset_id).to_special_mode(mods)
|
||||
statistics = (
|
||||
await session.exec(
|
||||
select(UserStatistics).where(
|
||||
UserStatistics.user_id == user_id,
|
||||
UserStatistics.mode == gamemode,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if statistics is None:
|
||||
return
|
||||
statistics.play_time -= before_time
|
||||
statistics.play_time += round(min(before_time, exit_time))
|
||||
|
||||
if state.state == SpectatedUserState.Playing:
|
||||
state.state = SpectatedUserState.Quit
|
||||
logger.debug(f"[SpectatorHub] Changed state from Playing to Quit for user {user_id}")
|
||||
|
||||
# Calculate exit time safely
|
||||
exit_time = 0
|
||||
if store.score and store.score.replay_frames:
|
||||
exit_time = max(frame.time for frame in store.score.replay_frames) // 1000
|
||||
|
||||
# Background task for playtime editing - only if we have valid data
|
||||
if store.score_token and store.ruleset_id and store.score:
|
||||
task = asyncio.create_task(
|
||||
_edit_playtime(
|
||||
store.score_token,
|
||||
store.ruleset_id,
|
||||
store.score.score_info.mods,
|
||||
)
|
||||
)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
# Background task for failtime tracking - only for failed/quit states with valid data
|
||||
if (
|
||||
state.beatmap_id is not None
|
||||
and exit_time > 0
|
||||
and state.state in (SpectatedUserState.Failed, SpectatedUserState.Quit)
|
||||
):
|
||||
task = asyncio.create_task(_add_failtime())
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
logger.info(f"[SpectatorHub] {user_id} finished playing {state.beatmap_id} with {state.state}")
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserFinishedPlaying",
|
||||
user_id,
|
||||
state,
|
||||
)
|
||||
|
||||
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
"""
|
||||
Enhanced StartWatchingUser based on official osu-server-spectator implementation.
|
||||
Properly handles state synchronization and watcher notifications.
|
||||
"""
|
||||
user_id = int(client.connection_id)
|
||||
|
||||
logger.info(f"[SpectatorHub] {user_id} started watching {target_id}")
|
||||
|
||||
try:
|
||||
# Get target user's current state if it exists
|
||||
target_store = self.state.get(target_id)
|
||||
if not target_store or not target_store.state:
|
||||
logger.info(f"[SpectatorHub] Rejecting watch request for {target_id}: user not playing")
|
||||
raise InvokeException("Target user is not currently playing")
|
||||
|
||||
if target_store.state.state != SpectatedUserState.Playing:
|
||||
logger.info(
|
||||
f"[SpectatorHub] Rejecting watch request for {target_id}: state is {target_store.state.state}"
|
||||
)
|
||||
raise InvokeException("Target user is not currently playing")
|
||||
|
||||
logger.debug(f"[SpectatorHub] {target_id} is currently playing, sending state")
|
||||
# Send current state to the watcher immediately
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"UserBeganPlaying",
|
||||
target_id,
|
||||
target_store.state,
|
||||
)
|
||||
except InvokeException:
|
||||
# Re-raise to inform caller without adding to group
|
||||
raise
|
||||
except Exception as e:
|
||||
# User isn't tracked or error occurred - this is not critical
|
||||
logger.debug(f"[SpectatorHub] Could not get state for {target_id}: {e}")
|
||||
raise InvokeException("Target user is not currently playing") from e
|
||||
|
||||
# Add watcher to our tracked users only after validation
|
||||
store = self.get_or_create_state(client)
|
||||
store.watched_user.add(target_id)
|
||||
|
||||
# Add to SignalR group for this target user
|
||||
self.add_to_group(client, self.group_id(target_id))
|
||||
|
||||
# Get watcher's username and notify the target user
|
||||
try:
|
||||
async with with_db() as session:
|
||||
username = (await session.exec(select(User.username).where(User.id == user_id))).first()
|
||||
if not username:
|
||||
logger.warning(f"[SpectatorHub] Could not find username for user {user_id}")
|
||||
return
|
||||
|
||||
# Notify target user that someone started watching
|
||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||
# Create watcher info array (matches official format)
|
||||
watcher_info = [[user_id, username]]
|
||||
await self.call_noblock(target_client, "UserStartedWatching", watcher_info)
|
||||
logger.debug(f"[SpectatorHub] Notified {target_id} that {username} started watching")
|
||||
except Exception as e:
|
||||
logger.error(f"[SpectatorHub] Error notifying target user {target_id}: {e}")
|
||||
|
||||
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
"""
|
||||
Enhanced EndWatchingUser based on official osu-server-spectator implementation.
|
||||
Properly cleans up watcher state and notifies target user.
|
||||
"""
|
||||
user_id = int(client.connection_id)
|
||||
|
||||
logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}")
|
||||
|
||||
# Remove from SignalR group
|
||||
self.remove_from_group(client, self.group_id(target_id))
|
||||
|
||||
# Remove from our tracked watched users
|
||||
store = self.get_or_create_state(client)
|
||||
store.watched_user.discard(target_id)
|
||||
|
||||
# Notify target user that watcher stopped watching
|
||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
||||
logger.debug(f"[SpectatorHub] Notified {target_id} that {user_id} stopped watching")
|
||||
else:
|
||||
logger.debug(f"[SpectatorHub] Target user {target_id} not found for end watching notification")
|
||||
@@ -1,492 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import datetime
|
||||
from enum import Enum, IntEnum
|
||||
import inspect
|
||||
import json
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol as TypingProtocol,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from app.models.signalr import SignalRMeta, SignalRUnionMessage
|
||||
from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal
|
||||
|
||||
import msgpack_lazer_api as m
|
||||
from pydantic import BaseModel
|
||||
|
||||
SEP = b"\x1e"
|
||||
|
||||
|
||||
class PacketType(IntEnum):
|
||||
INVOCATION = 1
|
||||
STREAM_ITEM = 2
|
||||
COMPLETION = 3
|
||||
STREAM_INVOCATION = 4
|
||||
CANCEL_INVOCATION = 5
|
||||
PING = 6
|
||||
CLOSE = 7
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Packet:
|
||||
type: PacketType
|
||||
header: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class InvocationPacket(Packet):
|
||||
type: PacketType = PacketType.INVOCATION
|
||||
invocation_id: str | None
|
||||
target: str
|
||||
arguments: list[Any] | None = None
|
||||
stream_ids: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class CompletionPacket(Packet):
|
||||
type: PacketType = PacketType.COMPLETION
|
||||
invocation_id: str
|
||||
result: Any
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PingPacket(Packet):
|
||||
type: PacketType = PacketType.PING
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ClosePacket(Packet):
|
||||
type: PacketType = PacketType.CLOSE
|
||||
error: str | None = None
|
||||
allow_reconnect: bool = False
|
||||
|
||||
|
||||
PACKETS = {
|
||||
PacketType.INVOCATION: InvocationPacket,
|
||||
PacketType.COMPLETION: CompletionPacket,
|
||||
PacketType.PING: PingPacket,
|
||||
PacketType.CLOSE: ClosePacket,
|
||||
}
|
||||
|
||||
|
||||
class Protocol(TypingProtocol):
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> list[Packet]: ...
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes: ...
|
||||
|
||||
@classmethod
|
||||
def validate_object(cls, v: Any, typ: type) -> Any: ...
|
||||
|
||||
|
||||
class MsgpackProtocol:
|
||||
@classmethod
|
||||
def serialize_msgpack(cls, v: Any) -> Any:
|
||||
typ = v.__class__
|
||||
if issubclass(typ, BaseModel):
|
||||
return cls.serialize_to_list(v)
|
||||
elif issubclass(typ, list):
|
||||
return [cls.serialize_msgpack(item) for item in v]
|
||||
elif issubclass(typ, datetime.datetime):
|
||||
return [v, 0]
|
||||
elif issubclass(typ, datetime.timedelta):
|
||||
return int(v.total_seconds() * 10_000_000)
|
||||
elif isinstance(v, dict):
|
||||
return {cls.serialize_msgpack(k): cls.serialize_msgpack(value) for k, value in v.items()}
|
||||
elif issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
return list_.index(v) if v in list_ else v.value
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def serialize_to_list(cls, value: BaseModel) -> list[Any]:
|
||||
values = []
|
||||
for field, info in value.__class__.model_fields.items():
|
||||
metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None)
|
||||
if metadata and metadata.member_ignore:
|
||||
continue
|
||||
values.append(cls.serialize_msgpack(v=getattr(value, field)))
|
||||
if issubclass(value.__class__, SignalRUnionMessage):
|
||||
return [value.__class__.union_type, values]
|
||||
else:
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def process_object(v: Any, typ: type[BaseModel]) -> Any:
|
||||
if isinstance(v, list):
|
||||
d = {}
|
||||
i = 0
|
||||
for field, info in typ.model_fields.items():
|
||||
metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None)
|
||||
if metadata and metadata.member_ignore:
|
||||
continue
|
||||
anno = info.annotation
|
||||
if anno is None:
|
||||
d[camel_to_snake(field)] = v[i]
|
||||
else:
|
||||
d[field] = MsgpackProtocol.validate_object(v[i], anno)
|
||||
i += 1
|
||||
return d
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def _encode_varint(value: int) -> bytes:
|
||||
result = []
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
@staticmethod
|
||||
def _decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
|
||||
result = 0
|
||||
shift = 0
|
||||
pos = offset
|
||||
|
||||
while pos < len(data):
|
||||
byte = data[pos]
|
||||
result |= (byte & 0x7F) << shift
|
||||
pos += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
|
||||
return result, pos
|
||||
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
length, offset = MsgpackProtocol._decode_varint(input)
|
||||
message_data = input[offset : offset + length]
|
||||
unpacked = m.decode(message_data)
|
||||
packet_type = PacketType(unpacked[0])
|
||||
if packet_type not in PACKETS:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
match packet_type:
|
||||
case PacketType.INVOCATION:
|
||||
return [
|
||||
InvocationPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
target=unpacked[3],
|
||||
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
||||
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
||||
)
|
||||
]
|
||||
case PacketType.COMPLETION:
|
||||
result_kind = unpacked[3]
|
||||
return [
|
||||
CompletionPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
error=unpacked[4] if result_kind == 1 else None,
|
||||
result=unpacked[5] if result_kind == 3 else None,
|
||||
)
|
||||
]
|
||||
case PacketType.PING:
|
||||
return [PingPacket()]
|
||||
case PacketType.CLOSE:
|
||||
return [
|
||||
ClosePacket(
|
||||
error=unpacked[1],
|
||||
allow_reconnect=unpacked[2] if len(unpacked) > 2 else False,
|
||||
)
|
||||
]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@classmethod
|
||||
def validate_object(cls, v: Any, typ: type) -> Any:
|
||||
if issubclass(typ, BaseModel):
|
||||
return typ.model_validate(obj=cls.process_object(v, typ))
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||
return v[0]
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
|
||||
return datetime.timedelta(seconds=int(v / 10_000_000))
|
||||
elif get_origin(typ) is list:
|
||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
|
||||
elif get_origin(typ) is dict:
|
||||
return {
|
||||
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(v, get_args(typ)[1]) for k, v in v.items()
|
||||
}
|
||||
elif (origin := get_origin(typ)) is Union or origin is UnionType:
|
||||
args = get_args(typ)
|
||||
if len(args) == 2 and NoneType in args:
|
||||
non_none_args = [arg for arg in args if arg is not NoneType]
|
||||
if len(non_none_args) == 1:
|
||||
if v is None:
|
||||
return None
|
||||
return cls.validate_object(v, non_none_args[0])
|
||||
|
||||
# suppose use `MessagePack-CSharp Union | None`
|
||||
# except `X (Other Type) | None`
|
||||
if NoneType in args and v is None:
|
||||
return None
|
||||
if not all(issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args):
|
||||
raise ValueError(f"Cannot validate {v} to {typ}, only SignalRUnionMessage subclasses are supported")
|
||||
union_type = v[0]
|
||||
for arg in args:
|
||||
assert issubclass(arg, SignalRUnionMessage)
|
||||
if arg.union_type == union_type:
|
||||
return cls.validate_object(v[1], arg)
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes:
|
||||
payload = [packet.type.value, packet.header or {}]
|
||||
if isinstance(packet, InvocationPacket):
|
||||
payload.extend(
|
||||
[
|
||||
packet.invocation_id,
|
||||
packet.target,
|
||||
]
|
||||
)
|
||||
if packet.arguments is not None:
|
||||
payload.append([MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments])
|
||||
if packet.stream_ids is not None:
|
||||
payload.append(packet.stream_ids)
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
result_kind = 2
|
||||
if packet.error:
|
||||
result_kind = 1
|
||||
elif packet.result is not None:
|
||||
result_kind = 3
|
||||
payload.extend(
|
||||
[
|
||||
packet.invocation_id,
|
||||
result_kind,
|
||||
packet.error or MsgpackProtocol.serialize_msgpack(packet.result) or None,
|
||||
]
|
||||
)
|
||||
elif isinstance(packet, ClosePacket):
|
||||
payload.extend(
|
||||
[
|
||||
packet.error or "",
|
||||
packet.allow_reconnect,
|
||||
]
|
||||
)
|
||||
elif isinstance(packet, PingPacket):
|
||||
payload.pop(-1)
|
||||
data = m.encode(payload)
|
||||
return MsgpackProtocol._encode_varint(len(data)) + data
|
||||
|
||||
|
||||
class JSONProtocol:
|
||||
@classmethod
|
||||
def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False):
|
||||
typ = v.__class__
|
||||
if issubclass(typ, BaseModel):
|
||||
return cls.serialize_model(v, in_union)
|
||||
elif isinstance(v, dict):
|
||||
return {cls.serialize_to_json(k, True): cls.serialize_to_json(value) for k, value in v.items()}
|
||||
elif isinstance(v, list):
|
||||
return [cls.serialize_to_json(item) for item in v]
|
||||
elif isinstance(v, datetime.datetime):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, datetime.timedelta):
|
||||
# d.hh:mm:ss
|
||||
total_seconds = int(v.total_seconds())
|
||||
hours, remainder = divmod(total_seconds, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
return f"{hours:02}:{minutes:02}:{seconds:02}"
|
||||
elif isinstance(v, Enum) and dict_key:
|
||||
return v.value
|
||||
elif isinstance(v, Enum):
|
||||
list_ = list(typ)
|
||||
return list_.index(v)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]:
|
||||
d = {}
|
||||
is_union = issubclass(v.__class__, SignalRUnionMessage)
|
||||
for field, info in v.__class__.model_fields.items():
|
||||
metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None)
|
||||
if metadata and metadata.json_ignore:
|
||||
continue
|
||||
name = (
|
||||
snake_to_camel(
|
||||
field,
|
||||
metadata.use_abbr if metadata else True,
|
||||
)
|
||||
if not is_union
|
||||
else snake_to_pascal(
|
||||
field,
|
||||
metadata.use_abbr if metadata else True,
|
||||
)
|
||||
)
|
||||
d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union)
|
||||
if is_union and not in_union:
|
||||
return {
|
||||
"$dtype": v.__class__.__name__,
|
||||
"$value": d,
|
||||
}
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def process_object(v: Any, typ: type[BaseModel], from_union: bool = False) -> dict[str, Any]:
|
||||
d = {}
|
||||
for field, info in typ.model_fields.items():
|
||||
metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None)
|
||||
if metadata and metadata.json_ignore:
|
||||
continue
|
||||
name = (
|
||||
snake_to_camel(field, metadata.use_abbr if metadata else True)
|
||||
if not from_union
|
||||
else snake_to_pascal(field, metadata.use_abbr if metadata else True)
|
||||
)
|
||||
value = v.get(name)
|
||||
anno = typ.model_fields[field].annotation
|
||||
if anno is None:
|
||||
d[field] = value
|
||||
continue
|
||||
d[field] = JSONProtocol.validate_object(value, anno)
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
packets_raw = input.removesuffix(SEP).split(SEP)
|
||||
packets = []
|
||||
if len(packets_raw) > 1:
|
||||
for packet_raw in packets_raw:
|
||||
packets.extend(JSONProtocol.decode(packet_raw))
|
||||
return packets
|
||||
else:
|
||||
data = json.loads(packets_raw[0])
|
||||
packet_type = PacketType(data["type"])
|
||||
if packet_type not in PACKETS:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
match packet_type:
|
||||
case PacketType.INVOCATION:
|
||||
return [
|
||||
InvocationPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data.get("invocationId"),
|
||||
target=data["target"],
|
||||
arguments=data.get("arguments"),
|
||||
stream_ids=data.get("streamIds"),
|
||||
)
|
||||
]
|
||||
case PacketType.COMPLETION:
|
||||
return [
|
||||
CompletionPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data["invocationId"],
|
||||
error=data.get("error"),
|
||||
result=data.get("result"),
|
||||
)
|
||||
]
|
||||
case PacketType.PING:
|
||||
return [PingPacket()]
|
||||
case PacketType.CLOSE:
|
||||
return [
|
||||
ClosePacket(
|
||||
error=data.get("error"),
|
||||
allow_reconnect=data.get("allowReconnect", False),
|
||||
)
|
||||
]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@classmethod
|
||||
def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
|
||||
if issubclass(typ, BaseModel):
|
||||
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
|
||||
return datetime.datetime.fromisoformat(v)
|
||||
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
|
||||
# d.hh:mm:ss
|
||||
parts = v.split(":")
|
||||
if len(parts) == 3:
|
||||
return datetime.timedelta(hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2]))
|
||||
elif len(parts) == 2:
|
||||
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
|
||||
elif len(parts) == 1:
|
||||
return datetime.timedelta(seconds=int(parts[0]))
|
||||
elif get_origin(typ) is list:
|
||||
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
|
||||
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
||||
list_ = list(typ)
|
||||
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
|
||||
elif get_origin(typ) is dict:
|
||||
return {
|
||||
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(v, get_args(typ)[1]) for k, v in v.items()
|
||||
}
|
||||
elif (origin := get_origin(typ)) is Union or origin is UnionType:
|
||||
args = get_args(typ)
|
||||
if len(args) == 2 and NoneType in args:
|
||||
non_none_args = [arg for arg in args if arg is not NoneType]
|
||||
if len(non_none_args) == 1:
|
||||
if v is None:
|
||||
return None
|
||||
return cls.validate_object(v, non_none_args[0])
|
||||
|
||||
# suppose use `MessagePack-CSharp Union | None`
|
||||
# except `X (Other Type) | None`
|
||||
if NoneType in args and v is None:
|
||||
return None
|
||||
if not all(issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args):
|
||||
raise ValueError(f"Cannot validate {v} to {typ}, only SignalRUnionMessage subclasses are supported")
|
||||
# https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
|
||||
union_type = v["$dtype"]
|
||||
for arg in args:
|
||||
assert issubclass(arg, SignalRUnionMessage)
|
||||
if arg.__name__ == union_type:
|
||||
return cls.validate_object(v["$value"], arg, True)
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes:
|
||||
payload: dict[str, Any] = {
|
||||
"type": packet.type.value,
|
||||
}
|
||||
if packet.header:
|
||||
payload["header"] = packet.header
|
||||
if isinstance(packet, InvocationPacket):
|
||||
payload.update(
|
||||
{
|
||||
"target": packet.target,
|
||||
}
|
||||
)
|
||||
if packet.invocation_id is not None:
|
||||
payload["invocationId"] = packet.invocation_id
|
||||
if packet.arguments is not None:
|
||||
payload["arguments"] = [JSONProtocol.serialize_to_json(arg) for arg in packet.arguments]
|
||||
if packet.stream_ids is not None:
|
||||
payload["streamIds"] = packet.stream_ids
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
payload.update(
|
||||
{
|
||||
"invocationId": packet.invocation_id,
|
||||
}
|
||||
)
|
||||
if packet.error is not None:
|
||||
payload["error"] = packet.error
|
||||
if packet.result is not None:
|
||||
payload["result"] = JSONProtocol.serialize_to_json(packet.result)
|
||||
elif isinstance(packet, PingPacket):
|
||||
pass
|
||||
elif isinstance(packet, ClosePacket):
|
||||
payload.update(
|
||||
{
|
||||
"allowReconnect": packet.allow_reconnect,
|
||||
}
|
||||
)
|
||||
if packet.error is not None:
|
||||
payload["error"] = packet.error
|
||||
return json.dumps(payload).encode("utf-8") + SEP
|
||||
|
||||
|
||||
PROTOCOLS: dict[str, Protocol] = {
|
||||
"json": JSONProtocol,
|
||||
"messagepack": MsgpackProtocol,
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Literal
|
||||
import uuid
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies.database import DBFactory, get_db_factory
|
||||
from app.dependencies.user import get_current_user, get_current_user_and_token
|
||||
from app.log import logger
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
|
||||
from .hub import Hubs
|
||||
from .packet import PROTOCOLS, SEP
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
|
||||
from fastapi.security import SecurityScopes
|
||||
|
||||
router = APIRouter(prefix="/signalr", include_in_schema=False)
|
||||
logger.warning(
|
||||
"The Python version of SignalR server is deprecated. "
|
||||
"Maybe it will be removed or be fixed to continuously use in the future"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{hub}/negotiate", response_model=NegotiateResponse)
|
||||
async def negotiate(
|
||||
hub: Literal["spectator", "multiplayer", "metadata"],
|
||||
negotiate_version: int = Query(1, alias="negotiateVersion"),
|
||||
user: DBUser = Depends(get_current_user),
|
||||
):
|
||||
connectionId = str(user.id)
|
||||
connectionToken = f"{connectionId}:{uuid.uuid4()}"
|
||||
Hubs[hub].add_waited_client(
|
||||
connection_token=connectionToken,
|
||||
timestamp=int(time.time()),
|
||||
)
|
||||
return NegotiateResponse(
|
||||
connectionId=connectionId,
|
||||
connectionToken=connectionToken,
|
||||
negotiateVersion=negotiate_version,
|
||||
availableTransports=[Transport(transport="WebSockets")],
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/{hub}")
|
||||
async def connect(
|
||||
hub: Literal["spectator", "multiplayer", "metadata"],
|
||||
websocket: WebSocket,
|
||||
id: str,
|
||||
authorization: str = Header(...),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
):
|
||||
token = authorization[7:]
|
||||
user_id = id.split(":")[0]
|
||||
hub_ = Hubs[hub]
|
||||
if id not in hub_:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
try:
|
||||
async for session in factory():
|
||||
if (
|
||||
user_and_token := await get_current_user_and_token(
|
||||
session, SecurityScopes(scopes=["*"]), token_pw=token
|
||||
)
|
||||
) is None or str(user_and_token[0].id) != user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
except HTTPException:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
await websocket.accept()
|
||||
|
||||
# handshake
|
||||
handshake = await websocket.receive()
|
||||
message = handshake.get("bytes") or handshake.get("text")
|
||||
if not message:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
handshake_payload = json.loads(message[:-1])
|
||||
error = ""
|
||||
protocol = handshake_payload.get("protocol", "json")
|
||||
|
||||
client = None
|
||||
try:
|
||||
client = await hub_.add_client(
|
||||
connection_id=user_id,
|
||||
connection_token=id,
|
||||
connection=websocket,
|
||||
protocol=PROTOCOLS[protocol],
|
||||
)
|
||||
except KeyError:
|
||||
error = f"Protocol '{protocol}' is not supported."
|
||||
except TimeoutError:
|
||||
error = f"Connection {id} has waited too long."
|
||||
except ValueError as e:
|
||||
error = str(e)
|
||||
payload = {"error": error} if error else {}
|
||||
# finish handshake
|
||||
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
|
||||
if error or not client:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
connected_clients = hub_.get_before_clients(user_id, id)
|
||||
for connected_client in connected_clients:
|
||||
await hub_.kick_client(connected_client)
|
||||
|
||||
await hub_.clean_state(client, False)
|
||||
task = asyncio.create_task(hub_.on_connect(client))
|
||||
hub_.tasks.add(task)
|
||||
task.add_done_callback(hub_.tasks.discard)
|
||||
await hub_._listen_client(client)
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
...
|
||||
@@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ResultStore:
|
||||
def __init__(self) -> None:
|
||||
self._seq: int = 1
|
||||
self._futures: dict[str, asyncio.Future] = {}
|
||||
|
||||
@property
|
||||
def current_invocation_id(self) -> int:
|
||||
return self._seq
|
||||
|
||||
def get_invocation_id(self) -> str:
|
||||
s = self._seq
|
||||
self._seq = (self._seq + 1) % sys.maxsize
|
||||
return str(s)
|
||||
|
||||
def add_result(self, invocation_id: str, result: Any, error: str | None = None) -> None:
|
||||
if isinstance(invocation_id, str) and invocation_id.isdecimal():
|
||||
if future := self._futures.get(invocation_id):
|
||||
future.set_result((result, error))
|
||||
|
||||
async def fetch(
|
||||
self,
|
||||
invocation_id: str,
|
||||
timeout: float | None, # noqa: ASYNC109
|
||||
) -> tuple[Any, str | None]:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._futures[invocation_id] = future
|
||||
try:
|
||||
return await asyncio.wait_for(future, timeout)
|
||||
finally:
|
||||
del self._futures[invocation_id]
|
||||
@@ -1,42 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Any, ForwardRef, cast
|
||||
|
||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L61-L75
|
||||
if sys.version_info < (3, 12, 4):
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
|
||||
else:
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
|
||||
|
||||
|
||||
def get_annotation(param: inspect.Parameter, globalns: dict[str, Any]) -> Any:
|
||||
annotation = param.annotation
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
try:
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
except Exception:
|
||||
return inspect.Parameter.empty
|
||||
return annotation
|
||||
|
||||
|
||||
def get_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_annotation(param, globalns),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
return inspect.Signature(typed_params)
|
||||
@@ -14,7 +14,6 @@ from app.database.user import User
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.dependencies.scheduler import get_scheduler
|
||||
from app.log import logger
|
||||
from app.models.metadata_hub import DailyChallengeInfo
|
||||
from app.models.mods import APIMod, get_available_mods
|
||||
from app.models.room import RoomCategory
|
||||
from app.service.room import create_playlist_room
|
||||
@@ -54,8 +53,6 @@ async def create_daily_challenge_room(
|
||||
|
||||
@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="daily_challenge")
|
||||
async def daily_challenge_job():
|
||||
from app.signalr.hub import MetadataHubs
|
||||
|
||||
now = utcnow()
|
||||
redis = get_redis()
|
||||
key = f"daily_challenge:{now.date()}"
|
||||
@@ -108,7 +105,6 @@ async def daily_challenge_job():
|
||||
allowed_mods=allowed_mods_list,
|
||||
duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60),
|
||||
)
|
||||
await MetadataHubs.broadcast_call("DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id))
|
||||
logger.success(f"Added today's daily challenge: {beatmap=}, {ruleset_id=}, {required_mods=}")
|
||||
return
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
|
||||
Reference in New Issue
Block a user