refactor(signalr): remove SignalR server & msgpack_lazer_api

Maybe we can make `msgpack_lazer_api` independent?
This commit is contained in:
MingxuanGame
2025-10-03 13:20:12 +00:00
parent d23f32f08d
commit 0d9019c6cc
39 changed files with 312 additions and 6252 deletions

View File

@@ -266,18 +266,6 @@ STORAGE_SETTINGS='{
else:
return "/"
# SignalR 设置
signalr_negotiate_timeout: Annotated[
int,
Field(default=30, description="SignalR 协商超时时间(秒)"),
"SignalR 服务器设置",
]
signalr_ping_interval: Annotated[
int,
Field(default=15, description="SignalR ping 间隔(秒)"),
"SignalR 服务器设置",
]
# Fetcher 设置
fetcher_client_id: Annotated[
str,

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from app.models.mods import APIMod
from app.models.playlist import PlaylistItem
from .beatmap import Beatmap, BeatmapResp
@@ -21,8 +22,6 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.models.multiplayer_hub import PlaylistItem
from .room import Room
@@ -73,7 +72,7 @@ class Playlist(PlaylistBase, table=True):
return result.one()
@classmethod
async def from_hub(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession) -> "Playlist":
async def from_model(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
next_id = await cls.get_next_id_for_room(room_id, session=session)
return cls(
id=next_id,
@@ -90,7 +89,7 @@ class Playlist(PlaylistBase, table=True):
)
@classmethod
async def update(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession):
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id))
db_playlist = db_playlist.first()
if db_playlist is None:
@@ -107,8 +106,8 @@ class Playlist(PlaylistBase, table=True):
await session.commit()
@classmethod
async def add_to_db(cls, playlist: "PlaylistItem", room_id: int, session: AsyncSession):
db_playlist = await cls.from_hub(playlist, room_id, session)
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
db_playlist = await cls.from_model(playlist, room_id, session)
session.add(db_playlist)
await session.commit()
await session.refresh(db_playlist)

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.database.item_attempts_count import PlaylistAggregateScore
from app.database.room_participated_user import RoomParticipatedUser
@@ -32,9 +31,6 @@ from sqlmodel import (
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.models.multiplayer_hub import ServerMultiplayerRoom
class RoomBase(SQLModel, UTCBaseModel):
name: str = Field(index=True)
@@ -163,25 +159,6 @@ class RoomResp(RoomBase):
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
return resp
@classmethod
async def from_hub(cls, server_room: "ServerMultiplayerRoom") -> "RoomResp":
room = server_room.room
resp = cls(
id=room.room_id,
name=room.settings.name,
type=room.settings.match_type,
queue_mode=room.settings.queue_mode,
auto_skip=room.settings.auto_skip,
auto_start_duration=int(room.settings.auto_start_duration.total_seconds()),
status=server_room.status,
category=server_room.category,
# duration = room.settings.duration,
starts_at=server_room.start_at,
participant_count=len(room.users),
channel_id=server_room.room.channel_id or 0,
)
return resp
class APIUploadedRoom(RoomBase):
def to_room(self) -> Room:

View File

@@ -1,10 +0,0 @@
from __future__ import annotations
class SignalRException(Exception):
pass
class InvokeException(SignalRException):
def __init__(self, message: str) -> None:
self.message = message

View File

@@ -1,73 +0,0 @@
"""
会话验证接口
基于osu-web的SessionVerificationInterface实现
用于标准化会话验证行为
"""
from __future__ import annotations
from abc import ABC, abstractmethod
class SessionVerificationInterface(ABC):
"""会话验证接口
定义了会话验证所需的基本操作参考osu-web的实现
"""
@classmethod
@abstractmethod
async def find_for_verification(cls, session_id: str) -> SessionVerificationInterface | None:
"""根据会话ID查找会话用于验证
Args:
session_id: 会话ID
Returns:
会话实例或None
"""
pass
@abstractmethod
def get_key(self) -> str:
"""获取会话密钥/ID"""
pass
@abstractmethod
def get_key_for_event(self) -> str:
"""获取用于事件广播的会话密钥"""
pass
@abstractmethod
def get_verification_method(self) -> str | None:
"""获取当前验证方法
Returns:
验证方法 ('totp', 'mail') 或 None
"""
pass
@abstractmethod
def is_verified(self) -> bool:
"""检查会话是否已验证"""
pass
@abstractmethod
async def mark_verified(self) -> None:
"""标记会话为已验证"""
pass
@abstractmethod
async def set_verification_method(self, method: str) -> None:
"""设置验证方法
Args:
method: 验证方法 ('totp', 'mail')
"""
pass
@abstractmethod
def user_id(self) -> int | None:
"""获取关联的用户ID"""
pass

View File

@@ -1,157 +0,0 @@
from __future__ import annotations
from enum import IntEnum
from typing import ClassVar, Literal
from app.models.signalr import SignalRUnionMessage, UserState
from pydantic import BaseModel, Field
TOTAL_SCORE_DISTRIBUTION_BINS = 13
class _UserActivity(SignalRUnionMessage): ...
class ChoosingBeatmap(_UserActivity):
union_type: ClassVar[Literal[11]] = 11
class _InGame(_UserActivity):
beatmap_id: int
beatmap_display_title: str
ruleset_id: int
ruleset_playing_verb: str
class InSoloGame(_InGame):
union_type: ClassVar[Literal[12]] = 12
class InMultiplayerGame(_InGame):
union_type: ClassVar[Literal[23]] = 23
class SpectatingMultiplayerGame(_InGame):
union_type: ClassVar[Literal[24]] = 24
class InPlaylistGame(_InGame):
union_type: ClassVar[Literal[31]] = 31
class PlayingDailyChallenge(_InGame):
union_type: ClassVar[Literal[52]] = 52
class EditingBeatmap(_UserActivity):
union_type: ClassVar[Literal[41]] = 41
beatmap_id: int
beatmap_display_title: str
class TestingBeatmap(EditingBeatmap):
union_type: ClassVar[Literal[43]] = 43
class ModdingBeatmap(EditingBeatmap):
union_type: ClassVar[Literal[42]] = 42
class WatchingReplay(_UserActivity):
union_type: ClassVar[Literal[13]] = 13
score_id: int
player_name: str
beatmap_id: int
beatmap_display_title: str
class SpectatingUser(WatchingReplay):
union_type: ClassVar[Literal[14]] = 14
class SearchingForLobby(_UserActivity):
union_type: ClassVar[Literal[21]] = 21
class InLobby(_UserActivity):
union_type: ClassVar[Literal[22]] = 22
room_id: int
room_name: str
class InDailyChallengeLobby(_UserActivity):
union_type: ClassVar[Literal[51]] = 51
UserActivity = (
ChoosingBeatmap
| InSoloGame
| WatchingReplay
| SpectatingUser
| SearchingForLobby
| InLobby
| InMultiplayerGame
| SpectatingMultiplayerGame
| InPlaylistGame
| EditingBeatmap
| ModdingBeatmap
| TestingBeatmap
| InDailyChallengeLobby
| PlayingDailyChallenge
)
class UserPresence(BaseModel):
activity: UserActivity | None = None
status: OnlineStatus | None = None
@property
def pushable(self) -> bool:
return self.status is not None and self.status != OnlineStatus.OFFLINE
@property
def for_push(self) -> "UserPresence | None":
return UserPresence(
activity=self.activity,
status=self.status,
)
class MetadataClientState(UserPresence, UserState): ...
class OnlineStatus(IntEnum):
OFFLINE = 0 # 隐身
DO_NOT_DISTURB = 1
ONLINE = 2
class DailyChallengeInfo(BaseModel):
room_id: int
class MultiplayerPlaylistItemStats(BaseModel):
playlist_item_id: int = 0
total_score_distribution: list[int] = Field(
default_factory=list,
min_length=TOTAL_SCORE_DISTRIBUTION_BINS,
max_length=TOTAL_SCORE_DISTRIBUTION_BINS,
)
cumulative_score: int = 0
last_processed_score_id: int = 0
class MultiplayerRoomStats(BaseModel):
room_id: int
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(default_factory=dict)
class MultiplayerRoomScoreSetEvent(BaseModel):
room_id: int
playlist_item_id: int
score_id: int
user_id: int
total_score: int
new_rank: int | None = None

View File

@@ -1,840 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import IntEnum
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Literal,
TypedDict,
cast,
override,
)
from app.database.beatmap import Beatmap
from app.dependencies.database import with_db
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException
from app.utils import utcnow
from .mods import API_MODS, APIMod
from .room import (
DownloadState,
MatchType,
MultiplayerRoomState,
MultiplayerUserState,
QueueMode,
RoomCategory,
RoomStatus,
)
from .signalr import (
SignalRMeta,
SignalRUnionMessage,
UserState,
)
from pydantic import BaseModel, Field
from sqlalchemy import update
from sqlmodel import col
if TYPE_CHECKING:
from app.database.room import Room
from app.signalr.hub import MultiplayerHub
HOST_LIMIT = 50
PER_USER_LIMIT = 3
class MultiplayerClientState(UserState):
room_id: int = 0
class MultiplayerRoomSettings(BaseModel):
name: str = "Unnamed Room"
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
password: str = ""
match_type: MatchType = MatchType.HEAD_TO_HEAD
queue_mode: QueueMode = QueueMode.HOST_ONLY
auto_start_duration: timedelta = timedelta(seconds=0)
auto_skip: bool = False
@property
def auto_start_enabled(self) -> bool:
return self.auto_start_duration != timedelta(seconds=0)
class BeatmapAvailability(BaseModel):
state: DownloadState = DownloadState.UNKNOWN
download_progress: float | None = None
class _MatchUserState(SignalRUnionMessage): ...
class TeamVersusUserState(_MatchUserState):
team_id: int
union_type: ClassVar[Literal[0]] = 0
MatchUserState = TeamVersusUserState
class _MatchRoomState(SignalRUnionMessage): ...
class MultiplayerTeam(BaseModel):
id: int
name: str
class TeamVersusRoomState(_MatchRoomState):
teams: list[MultiplayerTeam] = Field(
default_factory=lambda: [
MultiplayerTeam(id=0, name="Team Red"),
MultiplayerTeam(id=1, name="Team Blue"),
]
)
union_type: ClassVar[Literal[0]] = 0
MatchRoomState = TeamVersusRoomState
class PlaylistItem(BaseModel):
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
owner_id: int
beatmap_id: int
beatmap_checksum: str
ruleset_id: int
required_mods: list[APIMod] = Field(default_factory=list)
allowed_mods: list[APIMod] = Field(default_factory=list)
expired: bool
playlist_order: int
played_at: datetime | None = None
star_rating: float
freestyle: bool
def _validate_mod_for_ruleset(self, mod: APIMod, ruleset_key: int, context: str = "mod") -> None:
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
# Check if mod is valid for ruleset
if typed_ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[typed_ruleset_key]:
raise InvokeException(f"{context} {mod['acronym']} is invalid for this ruleset")
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
# Check if mod is unplayable in multiplayer
if mod_settings.get("UserPlayable", True) is False:
raise InvokeException(f"{context} {mod['acronym']} is not playable by users")
if mod_settings.get("ValidForMultiplayer", True) is False:
raise InvokeException(f"{context} {mod['acronym']} is not valid for multiplayer")
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
from typing import Literal, cast
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
for i, mod1 in enumerate(mods):
mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
if mod1_settings:
incompatible = set(mod1_settings.get("IncompatibleMods", []))
for mod2 in mods[i + 1 :]:
if mod2["acronym"] in incompatible:
raise InvokeException(f"Mods {mod1['acronym']} and {mod2['acronym']} are incompatible")
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
from typing import Literal, cast
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
for req_mod in self.required_mods:
req_acronym = req_mod["acronym"]
req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
if req_settings:
incompatible = set(req_settings.get("IncompatibleMods", []))
conflicting_allowed = allowed_acronyms & incompatible
if conflicting_allowed:
conflict_list = ", ".join(conflicting_allowed)
raise InvokeException(f"Required mod {req_acronym} conflicts with allowed mods: {conflict_list}")
def validate_playlist_item_mods(self) -> None:
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
# Validate required mods
for mod in self.required_mods:
self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
# Validate allowed mods
for mod in self.allowed_mods:
self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
# Check internal compatibility of required mods
self._check_mod_compatibility(self.required_mods, ruleset_key)
# Check compatibility between required and allowed mods
self._check_required_allowed_compatibility(ruleset_key)
def validate_user_mods(
self,
user: "MultiplayerRoomUser",
proposed_mods: list[APIMod],
) -> tuple[bool, list[APIMod]]:
"""
Validates user mods against playlist item rules and returns valid mods.
Returns (is_valid, valid_mods).
"""
from typing import Literal, cast
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
valid_mods = []
all_proposed_valid = True
# Check if mods are valid for the ruleset
for mod in proposed_mods:
if ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[ruleset_key]:
all_proposed_valid = False
continue
valid_mods.append(mod)
# Check mod compatibility within user mods
incompatible_mods = set()
final_valid_mods = []
for mod in valid_mods:
if mod["acronym"] in incompatible_mods:
all_proposed_valid = False
continue
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
if setting_mods:
incompatible_mods.update(setting_mods["IncompatibleMods"])
final_valid_mods.append(mod)
# If not freestyle, check against allowed mods
if not self.freestyle:
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
filtered_valid_mods = []
for mod in final_valid_mods:
if mod["acronym"] not in allowed_acronyms:
all_proposed_valid = False
else:
filtered_valid_mods.append(mod)
final_valid_mods = filtered_valid_mods
# Check compatibility with required mods
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
all_mod_acronyms = {mod["acronym"] for mod in final_valid_mods} | required_mod_acronyms
# Check for incompatibility between required and user mods
filtered_valid_mods = []
for mod in final_valid_mods:
mod_acronym = mod["acronym"]
is_compatible = True
for other_acronym in all_mod_acronyms:
if other_acronym == mod_acronym:
continue
setting_mods = API_MODS[ruleset_key].get(mod_acronym)
if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
is_compatible = False
all_proposed_valid = False
break
if is_compatible:
filtered_valid_mods.append(mod)
return all_proposed_valid, filtered_valid_mods
def clone(self) -> "PlaylistItem":
copy = self.model_copy()
copy.required_mods = list(self.required_mods)
copy.allowed_mods = list(self.allowed_mods)
copy.expired = False
copy.played_at = None
return copy
class _MultiplayerCountdown(SignalRUnionMessage):
id: int = 0
time_remaining: timedelta
is_exclusive: Annotated[bool, Field(default=True), SignalRMeta(member_ignore=True)] = True
class MatchStartCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[0]] = 0
class ForceGameplayStartCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[1]] = 1
class ServerShuttingDownCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[2]] = 2
MultiplayerCountdown = MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
class MultiplayerRoomUser(BaseModel):
user_id: int
state: MultiplayerUserState = MultiplayerUserState.IDLE
availability: BeatmapAvailability = BeatmapAvailability(state=DownloadState.UNKNOWN, download_progress=None)
mods: list[APIMod] = Field(default_factory=list)
match_state: MatchUserState | None = None
ruleset_id: int | None = None # freestyle
beatmap_id: int | None = None # freestyle
class MultiplayerRoom(BaseModel):
room_id: int
state: MultiplayerRoomState
settings: MultiplayerRoomSettings
users: list[MultiplayerRoomUser] = Field(default_factory=list)
host: MultiplayerRoomUser | None = None
match_state: MatchRoomState | None = None
playlist: list[PlaylistItem] = Field(default_factory=list)
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
channel_id: int
@classmethod
def from_db(cls, room: "Room") -> "MultiplayerRoom":
"""
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
"""
# 用户列表
users = [MultiplayerRoomUser(user_id=room.host_id)]
host_user = MultiplayerRoomUser(user_id=room.host_id)
# playlist 转换
playlist = []
if room.playlist:
for item in room.playlist:
playlist.append(
PlaylistItem(
id=item.id,
owner_id=item.owner_id,
beatmap_id=item.beatmap_id,
beatmap_checksum=item.beatmap.checksum if item.beatmap else "",
ruleset_id=item.ruleset_id,
required_mods=item.required_mods,
allowed_mods=item.allowed_mods,
expired=item.expired,
playlist_order=item.playlist_order,
played_at=item.played_at,
star_rating=item.beatmap.difficulty_rating if item.beatmap is not None else 0.0,
freestyle=item.freestyle,
)
)
return cls(
room_id=room.id,
state=getattr(room, "state", MultiplayerRoomState.OPEN),
settings=MultiplayerRoomSettings(
name=room.name,
playlist_item_id=playlist[0].id if playlist else 0,
password=getattr(room, "password", ""),
match_type=room.type,
queue_mode=room.queue_mode,
auto_start_duration=timedelta(seconds=room.auto_start_duration),
auto_skip=room.auto_skip,
),
users=users,
host=host_user,
match_state=None,
playlist=playlist,
active_countdowns=[],
channel_id=room.channel_id or 0,
)
class MultiplayerQueue:
def __init__(self, room: "ServerMultiplayerRoom"):
self.server_room = room
self.current_index = 0
@property
def hub(self) -> "MultiplayerHub":
return self.server_room.hub
@property
def upcoming_items(self):
return sorted(
(item for item in self.room.playlist if not item.expired),
key=lambda i: i.playlist_order,
)
@property
def room(self):
return self.server_room.room
async def update_order(self):
from app.database import Playlist
match self.room.settings.queue_mode:
case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
ordered_active_items = []
is_first_set = True
first_set_order_by_user_id = {}
active_items = [item for item in self.room.playlist if not item.expired]
active_items.sort(key=lambda x: x.id)
user_item_groups = {}
for item in active_items:
if item.owner_id not in user_item_groups:
user_item_groups[item.owner_id] = []
user_item_groups[item.owner_id].append(item)
max_items = max((len(items) for items in user_item_groups.values()), default=0)
for i in range(max_items):
current_set = []
for user_id, items in user_item_groups.items():
if i < len(items):
current_set.append(items[i])
if is_first_set:
current_set.sort(key=lambda item: (item.playlist_order, item.id))
ordered_active_items.extend(current_set)
first_set_order_by_user_id = {
item.owner_id: idx for idx, item in enumerate(ordered_active_items)
}
else:
current_set.sort(key=lambda item: first_set_order_by_user_id.get(item.owner_id, 0))
ordered_active_items.extend(current_set)
is_first_set = False
case _:
ordered_active_items = sorted(
(item for item in self.room.playlist if not item.expired),
key=lambda x: x.id,
)
async with with_db() as session:
for idx, item in enumerate(ordered_active_items):
if item.playlist_order == idx:
continue
item.playlist_order = idx
await Playlist.update(item, self.room.room_id, session)
await self.hub.playlist_changed(self.server_room, item, beatmap_changed=False)
async def update_current_item(self):
upcoming_items = self.upcoming_items
if upcoming_items:
# 优先选择未过期的项目
next_item = upcoming_items[0]
else:
# 如果所有项目都过期了选择最近添加的项目played_at 为 None 或最新的)
# 优先选择 expired=False 的项目,然后是 played_at 最晚的
next_item = max(
self.room.playlist,
key=lambda i: (not i.expired, i.played_at or datetime.min),
)
self.current_index = self.room.playlist.index(next_item)
last_id = self.room.settings.playlist_item_id
self.room.settings.playlist_item_id = next_item.id
if last_id != next_item.id:
await self.hub.setting_changed(self.server_room, True)
async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
from app.database import Playlist
is_host = self.room.host and self.room.host.user_id == user.user_id
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
raise InvokeException("You are not the host")
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
if len([True for u in self.room.playlist if u.owner_id == user.user_id and not u.expired]) >= limit:
raise InvokeException(f"You can only have {limit} items in the queue")
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with with_db() as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
if beatmap is None:
raise InvokeException("Beatmap not found")
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
item.validate_playlist_item_mods()
item.owner_id = user.user_id
item.star_rating = beatmap.difficulty_rating
await Playlist.add_to_db(item, self.room.room_id, session)
self.room.playlist.append(item)
await self.hub.playlist_added(self.server_room, item)
await self.update_order()
await self.update_current_item()
async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
from app.database import Playlist
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with with_db() as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
existing_item = next((i for i in self.room.playlist if i.id == item.id), None)
if existing_item is None:
raise InvokeException("Attempted to change an item that doesn't exist")
if existing_item.owner_id != user.user_id and self.room.host != user:
raise InvokeException("Attempted to change an item which is not owned by the user")
if existing_item.expired:
raise InvokeException("Attempted to change an item which has already been played")
item.validate_playlist_item_mods()
item.owner_id = user.user_id
item.star_rating = float(beatmap.difficulty_rating)
item.playlist_order = existing_item.playlist_order
await Playlist.update(item, self.room.room_id, session)
# Update item in playlist
for idx, playlist_item in enumerate(self.room.playlist):
if playlist_item.id == item.id:
self.room.playlist[idx] = item
break
await self.hub.playlist_changed(
self.server_room,
item,
beatmap_changed=item.beatmap_checksum != existing_item.beatmap_checksum,
)
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
from app.database import Playlist
item = next(
(i for i in self.room.playlist if i.id == playlist_item_id),
None,
)
if item is None:
raise InvokeException("Item does not exist in the room")
# Check if it's the only item and current item
if item == self.current_item:
upcoming_items = [i for i in self.room.playlist if not i.expired]
if len(upcoming_items) == 1:
raise InvokeException("The only item in the room cannot be removed")
if item.owner_id != user.user_id and self.room.host != user:
raise InvokeException("Attempted to remove an item which is not owned by the user")
if item.expired:
raise InvokeException("Attempted to remove an item which has already been played")
async with with_db() as session:
await Playlist.delete_item(item.id, self.room.room_id, session)
found_item = next((i for i in self.room.playlist if i.id == item.id), None)
if found_item:
self.room.playlist.remove(found_item)
self.current_index = self.room.playlist.index(self.upcoming_items[0])
await self.update_order()
await self.update_current_item()
await self.hub.playlist_removed(self.server_room, item.id)
async def finish_current_item(self):
from app.database import Playlist
async with with_db() as session:
played_at = utcnow()
await session.execute(
update(Playlist)
.where(
col(Playlist.id) == self.current_item.id,
col(Playlist.room_id) == self.room.room_id,
)
.values(expired=True, played_at=played_at)
)
self.room.playlist[self.current_index].expired = True
self.room.playlist[self.current_index].played_at = played_at
await self.hub.playlist_changed(self.server_room, self.current_item, True)
await self.update_order()
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
await self.update_current_item()
async def update_queue_mode(self):
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
await self.update_order()
await self.update_current_item()
@property
def current_item(self):
return self.room.playlist[self.current_index]
@dataclass
class CountdownInfo:
countdown: MultiplayerCountdown
duration: timedelta
task: asyncio.Task | None = None
def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown
self.duration = (
countdown.time_remaining if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0)
)
class _MatchRequest(SignalRUnionMessage): ...
class ChangeTeamRequest(_MatchRequest):
union_type: ClassVar[Literal[0]] = 0
team_id: int
class StartMatchCountdownRequest(_MatchRequest):
union_type: ClassVar[Literal[1]] = 1
duration: timedelta
class StopCountdownRequest(_MatchRequest):
union_type: ClassVar[Literal[2]] = 2
id: int
MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
class MatchTypeHandler(ABC):
def __init__(self, room: "ServerMultiplayerRoom"):
self.room = room
self.hub = room.hub
@abstractmethod
async def handle_join(self, user: MultiplayerRoomUser): ...
@abstractmethod
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
@abstractmethod
async def handle_leave(self, user: MultiplayerRoomUser): ...
@abstractmethod
def get_details(self) -> MatchStartedEventDetail: ...
class HeadToHeadHandler(MatchTypeHandler):
@override
async def handle_join(self, user: MultiplayerRoomUser):
if user.match_state is not None:
user.match_state = None
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
@override
def get_details(self) -> MatchStartedEventDetail:
detail = MatchStartedEventDetail(room_type="head_to_head", team=None)
return detail
class TeamVersusHandler(MatchTypeHandler):
@override
def __init__(self, room: "ServerMultiplayerRoom"):
super().__init__(room)
self.state = TeamVersusRoomState()
room.room.match_state = self.state
task = asyncio.create_task(self.hub.change_room_match_state(self.room))
self.hub.tasks.add(task)
task.add_done_callback(self.hub.tasks.discard)
def _get_best_available_team(self) -> int:
for team in self.state.teams:
if all(
(
user.match_state is None
or not isinstance(user.match_state, TeamVersusUserState)
or user.match_state.team_id != team.id
)
for user in self.room.room.users
):
return team.id
from collections import defaultdict
team_counts = defaultdict(int)
for user in self.room.room.users:
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
team_counts[user.match_state.team_id] += 1
if team_counts:
min_count = min(team_counts.values())
for team_id, count in team_counts.items():
if count == min_count:
return team_id
return self.state.teams[0].id if self.state.teams else 0
@override
async def handle_join(self, user: MultiplayerRoomUser):
best_team_id = self._get_best_available_team()
user.match_state = TeamVersusUserState(team_id=best_team_id)
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
if not isinstance(request, ChangeTeamRequest):
return
if request.team_id not in [team.id for team in self.state.teams]:
raise InvokeException("Invalid team ID")
user.match_state = TeamVersusUserState(team_id=request.team_id)
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
@override
def get_details(self) -> MatchStartedEventDetail:
teams: dict[int, Literal["blue", "red"]] = {}
for user in self.room.room.users:
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
return detail
MATCH_TYPE_HANDLERS = {
MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
MatchType.TEAM_VERSUS: TeamVersusHandler,
}
@dataclass
class ServerMultiplayerRoom:
room: MultiplayerRoom
category: RoomCategory
status: RoomStatus
start_at: datetime
hub: "MultiplayerHub"
match_type_handler: MatchTypeHandler
queue: MultiplayerQueue
_next_countdown_id: int
_countdown_id_lock: asyncio.Lock
_tracked_countdown: dict[int, CountdownInfo]
def __init__(
self,
room: MultiplayerRoom,
category: RoomCategory,
start_at: datetime,
hub: "MultiplayerHub",
):
self.room = room
self.category = category
self.status = RoomStatus.IDLE
self.start_at = start_at
self.hub = hub
self.queue = MultiplayerQueue(self)
self._next_countdown_id = 0
self._countdown_id_lock = asyncio.Lock()
self._tracked_countdown = {}
async def set_handler(self):
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](self)
for i in self.room.users:
await self.match_type_handler.handle_join(i)
async def get_next_countdown_id(self) -> int:
async with self._countdown_id_lock:
self._next_countdown_id += 1
return self._next_countdown_id
async def start_countdown(
self,
countdown: MultiplayerCountdown,
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
):
async def _countdown_task(self: "ServerMultiplayerRoom"):
await asyncio.sleep(info.duration.total_seconds())
if on_complete is not None:
await on_complete(self)
await self.stop_countdown(countdown)
if countdown.is_exclusive:
await self.stop_all_countdowns(countdown.__class__)
countdown.id = await self.get_next_countdown_id()
info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown)
self._tracked_countdown[countdown.id] = info
await self.hub.send_match_event(self, CountdownStartedEvent(countdown=info.countdown))
info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown):
info = self._tracked_countdown.get(countdown.id)
if info is None:
return
del self._tracked_countdown[countdown.id]
self.room.active_countdowns.remove(countdown)
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
if info.task is not None and not info.task.done():
info.task.cancel()
async def stop_all_countdowns(self, typ: type[MultiplayerCountdown]):
for countdown in list(self._tracked_countdown.values()):
if isinstance(countdown.countdown, typ):
await self.stop_countdown(countdown.countdown)
class _MatchServerEvent(SignalRUnionMessage): ...
class CountdownStartedEvent(_MatchServerEvent):
countdown: MultiplayerCountdown
union_type: ClassVar[Literal[0]] = 0
class CountdownStoppedEvent(_MatchServerEvent):
id: int
union_type: ClassVar[Literal[1]] = 1
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
class GameplayAbortReason(IntEnum):
LOAD_TOOK_TOO_LONG = 0
HOST_ABORTED = 1
class MatchStartedEventDetail(TypedDict):
room_type: Literal["playlists", "head_to_head", "team_versus"]
team: dict[int, Literal["blue", "red"]] | None

22
app/models/playlist.py Normal file
View File

@@ -0,0 +1,22 @@
from __future__ import annotations
from datetime import datetime
from app.models.mods import APIMod
from pydantic import BaseModel, Field
class PlaylistItem(BaseModel):
id: int = Field(default=0, ge=-1)
owner_id: int
beatmap_id: int
beatmap_checksum: str = ""
ruleset_id: int = 0
required_mods: list[APIMod] = Field(default_factory=list)
allowed_mods: list[APIMod] = Field(default_factory=list)
expired: bool = False
playlist_order: int = 0
played_at: datetime | None = None
star_rating: float = 0.0
freestyle: bool = False

View File

@@ -1,37 +1 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar
from pydantic import (
BaseModel,
Field,
)
@dataclass
class SignalRMeta:
member_ignore: bool = False # implement of IgnoreMember (msgpack) attribute
json_ignore: bool = False # implement of JsonIgnore (json) attribute
use_abbr: bool = True
class SignalRUnionMessage(BaseModel):
union_type: ClassVar[int]
class Transport(BaseModel):
transport: str
transfer_formats: list[str] = Field(default_factory=lambda: ["Binary", "Text"], alias="transferFormats")
class NegotiateResponse(BaseModel):
connectionId: str
connectionToken: str
negotiateVersion: int = 1
availableTransports: list[Transport]
class UserState(BaseModel):
connection_id: str
connection_token: str

View File

@@ -1,131 +0,0 @@
from __future__ import annotations
import datetime
from enum import IntEnum
from typing import Annotated, Any
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from .score import (
ScoreStatistics,
)
from .signalr import SignalRMeta, UserState
from pydantic import BaseModel, Field, field_validator
class SpectatedUserState(IntEnum):
Idle = 0
Playing = 1
Paused = 2
Passed = 3
Failed = 4
Quit = 5
class SpectatorState(BaseModel):
beatmap_id: int | None = None
ruleset_id: int | None = None # 0,1,2,3
mods: list[APIMod] = Field(default_factory=list)
state: SpectatedUserState
maximum_statistics: ScoreStatistics = Field(default_factory=dict)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SpectatorState):
return False
return (
self.beatmap_id == other.beatmap_id
and self.ruleset_id == other.ruleset_id
and self.mods == other.mods
and self.state == other.state
)
class ScoreProcessorStatistics(BaseModel):
base_score: float
maximum_base_score: float
accuracy_judgement_count: int
combo_portion: float
bonus_portion: float
class FrameHeader(BaseModel):
total_score: int
accuracy: float
combo: int
max_combo: int
statistics: ScoreStatistics = Field(default_factory=dict)
score_processor_statistics: ScoreProcessorStatistics
received_time: datetime.datetime
mods: list[APIMod] = Field(default_factory=list)
@field_validator("received_time", mode="before")
@classmethod
def validate_timestamp(cls, v: Any) -> datetime.datetime:
if isinstance(v, list):
return v[0]
if isinstance(v, datetime.datetime):
return v
if isinstance(v, int | float):
return datetime.datetime.fromtimestamp(v, tz=datetime.UTC)
if isinstance(v, str):
return datetime.datetime.fromisoformat(v)
raise ValueError(f"Cannot convert {type(v)} to datetime")
# class ReplayButtonState(IntEnum):
# NONE = 0
# LEFT1 = 1
# RIGHT1 = 2
# LEFT2 = 4
# RIGHT2 = 8
# SMOKE = 16
class LegacyReplayFrame(BaseModel):
time: float # from ReplayFrame,the parent of LegacyReplayFrame
mouse_x: float | None = None
mouse_y: float | None = None
button_state: int
header: Annotated[FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)]
class FrameDataBundle(BaseModel):
header: FrameHeader
frames: list[LegacyReplayFrame]
# Use for server
class APIUser(BaseModel):
id: int
name: str
class ScoreInfo(BaseModel):
mods: list[APIMod]
user: APIUser
ruleset: int
maximum_statistics: ScoreStatistics
id: int | None = None
total_score: int | None = None
accuracy: float | None = None
max_combo: int | None = None
combo: int | None = None
statistics: ScoreStatistics = Field(default_factory=dict)
class StoreScore(BaseModel):
score_info: ScoreInfo
replay_frames: list[LegacyReplayFrame] = Field(default_factory=list)
class StoreClientState(UserState):
state: SpectatorState | None = None
beatmap_status: BeatmapRankStatus | None = None
checksum: str | None = None
ruleset_id: int | None = None
score_token: int | None = None
watched_user: set[int] = Field(default_factory=set)
score: StoreScore | None = None

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
# from app.signalr import signalr_router as signalr_router
from .auth import router as auth_router
from .fetcher import fetcher_router as fetcher_router
from .file import file_router as file_router
@@ -25,5 +24,4 @@ __all__ = [
"private_router",
"redirect_api_router",
"redirect_router",
# "signalr_router",
]

View File

@@ -15,7 +15,7 @@ from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.dependencies.storage import StorageService
from app.log import log
from app.models.multiplayer_hub import PlaylistItem as HubPlaylistItem
from app.models.playlist import PlaylistItem
from app.models.room import MatchType, QueueMode, RoomCategory, RoomStatus
from app.utils import utcnow
@@ -216,7 +216,7 @@ async def _add_playlist_items(db: Database, room_id: int, room_data: dict[str, A
# Insert playlist items
for item_data in items_raw:
hub_item = HubPlaylistItem(
playlist_item = PlaylistItem(
id=-1, # Placeholder, will be assigned by add_to_db
owner_id=item_data["owner_id"],
ruleset_id=item_data["ruleset_id"],
@@ -230,7 +230,7 @@ async def _add_playlist_items(db: Database, room_id: int, room_data: dict[str, A
beatmap_checksum=item_data["beatmap_checksum"],
star_rating=item_data["star_rating"],
)
await DBPlaylist.add_to_db(hub_item, room_id=room_id, session=db)
await DBPlaylist.add_to_db(playlist_item, room_id=room_id, session=db)
async def _add_host_as_participant(db: Database, room_id: int, host_user_id: int) -> None:

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from datetime import timedelta
from math import ceil
import random
import shlex
@@ -10,27 +9,15 @@ import shlex
from app.calculator import calculate_weighted_pp
from app.const import BANCHOBOT_ID
from app.database import ChatMessageResp
from app.database.beatmap import Beatmap
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
from app.database.score import Score, get_best_id
from app.database.statistics import UserStatistics, get_rank
from app.database.user import User
from app.dependencies.fetcher import get_fetcher
from app.exception import InvokeException
from app.models.mods import APIMod, get_available_mods, mod_to_save
from app.models.multiplayer_hub import (
ChangeTeamRequest,
ServerMultiplayerRoom,
StartMatchCountdownRequest,
)
from app.models.room import MatchType, QueueMode, RoomStatus
from app.models.mods import mod_to_save
from app.models.score import GameMode
from app.signalr.hub import MultiplayerHubs
from app.signalr.hub.hub import Client
from .server import server
from httpx import HTTPError
from sqlalchemy.orm import joinedload
from sqlmodel import col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -216,352 +203,6 @@ PP: {statistics.pp:.2f}
"""
async def _mp_name(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp name <name>"
name = args[0]
try:
settings = room.room.settings.model_copy()
settings.name = name
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
return f"Room name has changed to {name}"
except InvokeException as e:
return e.message
async def _mp_set(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp set <teammode> [<queuemode>]"
teammode = {"0": MatchType.HEAD_TO_HEAD, "2": MatchType.TEAM_VERSUS}.get(args[0])
if not teammode:
return "Invalid teammode. Use 0 for Head-to-Head or 2 for Team Versus."
queuemode = (
{
"0": QueueMode.HOST_ONLY,
"1": QueueMode.ALL_PLAYERS,
"2": QueueMode.ALL_PLAYERS_ROUND_ROBIN,
}.get(args[1])
if len(args) >= 2
else None
)
try:
settings = room.room.settings.model_copy()
settings.match_type = teammode
if queuemode:
settings.queue_mode = queuemode
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
return f"Room setting 'teammode' has been changed to {teammode.name.lower()}"
except InvokeException as e:
return e.message
async def _mp_host(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp host <username>"
username = args[0]
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
if not user_id:
return f"User '{username}' not found."
try:
await MultiplayerHubs.TransferHost(signalr_client, user_id)
return f"User '{username}' has been hosted in the room."
except InvokeException as e:
return e.message
async def _mp_start(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
timer = None
if len(args) >= 1 and args[0].isdigit():
timer = int(args[0])
try:
if timer is not None:
await MultiplayerHubs.SendMatchRequest(
signalr_client,
StartMatchCountdownRequest(duration=timedelta(seconds=timer)),
)
return ""
else:
await MultiplayerHubs.StartMatch(signalr_client)
return "Good luck! Enjoy game!"
except InvokeException as e:
return e.message
async def _mp_abort(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
try:
await MultiplayerHubs.AbortMatch(signalr_client)
return "Match aborted."
except InvokeException as e:
return e.message
async def _mp_team(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
):
if room.room.settings.match_type != MatchType.TEAM_VERSUS:
return "This command is only available in Team Versus mode."
if len(args) < 2:
return "Usage: !mp team <username> <colour>"
username = args[0]
team = {"red": 0, "blue": 1}.get(args[1])
if team is None:
return "Invalid team colour. Use 'red' or 'blue'."
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
if not user_id:
return f"User '{username}' not found."
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
if not user_client:
return f"User '{username}' is not in the room."
assert room.room.host
if user_client.user_id != signalr_client.user_id and room.room.host.user_id != signalr_client.user_id:
return "You are not allowed to change other users' teams."
try:
await MultiplayerHubs.SendMatchRequest(user_client, ChangeTeamRequest(team_id=team))
return ""
except InvokeException as e:
return e.message
async def _mp_password(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
password = ""
if len(args) >= 1:
password = args[0]
try:
settings = room.room.settings.model_copy()
settings.password = password
await MultiplayerHubs.ChangeSettings(signalr_client, settings)
return "Room password has been set."
except InvokeException as e:
return e.message
async def _mp_kick(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp kick <username>"
username = args[0]
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
if not user_id:
return f"User '{username}' not found."
try:
await MultiplayerHubs.KickUser(signalr_client, user_id)
return f"User '{username}' has been kicked from the room."
except InvokeException as e:
return e.message
async def _mp_map(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp map <mapid> [<playmode>]"
if room.status != RoomStatus.IDLE:
return "Cannot change map while the game is running."
map_id = args[0]
if not map_id.isdigit():
return "Invalid map ID."
map_id = int(map_id)
playmode = GameMode.parse(args[1].upper()) if len(args) >= 2 else None
if playmode not in (
GameMode.OSU,
GameMode.TAIKO,
GameMode.FRUITS,
GameMode.MANIA,
None,
):
return "Invalid playmode."
try:
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
return f"Cannot convert to {playmode.value}. Original mode is {beatmap.mode.value}."
except HTTPError:
return "Beatmap not found"
try:
current_item = room.queue.current_item
item = current_item.model_copy(deep=True)
item.owner_id = signalr_client.user_id
item.beatmap_checksum = beatmap.checksum
item.required_mods = []
item.allowed_mods = []
item.freestyle = False
item.beatmap_id = map_id
if playmode is not None:
item.ruleset_id = int(playmode)
if item.expired:
item.id = 0
item.expired = False
item.played_at = None
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
else:
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
return ""
except InvokeException as e:
return e.message
async def _mp_mods(
signalr_client: Client,
room: ServerMultiplayerRoom,
args: list[str],
session: AsyncSession,
) -> str:
if len(args) < 1:
return "Usage: !mp mods <mod1> [<mod2> ...]"
if room.status != RoomStatus.IDLE:
return "Cannot change mods while the game is running."
required_mods = []
allowed_mods = []
freestyle = False
freemod = False
for arg in args:
arg = arg.upper()
if arg == "NONE":
required_mods.clear()
allowed_mods.clear()
break
elif arg == "FREESTYLE":
freestyle = True
elif arg == "FREEMOD":
freemod = True
elif arg.startswith("+"):
mod = arg.removeprefix("+")
if len(mod) != 2:
return f"Invalid mod: {mod}."
allowed_mods.append(APIMod(acronym=mod))
else:
if len(arg) != 2:
return f"Invalid mod: {arg}."
required_mods.append(APIMod(acronym=arg))
try:
current_item = room.queue.current_item
item = current_item.model_copy(deep=True)
item.owner_id = signalr_client.user_id
item.freestyle = freestyle
if freestyle:
item.allowed_mods = []
elif freemod:
item.allowed_mods = get_available_mods(current_item.ruleset_id, required_mods)
else:
item.allowed_mods = allowed_mods
item.required_mods = required_mods
if item.expired:
item.id = 0
item.expired = False
item.played_at = None
await MultiplayerHubs.AddPlaylistItem(signalr_client, item)
else:
await MultiplayerHubs.EditPlaylistItem(signalr_client, item)
return ""
except InvokeException as e:
return e.message
_MP_COMMANDS = {
"name": _mp_name,
"set": _mp_set,
"host": _mp_host,
"start": _mp_start,
"abort": _mp_abort,
"map": _mp_map,
"mods": _mp_mods,
"kick": _mp_kick,
"password": _mp_password,
"team": _mp_team,
}
_MP_HELP = """!mp name <name>
!mp set <teammode> [<queuemode>]
!mp host <host>
!mp start [<timer>]
!mp abort
!mp map <map> [<playmode>]
!mp mods <mod1> [<mod2> ...]
!mp kick <user>
!mp password [<password>]
!mp team <user> <team:red|blue>"""
@bot.command("mp")
async def _mp(user: User, args: list[str], session: AsyncSession, channel: ChatChannel):
if not channel.name.startswith("room_"):
return
room_id = int(channel.name[5:])
room = MultiplayerHubs.rooms.get(room_id)
if not room:
return
signalr_client = MultiplayerHubs.get_client_by_id(str(user.id))
if not signalr_client:
return
if len(args) < 1:
return f"Usage: !mp <{'|'.join(_MP_COMMANDS.keys())}> [args]"
command = args[0].lower()
if command not in _MP_COMMANDS:
return f"No such command: {command}"
return await _MP_COMMANDS[command](signalr_client, room, args[1:], session)
async def _score(
user_id: int,
session: AsyncSession,

View File

@@ -16,7 +16,6 @@ from app.dependencies.database import Database, Redis
from app.dependencies.user import ClientUser, get_current_user
from app.models.room import RoomCategory, RoomStatus
from app.service.room import create_playlist_room_from_api
from app.signalr.hub import MultiplayerHubs
from app.utils import utcnow
from .router import router
@@ -391,14 +390,12 @@ async def get_room_events(
first_event_id = min(first_event_id, event.id)
last_event_id = max(last_event_id, event.id)
if room := MultiplayerHubs.rooms.get(room_id):
current_playlist_item_id = room.queue.current_item.id
room_resp = await RoomResp.from_hub(room)
else:
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if room is None:
raise HTTPException(404, "Room not found")
room_resp = await RoomResp.from_db(room, db)
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if room is None:
raise HTTPException(404, "Room not found")
room_resp = await RoomResp.from_db(room, db)
if room.category == RoomCategory.REALTIME and room_resp.current_playlist_item:
current_playlist_item_id = room_resp.current_playlist_item.id
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
user_resps = [await UserResp.from_db(user, db) for user in users]

View File

@@ -217,8 +217,7 @@ class MessageQueueProcessor:
):
"""通知客户端消息ID已更新"""
try:
# 这里我们需要通过 SignalR消息更新通知
# 但为了避免循环依赖,我们将通过 Redis 发布消息更新事件
# 通过 Redis消息更新事件,由聊天通知服务分发到客户端
update_event = {
"event": "chat.message.update",
"data": {
@@ -229,7 +228,6 @@ class MessageQueueProcessor:
},
}
# 发布到 Redis 频道,让 SignalR 服务处理
await self._redis_exec(
self.redis_message.publish,
f"chat_updates:{channel_id}",

View File

@@ -1,85 +0,0 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING
from app.database import PlaylistBestScore, Score
from app.database.playlist_best_score import get_position
from app.dependencies.database import with_db
from app.models.metadata_hub import MultiplayerRoomScoreSetEvent
from .base import RedisSubscriber
from sqlmodel import select
if TYPE_CHECKING:
from app.signalr.hub import MetadataHub
CHANNEL = "osu-channel:score:processed"
class ScoreSubscriber(RedisSubscriber):
def __init__(self):
super().__init__()
self.room_subscriber: dict[int, list[int]] = {}
self.metadata_hub: "MetadataHub | None " = None
self.subscribed = False
self.handlers[CHANNEL] = [self._handler]
async def subscribe_room_score(self, room_id: int, user_id: int):
if room_id not in self.room_subscriber:
await self.subscribe(CHANNEL)
self.start()
self.room_subscriber.setdefault(room_id, []).append(user_id)
async def unsubscribe_room_score(self, room_id: int, user_id: int):
if room_id in self.room_subscriber:
try:
self.room_subscriber[room_id].remove(user_id)
except ValueError:
pass
if not self.room_subscriber[room_id]:
del self.room_subscriber[room_id]
async def _notify_room_score_processed(self, score_id: int):
if not self.metadata_hub:
return
async with with_db() as session:
score = await session.get(Score, score_id)
if not score or not score.passed or score.room_id is None or score.playlist_item_id is None:
return
if not self.room_subscriber.get(score.room_id, []):
return
new_rank = None
user_best = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.user_id == score.user_id,
PlaylistBestScore.room_id == score.room_id,
)
)
).first()
if user_best and user_best.score_id == score_id:
new_rank = await get_position(
user_best.room_id,
user_best.playlist_id,
user_best.score_id,
session,
)
event = MultiplayerRoomScoreSetEvent(
room_id=score.room_id,
playlist_item_id=score.playlist_item_id,
score_id=score_id,
user_id=score.user_id,
total_score=score.total_score,
new_rank=new_rank,
)
await self.metadata_hub.notify_room_score_processed(event)
async def _handler(self, channel: str, data: str):
score_id = json.loads(data)["ScoreId"]
if self.metadata_hub:
await self._notify_room_score_processed(score_id)

View File

@@ -1,5 +0,0 @@
from __future__ import annotations
from .router import router as signalr_router
__all__ = ["signalr_router"]

View File

@@ -1,15 +0,0 @@
from __future__ import annotations
from .hub import Hub
from .metadata import MetadataHub
from .multiplayer import MultiplayerHub
from .spectator import SpectatorHub
SpectatorHubs = SpectatorHub()
MultiplayerHubs = MultiplayerHub()
MetadataHubs = MetadataHub()
Hubs: dict[str, Hub] = {
"spectator": SpectatorHubs,
"multiplayer": MultiplayerHubs,
"metadata": MetadataHubs,
}

View File

@@ -1,322 +0,0 @@
from __future__ import annotations
from abc import abstractmethod
import asyncio
import time
from typing import Any
from app.config import settings
from app.exception import InvokeException
from app.log import logger
from app.models.signalr import UserState
from app.signalr.packet import (
ClosePacket,
CompletionPacket,
InvocationPacket,
Packet,
PingPacket,
Protocol,
)
from app.signalr.store import ResultStore
from app.signalr.utils import get_signature
from fastapi import WebSocket
from starlette.websockets import WebSocketDisconnect
class CloseConnection(Exception):
def __init__(
self,
message: str = "Connection closed",
allow_reconnect: bool = False,
from_client: bool = False,
) -> None:
super().__init__(message)
self.message = message
self.allow_reconnect = allow_reconnect
self.from_client = from_client
class Client:
def __init__(
self,
connection_id: str,
connection_token: str,
connection: WebSocket,
protocol: Protocol,
) -> None:
self.connection_id = connection_id
self.connection_token = connection_token
self.connection = connection
self.protocol = protocol
self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None
self._store = ResultStore()
def __hash__(self) -> int:
return hash(self.connection_token)
@property
def user_id(self) -> int:
return int(self.connection_id)
async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.protocol.encode(packet))
async def receive_packets(self) -> list[Packet]:
message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
return []
return self.protocol.decode(d)
async def _ping(self):
while True:
try:
await self.send_packet(PingPacket())
await asyncio.sleep(settings.signalr_ping_interval)
except WebSocketDisconnect:
break
except RuntimeError as e:
if "disconnect message" in str(e) or "close message" in str(e):
break
else:
logger.error(f"Error in ping task for {self.connection_id}: {e}")
break
except Exception:
logger.exception(f"Error in client {self.connection_id}")
class Hub[TState: UserState]:
def __init__(self) -> None:
self.clients: dict[str, Client] = {}
self.waited_clients: dict[str, int] = {}
self.tasks: set[asyncio.Task] = set()
self.groups: dict[str, set[Client]] = {}
self.state: dict[int, TState] = {}
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
self.waited_clients[connection_token] = timestamp
def get_client_by_id(self, id: str, default: Any = None) -> Client:
for client in self.clients.values():
if client.connection_id == id:
return client
return default
def get_before_clients(self, id: str, current_token: str) -> list[Client]:
clients = []
for client in self.clients.values():
if client.connection_id != id:
continue
if client.connection_token == current_token:
continue
clients.append(client)
return clients
@abstractmethod
def create_state(self, client: Client) -> TState:
raise NotImplementedError
def get_or_create_state(self, client: Client) -> TState:
if (state := self.state.get(client.user_id)) is not None:
return state
state = self.create_state(client)
self.state[client.user_id] = state
return state
def add_to_group(self, client: Client, group_id: str) -> None:
self.groups.setdefault(group_id, set()).add(client)
def remove_from_group(self, client: Client, group_id: str) -> None:
if group_id in self.groups:
self.groups[group_id].discard(client)
async def kick_client(self, client: Client) -> None:
await self.call_noblock(client, "DisconnectRequested")
await client.send_packet(ClosePacket(allow_reconnect=False))
await client.connection.close(code=1000, reason="Disconnected by server")
async def add_client(
self,
connection_id: str,
connection_token: str,
protocol: Protocol,
connection: WebSocket,
) -> Client:
if connection_token in self.clients:
raise ValueError(f"Client with connection token {connection_token} already exists.")
if connection_token in self.waited_clients:
if self.waited_clients[connection_token] < time.time() - settings.signalr_negotiate_timeout:
raise TimeoutError(f"Connection {connection_id} has waited too long.")
del self.waited_clients[connection_token]
client = Client(connection_id, connection_token, connection, protocol)
self.clients[connection_token] = client
task = asyncio.create_task(client._ping())
self.tasks.add(task)
client._ping_task = task
return client
async def remove_client(self, client: Client) -> None:
if client.connection_token not in self.clients:
return
del self.clients[client.connection_token]
if client._listen_task:
client._listen_task.cancel()
if client._ping_task:
client._ping_task.cancel()
for group in self.groups.values():
group.discard(client)
await self.clean_state(client, False)
@abstractmethod
async def _clean_state(self, state: TState) -> None:
return
async def clean_state(self, client: Client, disconnected: bool) -> None:
if (state := self.state.get(client.user_id)) is None:
return
if disconnected and client.connection_token != state.connection_token:
return
try:
await self._clean_state(state)
del self.state[client.user_id]
except Exception:
...
async def on_connect(self, client: Client) -> None:
if method := getattr(self, "on_client_connect", None):
await method(client)
async def send_packet(self, client: Client, packet: Packet) -> None:
logger.trace(f"[SignalR] send to {client.connection_id} packet {packet}")
try:
await client.send_packet(packet)
except WebSocketDisconnect as e:
logger.info(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}")
await self.remove_client(client)
except RuntimeError as e:
if "disconnect message" in str(e):
logger.info(f"Client {client.connection_id} closed the connection.")
else:
logger.exception(f"RuntimeError in client {client.connection_id}: {e}")
await self.remove_client(client)
except Exception:
logger.exception(f"Error in client {client.connection_id}")
await self.remove_client(client)
async def broadcast_call(self, method: str, *args: Any) -> None:
tasks = []
for client in self.clients.values():
tasks.append(self.call_noblock(client, method, *args))
await asyncio.gather(*tasks)
async def broadcast_group_call(self, group_id: str, method: str, *args: Any) -> None:
tasks = []
for client in self.groups.get(group_id, []):
tasks.append(self.call_noblock(client, method, *args))
await asyncio.gather(*tasks)
async def _listen_client(self, client: Client) -> None:
try:
while True:
packets = await client.receive_packets()
for packet in packets:
if isinstance(packet, PingPacket):
continue
elif isinstance(packet, ClosePacket):
raise CloseConnection(
packet.error or "Connection closed by client",
packet.allow_reconnect,
True,
)
task = asyncio.create_task(self._handle_packet(client, packet))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e:
logger.info(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}")
except RuntimeError as e:
if "disconnect message" in str(e):
logger.info(f"Client {client.connection_id} closed the connection.")
else:
logger.exception(f"RuntimeError in client {client.connection_id}: {e}")
except CloseConnection as e:
if not e.from_client:
await client.send_packet(ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect))
logger.info(f"Client {client.connection_id} closed the connection: {e.message}")
except Exception:
logger.exception(f"Error in client {client.connection_id}")
await self.remove_client(client)
async def _handle_packet(self, client: Client, packet: Packet) -> None:
if isinstance(packet, PingPacket):
return
elif isinstance(packet, InvocationPacket):
args = packet.arguments or []
error = None
result = None
try:
result = await self.invoke_method(client, packet.target, args)
except InvokeException as e:
error = e.message
logger.debug(f"Client {client.connection_token} call {packet.target} failed: {error}")
except Exception:
logger.exception(f"Error invoking method {packet.target} for client {client.connection_id}")
error = "Unknown error occured in server"
if packet.invocation_id is not None:
await self.send_packet(
client,
CompletionPacket(
invocation_id=packet.invocation_id,
error=error,
result=result,
),
)
elif isinstance(packet, CompletionPacket):
client._store.add_result(packet.invocation_id, packet.result, packet.error)
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
method_ = getattr(self, method, None)
call_params = []
if not method_:
raise InvokeException(f"Method '{method}' not found in hub.")
signature = get_signature(method_)
for name, param in signature.parameters.items():
if name == "self" or param.annotation is Client:
continue
call_params.append(client.protocol.validate_object(args.pop(0), param.annotation))
return await method_(client, *call_params)
async def call(self, client: Client, method: str, *args: Any) -> Any:
invocation_id = client._store.get_invocation_id()
await self.send_packet(
client,
InvocationPacket(
header={},
invocation_id=invocation_id,
target=method,
arguments=list(args),
stream_ids=None,
),
)
r = await client._store.fetch(invocation_id, None)
if r[1]:
raise InvokeException(r[1])
return r[0]
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
await self.send_packet(
client,
InvocationPacket(
header={},
invocation_id=None,
target=method,
arguments=list(args),
stream_ids=None,
),
)
return None
def __contains__(self, item: str) -> bool:
return item in self.clients or item in self.waited_clients

View File

@@ -1,296 +0,0 @@
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Coroutine
import math
from typing import override
from app.calculator import clamp
from app.database import Relationship, RelationshipType, User
from app.database.playlist_best_score import PlaylistBestScore
from app.database.playlists import Playlist
from app.database.room import Room
from app.database.score import Score
from app.dependencies.database import with_db
from app.log import logger
from app.models.metadata_hub import (
TOTAL_SCORE_DISTRIBUTION_BINS,
DailyChallengeInfo,
MetadataClientState,
MultiplayerPlaylistItemStats,
MultiplayerRoomScoreSetEvent,
MultiplayerRoomStats,
OnlineStatus,
UserActivity,
)
from app.models.room import RoomCategory
from app.service.subscribers.score_processed import ScoreSubscriber
from app.utils import utcnow
from .hub import Client, Hub
from sqlmodel import col, select
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
class MetadataHub(Hub[MetadataClientState]):
def __init__(self) -> None:
super().__init__()
self.subscriber = ScoreSubscriber()
self.subscriber.metadata_hub = self
self._daily_challenge_stats: MultiplayerRoomStats | None = None
self._today = utcnow().date()
self._lock = asyncio.Lock()
def get_daily_challenge_stats(self, daily_challenge_room: int) -> MultiplayerRoomStats:
if self._daily_challenge_stats is None or self._today != utcnow().date():
self._daily_challenge_stats = MultiplayerRoomStats(
room_id=daily_challenge_room,
playlist_item_stats={},
)
return self._daily_challenge_stats
@staticmethod
def online_presence_watchers_group() -> str:
return ONLINE_PRESENCE_WATCHERS_GROUP
@staticmethod
def room_watcher_group(room_id: int) -> str:
return f"metadata:multiplayer-room-watchers:{room_id}"
def broadcast_tasks(self, user_id: int, store: MetadataClientState | None) -> set[Coroutine]:
if store is not None and not store.pushable:
return set()
data = store.for_push if store else None
return {
self.broadcast_group_call(
self.online_presence_watchers_group(),
"UserPresenceUpdated",
user_id,
data,
),
self.broadcast_group_call(
self.friend_presence_watchers_group(user_id),
"FriendPresenceUpdated",
user_id,
data,
),
}
@staticmethod
def friend_presence_watchers_group(user_id: int):
return f"metadata:friend-presence-watchers:{user_id}"
@override
async def _clean_state(self, state: MetadataClientState) -> None:
user_id = int(state.connection_id)
if state.pushable:
await asyncio.gather(*self.broadcast_tasks(user_id, None))
async with with_db() as session:
async with session.begin():
user = (await session.exec(select(User).where(User.id == int(state.connection_id)))).one()
user.last_visit = utcnow()
await session.commit()
@override
def create_state(self, client: Client) -> MetadataClientState:
return MetadataClientState(
connection_id=client.connection_id,
connection_token=client.connection_token,
)
async def on_client_connect(self, client: Client) -> None:
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
# CRITICAL FIX: Set online status IMMEDIATELY upon connection
# This matches the C# official implementation behavior
store.status = OnlineStatus.ONLINE
logger.info(f"[MetadataHub] Set user {user_id} status to ONLINE upon connection")
async with with_db() as session:
async with session.begin():
friends = (
await session.exec(
select(Relationship.target_id).where(
Relationship.user_id == user_id,
Relationship.type == RelationshipType.FOLLOW,
)
)
).all()
tasks = []
for friend_id in friends:
self.groups.setdefault(self.friend_presence_watchers_group(friend_id), set()).add(client)
if (friend_state := self.state.get(friend_id)) and friend_state.pushable:
tasks.append(
self.broadcast_group_call(
self.friend_presence_watchers_group(friend_id),
"FriendPresenceUpdated",
friend_id,
friend_state.for_push if friend_state.pushable else None,
)
)
await asyncio.gather(*tasks)
daily_challenge_room = (
await session.exec(
select(Room).where(
col(Room.ends_at) > utcnow(),
Room.category == RoomCategory.DAILY_CHALLENGE,
)
)
).first()
if daily_challenge_room:
await self.call_noblock(
client,
"DailyChallengeUpdated",
DailyChallengeInfo(
room_id=daily_challenge_room.id,
),
)
# CRITICAL FIX: Immediately broadcast the user's online status to all watchers
# This ensures the user appears as "currently online" right after connection
# Similar to the C# implementation's immediate broadcast logic
online_presence_tasks = self.broadcast_tasks(user_id, store)
if online_presence_tasks:
await asyncio.gather(*online_presence_tasks)
logger.info(f"[MetadataHub] Broadcasted online status for user {user_id} to watchers")
# Also send the user's own presence update to confirm online status
await self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
store.for_push,
)
logger.info(f"[MetadataHub] User {user_id} is now ONLINE and visible to other clients")
async def UpdateStatus(self, client: Client, status: int) -> None:
status_ = OnlineStatus(status)
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
if store.status is not None and store.status == status_:
return
store.status = OnlineStatus(status_)
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
store.for_push,
)
)
await asyncio.gather(*tasks)
async def UpdateActivity(self, client: Client, activity: UserActivity | None) -> None:
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
store.activity = activity
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
store.for_push,
)
)
await asyncio.gather(*tasks)
async def BeginWatchingUserPresence(self, client: Client) -> None:
# Critical fix: Send all currently online users to the new watcher
# Must use for_push to get the correct UserPresence format
await asyncio.gather(
*[
self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
store.for_push, # Fixed: use for_push instead of store
)
for user_id, store in self.state.items()
if store.pushable
]
)
self.add_to_group(client, self.online_presence_watchers_group())
logger.info(
f"[MetadataHub] Client {client.connection_id} now watching user presence, "
f"sent {len([s for s in self.state.values() if s.pushable])} online users"
)
async def EndWatchingUserPresence(self, client: Client) -> None:
self.remove_from_group(client, self.online_presence_watchers_group())
async def notify_room_score_processed(self, event: MultiplayerRoomScoreSetEvent):
await self.broadcast_group_call(self.room_watcher_group(event.room_id), "MultiplayerRoomScoreSet", event)
async def BeginWatchingMultiplayerRoom(self, client: Client, room_id: int):
self.add_to_group(client, self.room_watcher_group(room_id))
await self.subscriber.subscribe_room_score(room_id, client.user_id)
stats = self.get_daily_challenge_stats(room_id)
await self.update_daily_challenge_stats(stats)
return list(stats.playlist_item_stats.values())
async def update_daily_challenge_stats(self, stats: MultiplayerRoomStats) -> None:
async with with_db() as session:
playlist_ids = (
await session.exec(
select(Playlist.id).where(
Playlist.room_id == stats.room_id,
)
)
).all()
for playlist_id in playlist_ids:
item = stats.playlist_item_stats.get(playlist_id, None)
if item is None:
item = MultiplayerPlaylistItemStats(
playlist_item_id=playlist_id,
total_score_distribution=[0] * TOTAL_SCORE_DISTRIBUTION_BINS,
cumulative_score=0,
last_processed_score_id=0,
)
stats.playlist_item_stats[playlist_id] = item
last_processed_score_id = item.last_processed_score_id
scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == stats.room_id,
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.score_id > last_processed_score_id,
col(PlaylistBestScore.score).has(col(Score.passed).is_(True)),
)
)
).all()
if len(scores) == 0:
continue
async with self._lock:
if item.last_processed_score_id == last_processed_score_id:
totals = defaultdict(int)
for score in scores:
bin_index = int(
clamp(
math.floor(score.total_score / 100000),
0,
TOTAL_SCORE_DISTRIBUTION_BINS - 1,
)
)
totals[bin_index] += 1
item.cumulative_score += sum(score.total_score for score in scores)
for j in range(TOTAL_SCORE_DISTRIBUTION_BINS):
item.total_score_distribution[j] += totals.get(j, 0)
if scores:
item.last_processed_score_id = max(score.score_id for score in scores)
async def EndWatchingMultiplayerRoom(self, client: Client, room_id: int):
self.remove_from_group(client, self.room_watcher_group(room_id))
await self.subscriber.unsubscribe_room_score(room_id, client.user_id)

File diff suppressed because it is too large Load Diff

View File

@@ -1,585 +0,0 @@
from __future__ import annotations
import asyncio
import json
import lzma
import struct
import time
from typing import override
from app.calculator import clamp
from app.config import settings
from app.database import Beatmap, User
from app.database.failtime import FailTime, FailTimeResp
from app.database.score import Score
from app.database.score_token import ScoreToken
from app.database.statistics import UserStatistics
from app.dependencies.database import get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service
from app.exception import InvokeException
from app.log import logger
from app.models.mods import APIMod, mods_to_int
from app.models.score import GameMode, LegacyReplaySoloScoreInfo, ScoreStatistics
from app.models.spectator_hub import (
APIUser,
FrameDataBundle,
LegacyReplayFrame,
ScoreInfo,
SpectatedUserState,
SpectatorState,
StoreClientState,
StoreScore,
)
from app.utils import unix_timestamp_to_windows
from .hub import Client, Hub
from httpx import HTTPError
from sqlalchemy.orm import joinedload
from sqlmodel import select
READ_SCORE_TIMEOUT = 30
REPLAY_LATEST_VER = 30000016
def encode_uleb128(num: int) -> bytes | bytearray:
if num == 0:
return b"\x00"
ret = bytearray()
while num != 0:
ret.append(num & 0x7F)
num >>= 7
if num != 0:
ret[-1] |= 0x80
return ret
def encode_string(s: str) -> bytes:
"""Write `s` into bytes (ULEB128 & string)."""
if s:
encoded = s.encode()
ret = b"\x0b" + encode_uleb128(len(encoded)) + encoded
else:
ret = b"\x00"
return ret
async def save_replay(
ruleset_id: int,
md5: str,
username: str,
score: Score,
statistics: ScoreStatistics,
maximum_statistics: ScoreStatistics,
frames: list[LegacyReplayFrame],
) -> None:
data = bytearray()
data.extend(struct.pack("<bi", ruleset_id, REPLAY_LATEST_VER))
data.extend(encode_string(md5))
data.extend(encode_string(username))
data.extend(encode_string(f"lazer-{username}-{score.started_at.isoformat()}"))
data.extend(
struct.pack(
"<hhhhhhihbi",
score.n300,
score.n100,
score.n50,
score.ngeki,
score.nkatu,
score.nmiss,
score.total_score,
score.max_combo,
score.is_perfect_combo,
mods_to_int(score.mods),
)
)
data.extend(encode_string("")) # hp graph
data.extend(
struct.pack(
"<q",
unix_timestamp_to_windows(round(score.started_at.timestamp())),
)
)
# write frames
frame_strs = []
last_time = 0
for frame in frames:
time = round(frame.time)
frame_strs.append(f"{time - last_time}|{frame.mouse_x or 0.0}|{frame.mouse_y or 0.0}|{frame.button_state}")
last_time = time
frame_strs.append("-12345|0|0|0")
compressed = lzma.compress(",".join(frame_strs).encode("ascii"), format=lzma.FORMAT_ALONE)
data.extend(struct.pack("<i", len(compressed)))
data.extend(compressed)
data.extend(struct.pack("<q", score.id))
score_info = LegacyReplaySoloScoreInfo(
online_id=score.id,
mods=score.mods,
statistics=statistics,
maximum_statistics=maximum_statistics,
client_version="",
rank=score.rank,
user_id=score.user_id,
total_score_without_mods=score.total_score_without_mods,
)
compressed = lzma.compress(json.dumps(score_info).encode(), format=lzma.FORMAT_ALONE)
data.extend(struct.pack("<i", len(compressed)))
data.extend(compressed)
storage_service = get_storage_service()
replay_path = score.replay_filename
await storage_service.write_file(replay_path, bytes(data), "application/x-osu-replay")
class SpectatorHub(Hub[StoreClientState]):
@staticmethod
def group_id(user_id: int) -> str:
return f"watch:{user_id}"
@override
def create_state(self, client: Client) -> StoreClientState:
return StoreClientState(
connection_id=client.connection_id,
connection_token=client.connection_token,
)
@override
async def _clean_state(self, state: StoreClientState) -> None:
"""
Enhanced cleanup based on official osu-server-spectator implementation.
Properly notifies watched users when spectator disconnects.
"""
user_id = int(state.connection_id)
if state.state:
await self._end_session(user_id, state.state, state)
# Critical fix: Notify all watched users that this spectator has disconnected
# This matches the official CleanUpState implementation
for watched_user_id in state.watched_user:
if (target_client := self.get_client_by_id(str(watched_user_id))) is not None:
await self.call_noblock(target_client, "UserEndedWatching", user_id)
logger.debug(f"[SpectatorHub] Notified {watched_user_id} that {user_id} stopped watching")
async def on_client_connect(self, client: Client) -> None:
"""
Enhanced connection handling based on official implementation.
Send all active player states to newly connected clients.
"""
logger.info(f"[SpectatorHub] Client {client.user_id} connected")
# Send all current player states to the new client
# This matches the official OnConnectedAsync behavior
active_states = []
for user_id, store in self.state.items():
if store.state is not None:
active_states.append((user_id, store.state))
if active_states:
logger.debug(f"[SpectatorHub] Sending {len(active_states)} active player states to {client.user_id}")
# Send states sequentially to avoid overwhelming the client
for user_id, state in active_states:
try:
await self.call_noblock(client, "UserBeganPlaying", user_id, state)
except Exception as e:
logger.debug(f"[SpectatorHub] Failed to send state for user {user_id}: {e}")
# Also sync with MultiplayerHub for cross-hub spectating
await self._sync_with_multiplayer_hub(client)
async def _sync_with_multiplayer_hub(self, client: Client) -> None:
"""
Sync with MultiplayerHub to get active multiplayer game states.
This ensures spectators can see multiplayer games from other pages.
"""
try:
# Import here to avoid circular imports
from app.signalr.hub import MultiplayerHubs
# Check all active multiplayer rooms for playing users
for room_id, server_room in MultiplayerHubs.rooms.items():
for room_user in server_room.room.users:
# Send state for users who are playing or in results
if room_user.state.is_playing and room_user.user_id not in self.state:
# Create a synthetic SpectatorState for multiplayer players
# This helps with cross-hub spectating
try:
synthetic_state = SpectatorState(
beatmap_id=server_room.queue.current_item.beatmap_id,
ruleset_id=room_user.ruleset_id or 0, # Default to osu!
mods=room_user.mods,
state=SpectatedUserState.Playing,
maximum_statistics={},
)
await self.call_noblock(
client,
"UserBeganPlaying",
room_user.user_id,
synthetic_state,
)
logger.debug(
f"[SpectatorHub] Sent synthetic multiplayer state for user {room_user.user_id}"
)
except Exception as e:
logger.debug(f"[SpectatorHub] Failed to create synthetic state: {e}")
# Critical addition: Notify about finished players in multiplayer games
elif (
hasattr(room_user.state, "name")
and room_user.state.name == "RESULTS"
and room_user.user_id not in self.state
):
try:
# Create a synthetic finished state
finished_state = SpectatorState(
beatmap_id=server_room.queue.current_item.beatmap_id,
ruleset_id=room_user.ruleset_id or 0,
mods=room_user.mods,
state=SpectatedUserState.Passed, # Assume passed for results
maximum_statistics={},
)
await self.call_noblock(
client,
"UserFinishedPlaying",
room_user.user_id,
finished_state,
)
logger.debug(f"[SpectatorHub] Sent synthetic finished state for user {room_user.user_id}")
except Exception as e:
logger.debug(f"[SpectatorHub] Failed to create synthetic finished state: {e}")
except Exception as e:
logger.debug(f"[SpectatorHub] Failed to sync with MultiplayerHub: {e}")
# This is not critical, so we don't raise the exception
async def BeginPlaySession(self, client: Client, score_token: int, state: SpectatorState) -> None:
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
if store.state is not None:
logger.warning(f"[SpectatorHub] User {user_id} began new session without ending previous one; cleaning up")
try:
await self._end_session(user_id, store.state, store)
finally:
store.state = None
store.beatmap_status = None
store.checksum = None
store.ruleset_id = None
store.score_token = None
store.score = None
if state.beatmap_id is None or state.ruleset_id is None:
return
fetcher = await get_fetcher()
async with with_db() as session:
async with session.begin():
try:
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=state.beatmap_id)
except HTTPError:
raise InvokeException(f"Beatmap {state.beatmap_id} not found.")
user = (await session.exec(select(User).where(User.id == user_id))).first()
if not user:
return
name = user.username
store.state = state
store.beatmap_status = beatmap.beatmap_status
store.checksum = beatmap.checksum
store.ruleset_id = state.ruleset_id
store.score_token = score_token
store.score = StoreScore(
score_info=ScoreInfo(
mods=state.mods,
user=APIUser(id=user_id, name=name),
ruleset=state.ruleset_id,
maximum_statistics=state.maximum_statistics,
)
)
logger.info(f"[SpectatorHub] {client.user_id} began playing {state.beatmap_id}")
await self.broadcast_group_call(
self.group_id(user_id),
"UserBeganPlaying",
user_id,
state,
)
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
if store.state is None or store.score is None:
return
header = frame_data.header
score_info = store.score.score_info
score_info.accuracy = header.accuracy
score_info.combo = header.combo
score_info.max_combo = header.max_combo
score_info.statistics = header.statistics
store.score.replay_frames.extend(frame_data.frames)
await self.broadcast_group_call(self.group_id(user_id), "UserSentFrames", user_id, frame_data)
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
score = store.score
# Early return if no active session
if (
score is None
or store.score_token is None
or store.beatmap_status is None
or store.state is None
or store.score is None
):
return
try:
# Process score if conditions are met
if (settings.enable_all_beatmap_leaderboard and store.beatmap_status.has_leaderboard()) and any(
k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()
):
await self._process_score(store, client)
# End the play session and notify watchers
await self._end_session(user_id, state, store)
finally:
# CRITICAL FIX: Always clear state in finally block to ensure cleanup
# This matches the official C# implementation pattern
store.state = None
store.beatmap_status = None
store.checksum = None
store.ruleset_id = None
store.score_token = None
store.score = None
logger.info(f"[SpectatorHub] Cleared all session state for user {user_id}")
async def _process_score(self, store: StoreClientState, client: Client) -> None:
user_id = int(client.connection_id)
assert store.state is not None
assert store.score_token is not None
assert store.checksum is not None
assert store.ruleset_id is not None
assert store.score is not None
async with with_db() as session:
async with session:
start_time = time.time()
score_record = None
while time.time() - start_time < READ_SCORE_TIMEOUT:
sub_query = select(ScoreToken.score_id).where(
ScoreToken.id == store.score_token,
)
result = await session.exec(
select(Score)
.options(joinedload(Score.beatmap))
.where(
Score.id == sub_query.scalar_subquery(),
Score.user_id == user_id,
)
)
score_record = result.first()
if score_record:
break
if not score_record:
return
if not score_record.passed:
return
await self.call_noblock(
client,
"UserScoreProcessed",
user_id,
score_record.id,
)
# save replay
score_record.has_replay = True
await session.commit()
await session.refresh(score_record)
await save_replay(
ruleset_id=store.ruleset_id,
md5=store.checksum,
username=store.score.score_info.user.name,
score=score_record,
statistics=store.score.score_info.statistics,
maximum_statistics=store.score.score_info.maximum_statistics,
frames=store.score.replay_frames,
)
async def _end_session(self, user_id: int, state: SpectatorState, store: StoreClientState) -> None:
async def _add_failtime():
async with with_db() as session:
failtime = await session.get(FailTime, state.beatmap_id)
total_length = (
await session.exec(select(Beatmap.total_length).where(Beatmap.id == state.beatmap_id))
).one()
index = clamp(round((exit_time / total_length) * 100), 0, 99)
if failtime is not None:
resp = FailTimeResp.from_db(failtime)
else:
resp = FailTimeResp()
if state.state == SpectatedUserState.Failed:
resp.fail[index] += 1
elif state.state == SpectatedUserState.Quit:
resp.exit[index] += 1
assert state.beatmap_id
new_failtime = FailTime.from_resp(state.beatmap_id, resp)
if failtime is not None:
await session.merge(new_failtime)
else:
session.add(new_failtime)
await session.commit()
async def _edit_playtime(token: int, ruleset_id: int, mods: list[APIMod]):
redis = get_redis()
key = f"score:existed_time:{token}"
messages = await redis.xrange(key, min="-", max="+", count=1)
if not messages:
return
before_time = int(messages[0][1]["time"])
await redis.delete(key)
async with with_db() as session:
gamemode = GameMode.from_int(ruleset_id).to_special_mode(mods)
statistics = (
await session.exec(
select(UserStatistics).where(
UserStatistics.user_id == user_id,
UserStatistics.mode == gamemode,
)
)
).first()
if statistics is None:
return
statistics.play_time -= before_time
statistics.play_time += round(min(before_time, exit_time))
if state.state == SpectatedUserState.Playing:
state.state = SpectatedUserState.Quit
logger.debug(f"[SpectatorHub] Changed state from Playing to Quit for user {user_id}")
# Calculate exit time safely
exit_time = 0
if store.score and store.score.replay_frames:
exit_time = max(frame.time for frame in store.score.replay_frames) // 1000
# Background task for playtime editing - only if we have valid data
if store.score_token and store.ruleset_id and store.score:
task = asyncio.create_task(
_edit_playtime(
store.score_token,
store.ruleset_id,
store.score.score_info.mods,
)
)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
# Background task for failtime tracking - only for failed/quit states with valid data
if (
state.beatmap_id is not None
and exit_time > 0
and state.state in (SpectatedUserState.Failed, SpectatedUserState.Quit)
):
task = asyncio.create_task(_add_failtime())
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
logger.info(f"[SpectatorHub] {user_id} finished playing {state.beatmap_id} with {state.state}")
await self.broadcast_group_call(
self.group_id(user_id),
"UserFinishedPlaying",
user_id,
state,
)
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
"""
Enhanced StartWatchingUser based on official osu-server-spectator implementation.
Properly handles state synchronization and watcher notifications.
"""
user_id = int(client.connection_id)
logger.info(f"[SpectatorHub] {user_id} started watching {target_id}")
try:
# Get target user's current state if it exists
target_store = self.state.get(target_id)
if not target_store or not target_store.state:
logger.info(f"[SpectatorHub] Rejecting watch request for {target_id}: user not playing")
raise InvokeException("Target user is not currently playing")
if target_store.state.state != SpectatedUserState.Playing:
logger.info(
f"[SpectatorHub] Rejecting watch request for {target_id}: state is {target_store.state.state}"
)
raise InvokeException("Target user is not currently playing")
logger.debug(f"[SpectatorHub] {target_id} is currently playing, sending state")
# Send current state to the watcher immediately
await self.call_noblock(
client,
"UserBeganPlaying",
target_id,
target_store.state,
)
except InvokeException:
# Re-raise to inform caller without adding to group
raise
except Exception as e:
# User isn't tracked or error occurred - this is not critical
logger.debug(f"[SpectatorHub] Could not get state for {target_id}: {e}")
raise InvokeException("Target user is not currently playing") from e
# Add watcher to our tracked users only after validation
store = self.get_or_create_state(client)
store.watched_user.add(target_id)
# Add to SignalR group for this target user
self.add_to_group(client, self.group_id(target_id))
# Get watcher's username and notify the target user
try:
async with with_db() as session:
username = (await session.exec(select(User.username).where(User.id == user_id))).first()
if not username:
logger.warning(f"[SpectatorHub] Could not find username for user {user_id}")
return
# Notify target user that someone started watching
if (target_client := self.get_client_by_id(str(target_id))) is not None:
# Create watcher info array (matches official format)
watcher_info = [[user_id, username]]
await self.call_noblock(target_client, "UserStartedWatching", watcher_info)
logger.debug(f"[SpectatorHub] Notified {target_id} that {username} started watching")
except Exception as e:
logger.error(f"[SpectatorHub] Error notifying target user {target_id}: {e}")
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
"""
Enhanced EndWatchingUser based on official osu-server-spectator implementation.
Properly cleans up watcher state and notifies target user.
"""
user_id = int(client.connection_id)
logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}")
# Remove from SignalR group
self.remove_from_group(client, self.group_id(target_id))
# Remove from our tracked watched users
store = self.get_or_create_state(client)
store.watched_user.discard(target_id)
# Notify target user that watcher stopped watching
if (target_client := self.get_client_by_id(str(target_id))) is not None:
await self.call_noblock(target_client, "UserEndedWatching", user_id)
logger.debug(f"[SpectatorHub] Notified {target_id} that {user_id} stopped watching")
else:
logger.debug(f"[SpectatorHub] Target user {target_id} not found for end watching notification")

View File

@@ -1,492 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
import datetime
from enum import Enum, IntEnum
import inspect
import json
from types import NoneType, UnionType
from typing import (
Any,
Protocol as TypingProtocol,
Union,
get_args,
get_origin,
)
from app.models.signalr import SignalRMeta, SignalRUnionMessage
from app.utils import camel_to_snake, snake_to_camel, snake_to_pascal
import msgpack_lazer_api as m
from pydantic import BaseModel
SEP = b"\x1e"
class PacketType(IntEnum):
INVOCATION = 1
STREAM_ITEM = 2
COMPLETION = 3
STREAM_INVOCATION = 4
CANCEL_INVOCATION = 5
PING = 6
CLOSE = 7
@dataclass(kw_only=True)
class Packet:
type: PacketType
header: dict[str, Any] | None = None
@dataclass(kw_only=True)
class InvocationPacket(Packet):
type: PacketType = PacketType.INVOCATION
invocation_id: str | None
target: str
arguments: list[Any] | None = None
stream_ids: list[str] | None = None
@dataclass(kw_only=True)
class CompletionPacket(Packet):
type: PacketType = PacketType.COMPLETION
invocation_id: str
result: Any
error: str | None = None
@dataclass(kw_only=True)
class PingPacket(Packet):
type: PacketType = PacketType.PING
@dataclass(kw_only=True)
class ClosePacket(Packet):
type: PacketType = PacketType.CLOSE
error: str | None = None
allow_reconnect: bool = False
PACKETS = {
PacketType.INVOCATION: InvocationPacket,
PacketType.COMPLETION: CompletionPacket,
PacketType.PING: PingPacket,
PacketType.CLOSE: ClosePacket,
}
class Protocol(TypingProtocol):
@staticmethod
def decode(input: bytes) -> list[Packet]: ...
@staticmethod
def encode(packet: Packet) -> bytes: ...
@classmethod
def validate_object(cls, v: Any, typ: type) -> Any: ...
class MsgpackProtocol:
@classmethod
def serialize_msgpack(cls, v: Any) -> Any:
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_to_list(v)
elif issubclass(typ, list):
return [cls.serialize_msgpack(item) for item in v]
elif issubclass(typ, datetime.datetime):
return [v, 0]
elif issubclass(typ, datetime.timedelta):
return int(v.total_seconds() * 10_000_000)
elif isinstance(v, dict):
return {cls.serialize_msgpack(k): cls.serialize_msgpack(value) for k, value in v.items()}
elif issubclass(typ, Enum):
list_ = list(typ)
return list_.index(v) if v in list_ else v.value
return v
@classmethod
def serialize_to_list(cls, value: BaseModel) -> list[Any]:
values = []
for field, info in value.__class__.model_fields.items():
metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None)
if metadata and metadata.member_ignore:
continue
values.append(cls.serialize_msgpack(v=getattr(value, field)))
if issubclass(value.__class__, SignalRUnionMessage):
return [value.__class__.union_type, values]
else:
return values
@staticmethod
def process_object(v: Any, typ: type[BaseModel]) -> Any:
if isinstance(v, list):
d = {}
i = 0
for field, info in typ.model_fields.items():
metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None)
if metadata and metadata.member_ignore:
continue
anno = info.annotation
if anno is None:
d[camel_to_snake(field)] = v[i]
else:
d[field] = MsgpackProtocol.validate_object(v[i], anno)
i += 1
return d
return v
@staticmethod
def _encode_varint(value: int) -> bytes:
result = []
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
@staticmethod
def _decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
result = 0
shift = 0
pos = offset
while pos < len(data):
byte = data[pos]
result |= (byte & 0x7F) << shift
pos += 1
if (byte & 0x80) == 0:
break
shift += 7
return result, pos
@staticmethod
def decode(input: bytes) -> list[Packet]:
length, offset = MsgpackProtocol._decode_varint(input)
message_data = input[offset : offset + length]
unpacked = m.decode(message_data)
packet_type = PacketType(unpacked[0])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return [
InvocationPacket(
header=unpacked[1],
invocation_id=unpacked[2],
target=unpacked[3],
arguments=unpacked[4] if len(unpacked) > 4 else None,
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
)
]
case PacketType.COMPLETION:
result_kind = unpacked[3]
return [
CompletionPacket(
header=unpacked[1],
invocation_id=unpacked[2],
error=unpacked[4] if result_kind == 1 else None,
result=unpacked[5] if result_kind == 3 else None,
)
]
case PacketType.PING:
return [PingPacket()]
case PacketType.CLOSE:
return [
ClosePacket(
error=unpacked[1],
allow_reconnect=unpacked[2] if len(unpacked) > 2 else False,
)
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@classmethod
def validate_object(cls, v: Any, typ: type) -> Any:
if issubclass(typ, BaseModel):
return typ.model_validate(obj=cls.process_object(v, typ))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return v[0]
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
return datetime.timedelta(seconds=int(v / 10_000_000))
elif get_origin(typ) is list:
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
elif get_origin(typ) is dict:
return {
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(v, get_args(typ)[1]) for k, v in v.items()
}
elif (origin := get_origin(typ)) is Union or origin is UnionType:
args = get_args(typ)
if len(args) == 2 and NoneType in args:
non_none_args = [arg for arg in args if arg is not NoneType]
if len(non_none_args) == 1:
if v is None:
return None
return cls.validate_object(v, non_none_args[0])
# suppose use `MessagePack-CSharp Union | None`
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args):
raise ValueError(f"Cannot validate {v} to {typ}, only SignalRUnionMessage subclasses are supported")
union_type = v[0]
for arg in args:
assert issubclass(arg, SignalRUnionMessage)
if arg.union_type == union_type:
return cls.validate_object(v[1], arg)
return v
@staticmethod
def encode(packet: Packet) -> bytes:
payload = [packet.type.value, packet.header or {}]
if isinstance(packet, InvocationPacket):
payload.extend(
[
packet.invocation_id,
packet.target,
]
)
if packet.arguments is not None:
payload.append([MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments])
if packet.stream_ids is not None:
payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket):
result_kind = 2
if packet.error:
result_kind = 1
elif packet.result is not None:
result_kind = 3
payload.extend(
[
packet.invocation_id,
result_kind,
packet.error or MsgpackProtocol.serialize_msgpack(packet.result) or None,
]
)
elif isinstance(packet, ClosePacket):
payload.extend(
[
packet.error or "",
packet.allow_reconnect,
]
)
elif isinstance(packet, PingPacket):
payload.pop(-1)
data = m.encode(payload)
return MsgpackProtocol._encode_varint(len(data)) + data
class JSONProtocol:
@classmethod
def serialize_to_json(cls, v: Any, dict_key: bool = False, in_union: bool = False):
typ = v.__class__
if issubclass(typ, BaseModel):
return cls.serialize_model(v, in_union)
elif isinstance(v, dict):
return {cls.serialize_to_json(k, True): cls.serialize_to_json(value) for k, value in v.items()}
elif isinstance(v, list):
return [cls.serialize_to_json(item) for item in v]
elif isinstance(v, datetime.datetime):
return v.isoformat()
elif isinstance(v, datetime.timedelta):
# d.hh:mm:ss
total_seconds = int(v.total_seconds())
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{hours:02}:{minutes:02}:{seconds:02}"
elif isinstance(v, Enum) and dict_key:
return v.value
elif isinstance(v, Enum):
list_ = list(typ)
return list_.index(v)
return v
@classmethod
def serialize_model(cls, v: BaseModel, in_union: bool = False) -> dict[str, Any]:
d = {}
is_union = issubclass(v.__class__, SignalRUnionMessage)
for field, info in v.__class__.model_fields.items():
metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None)
if metadata and metadata.json_ignore:
continue
name = (
snake_to_camel(
field,
metadata.use_abbr if metadata else True,
)
if not is_union
else snake_to_pascal(
field,
metadata.use_abbr if metadata else True,
)
)
d[name] = cls.serialize_to_json(getattr(v, field), in_union=is_union)
if is_union and not in_union:
return {
"$dtype": v.__class__.__name__,
"$value": d,
}
return d
@staticmethod
def process_object(v: Any, typ: type[BaseModel], from_union: bool = False) -> dict[str, Any]:
d = {}
for field, info in typ.model_fields.items():
metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None)
if metadata and metadata.json_ignore:
continue
name = (
snake_to_camel(field, metadata.use_abbr if metadata else True)
if not from_union
else snake_to_pascal(field, metadata.use_abbr if metadata else True)
)
value = v.get(name)
anno = typ.model_fields[field].annotation
if anno is None:
d[field] = value
continue
d[field] = JSONProtocol.validate_object(value, anno)
return d
@staticmethod
def decode(input: bytes) -> list[Packet]:
packets_raw = input.removesuffix(SEP).split(SEP)
packets = []
if len(packets_raw) > 1:
for packet_raw in packets_raw:
packets.extend(JSONProtocol.decode(packet_raw))
return packets
else:
data = json.loads(packets_raw[0])
packet_type = PacketType(data["type"])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return [
InvocationPacket(
header=data.get("header"),
invocation_id=data.get("invocationId"),
target=data["target"],
arguments=data.get("arguments"),
stream_ids=data.get("streamIds"),
)
]
case PacketType.COMPLETION:
return [
CompletionPacket(
header=data.get("header"),
invocation_id=data["invocationId"],
error=data.get("error"),
result=data.get("result"),
)
]
case PacketType.PING:
return [PingPacket()]
case PacketType.CLOSE:
return [
ClosePacket(
error=data.get("error"),
allow_reconnect=data.get("allowReconnect", False),
)
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@classmethod
def validate_object(cls, v: Any, typ: type, from_union: bool = False) -> Any:
if issubclass(typ, BaseModel):
return typ.model_validate(JSONProtocol.process_object(v, typ, from_union))
elif inspect.isclass(typ) and issubclass(typ, datetime.datetime):
return datetime.datetime.fromisoformat(v)
elif inspect.isclass(typ) and issubclass(typ, datetime.timedelta):
# d.hh:mm:ss
parts = v.split(":")
if len(parts) == 3:
return datetime.timedelta(hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2]))
elif len(parts) == 2:
return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1]))
elif len(parts) == 1:
return datetime.timedelta(seconds=int(parts[0]))
elif get_origin(typ) is list:
return [cls.validate_object(item, get_args(typ)[0]) for item in v]
elif inspect.isclass(typ) and issubclass(typ, Enum):
list_ = list(typ)
return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v)
elif get_origin(typ) is dict:
return {
cls.validate_object(k, get_args(typ)[0]): cls.validate_object(v, get_args(typ)[1]) for k, v in v.items()
}
elif (origin := get_origin(typ)) is Union or origin is UnionType:
args = get_args(typ)
if len(args) == 2 and NoneType in args:
non_none_args = [arg for arg in args if arg is not NoneType]
if len(non_none_args) == 1:
if v is None:
return None
return cls.validate_object(v, non_none_args[0])
# suppose use `MessagePack-CSharp Union | None`
# except `X (Other Type) | None`
if NoneType in args and v is None:
return None
if not all(issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args):
raise ValueError(f"Cannot validate {v} to {typ}, only SignalRUnionMessage subclasses are supported")
# https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs
union_type = v["$dtype"]
for arg in args:
assert issubclass(arg, SignalRUnionMessage)
if arg.__name__ == union_type:
return cls.validate_object(v["$value"], arg, True)
return v
@staticmethod
def encode(packet: Packet) -> bytes:
payload: dict[str, Any] = {
"type": packet.type.value,
}
if packet.header:
payload["header"] = packet.header
if isinstance(packet, InvocationPacket):
payload.update(
{
"target": packet.target,
}
)
if packet.invocation_id is not None:
payload["invocationId"] = packet.invocation_id
if packet.arguments is not None:
payload["arguments"] = [JSONProtocol.serialize_to_json(arg) for arg in packet.arguments]
if packet.stream_ids is not None:
payload["streamIds"] = packet.stream_ids
elif isinstance(packet, CompletionPacket):
payload.update(
{
"invocationId": packet.invocation_id,
}
)
if packet.error is not None:
payload["error"] = packet.error
if packet.result is not None:
payload["result"] = JSONProtocol.serialize_to_json(packet.result)
elif isinstance(packet, PingPacket):
pass
elif isinstance(packet, ClosePacket):
payload.update(
{
"allowReconnect": packet.allow_reconnect,
}
)
if packet.error is not None:
payload["error"] = packet.error
return json.dumps(payload).encode("utf-8") + SEP
PROTOCOLS: dict[str, Protocol] = {
"json": JSONProtocol,
"messagepack": MsgpackProtocol,
}

View File

@@ -1,119 +0,0 @@
from __future__ import annotations
import asyncio
import json
import time
from typing import Literal
import uuid
from app.database import User as DBUser
from app.dependencies.database import DBFactory, get_db_factory
from app.dependencies.user import get_current_user, get_current_user_and_token
from app.log import logger
from app.models.signalr import NegotiateResponse, Transport
from .hub import Hubs
from .packet import PROTOCOLS, SEP
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
from fastapi.security import SecurityScopes
router = APIRouter(prefix="/signalr", include_in_schema=False)
logger.warning(
"The Python version of SignalR server is deprecated. "
"Maybe it will be removed or be fixed to continuously use in the future"
)
@router.post("/{hub}/negotiate", response_model=NegotiateResponse)
async def negotiate(
hub: Literal["spectator", "multiplayer", "metadata"],
negotiate_version: int = Query(1, alias="negotiateVersion"),
user: DBUser = Depends(get_current_user),
):
connectionId = str(user.id)
connectionToken = f"{connectionId}:{uuid.uuid4()}"
Hubs[hub].add_waited_client(
connection_token=connectionToken,
timestamp=int(time.time()),
)
return NegotiateResponse(
connectionId=connectionId,
connectionToken=connectionToken,
negotiateVersion=negotiate_version,
availableTransports=[Transport(transport="WebSockets")],
)
@router.websocket("/{hub}")
async def connect(
hub: Literal["spectator", "multiplayer", "metadata"],
websocket: WebSocket,
id: str,
authorization: str = Header(...),
factory: DBFactory = Depends(get_db_factory),
):
token = authorization[7:]
user_id = id.split(":")[0]
hub_ = Hubs[hub]
if id not in hub_:
await websocket.close(code=1008)
return
try:
async for session in factory():
if (
user_and_token := await get_current_user_and_token(
session, SecurityScopes(scopes=["*"]), token_pw=token
)
) is None or str(user_and_token[0].id) != user_id:
await websocket.close(code=1008)
return
except HTTPException:
await websocket.close(code=1008)
return
await websocket.accept()
# handshake
handshake = await websocket.receive()
message = handshake.get("bytes") or handshake.get("text")
if not message:
await websocket.close(code=1008)
return
handshake_payload = json.loads(message[:-1])
error = ""
protocol = handshake_payload.get("protocol", "json")
client = None
try:
client = await hub_.add_client(
connection_id=user_id,
connection_token=id,
connection=websocket,
protocol=PROTOCOLS[protocol],
)
except KeyError:
error = f"Protocol '{protocol}' is not supported."
except TimeoutError:
error = f"Connection {id} has waited too long."
except ValueError as e:
error = str(e)
payload = {"error": error} if error else {}
# finish handshake
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
if error or not client:
await websocket.close(code=1008)
return
connected_clients = hub_.get_before_clients(user_id, id)
for connected_client in connected_clients:
await hub_.kick_client(connected_client)
await hub_.clean_state(client, False)
task = asyncio.create_task(hub_.on_connect(client))
hub_.tasks.add(task)
task.add_done_callback(hub_.tasks.discard)
await hub_._listen_client(client)
try:
await websocket.close()
except Exception:
...

View File

@@ -1,37 +0,0 @@
from __future__ import annotations
import asyncio
import sys
from typing import Any
class ResultStore:
def __init__(self) -> None:
self._seq: int = 1
self._futures: dict[str, asyncio.Future] = {}
@property
def current_invocation_id(self) -> int:
return self._seq
def get_invocation_id(self) -> str:
s = self._seq
self._seq = (self._seq + 1) % sys.maxsize
return str(s)
def add_result(self, invocation_id: str, result: Any, error: str | None = None) -> None:
if isinstance(invocation_id, str) and invocation_id.isdecimal():
if future := self._futures.get(invocation_id):
future.set_result((result, error))
async def fetch(
self,
invocation_id: str,
timeout: float | None, # noqa: ASYNC109
) -> tuple[Any, str | None]:
future = asyncio.get_event_loop().create_future()
self._futures[invocation_id] = future
try:
return await asyncio.wait_for(future, timeout)
finally:
del self._futures[invocation_id]

View File

@@ -1,42 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
import inspect
import sys
from typing import Any, ForwardRef, cast
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L61-L75
if sys.version_info < (3, 12, 4):
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
else:
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
def get_annotation(param: inspect.Parameter, globalns: dict[str, Any]) -> Any:
annotation = param.annotation
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
try:
annotation = evaluate_forwardref(annotation, globalns, globalns)
except Exception:
return inspect.Parameter.empty
return annotation
def get_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_annotation(param, globalns),
)
for param in signature.parameters.values()
]
return inspect.Signature(typed_params)

View File

@@ -14,7 +14,6 @@ from app.database.user import User
from app.dependencies.database import get_redis, with_db
from app.dependencies.scheduler import get_scheduler
from app.log import logger
from app.models.metadata_hub import DailyChallengeInfo
from app.models.mods import APIMod, get_available_mods
from app.models.room import RoomCategory
from app.service.room import create_playlist_room
@@ -54,8 +53,6 @@ async def create_daily_challenge_room(
@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="daily_challenge")
async def daily_challenge_job():
from app.signalr.hub import MetadataHubs
now = utcnow()
redis = get_redis()
key = f"daily_challenge:{now.date()}"
@@ -108,7 +105,6 @@ async def daily_challenge_job():
allowed_mods=allowed_mods_list,
duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60),
)
await MetadataHubs.broadcast_call("DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id))
logger.success(f"Added today's daily challenge: {beatmap=}, {ruleset_id=}, {required_mods=}")
return
except (ValueError, json.JSONDecodeError) as e: