from __future__ import annotations from abc import ABC, abstractmethod import asyncio from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import UTC, datetime, timedelta from enum import IntEnum from typing import ( TYPE_CHECKING, Annotated, Any, ClassVar, Literal, TypedDict, cast, override, ) from app.database.beatmap import Beatmap from app.dependencies.database import engine from app.dependencies.fetcher import get_fetcher from app.exception import InvokeException from .mods import APIMod from .room import ( DownloadState, MatchType, MultiplayerRoomState, MultiplayerUserState, QueueMode, RoomCategory, RoomStatus, ) from .signalr import ( SignalRMeta, SignalRUnionMessage, UserState, ) from pydantic import BaseModel, Field from sqlalchemy import update from sqlmodel import col from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.signalr.hub import MultiplayerHub HOST_LIMIT = 50 PER_USER_LIMIT = 3 class MultiplayerClientState(UserState): room_id: int = 0 class MultiplayerRoomSettings(BaseModel): name: str = "Unnamed Room" playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)] password: str = "" match_type: MatchType = MatchType.HEAD_TO_HEAD queue_mode: QueueMode = QueueMode.HOST_ONLY auto_start_duration: timedelta = timedelta(seconds=0) auto_skip: bool = False @property def auto_start_enabled(self) -> bool: return self.auto_start_duration != timedelta(seconds=0) class BeatmapAvailability(BaseModel): state: DownloadState = DownloadState.UNKNOWN download_progress: float | None = None class _MatchUserState(SignalRUnionMessage): ... class TeamVersusUserState(_MatchUserState): team_id: int union_type: ClassVar[Literal[0]] = 0 MatchUserState = TeamVersusUserState class _MatchRoomState(SignalRUnionMessage): ... class MultiplayerTeam(BaseModel): id: int name: str class TeamVersusRoomState(_MatchRoomState): teams: list[MultiplayerTeam] = Field( default_factory=lambda: [ MultiplayerTeam(id=0, name="Team Red"), MultiplayerTeam(id=1, name="Team Blue"), ] ) union_type: ClassVar[Literal[0]] = 0 MatchRoomState = TeamVersusRoomState class PlaylistItem(BaseModel): id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)] owner_id: int beatmap_id: int beatmap_checksum: str ruleset_id: int required_mods: list[APIMod] = Field(default_factory=list) allowed_mods: list[APIMod] = Field(default_factory=list) expired: bool playlist_order: int played_at: datetime | None = None star_rating: float freestyle: bool def _get_api_mods(self): from app.models.mods import API_MODS, init_mods if not API_MODS: init_mods() return API_MODS def _validate_mod_for_ruleset( self, mod: APIMod, ruleset_key: int, context: str = "mod" ) -> None: from typing import Literal, cast API_MODS = self._get_api_mods() typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) # Check if mod is valid for ruleset if ( typed_ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[typed_ruleset_key] ): raise InvokeException( f"{context} {mod['acronym']} is invalid for this ruleset" ) mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]] # Check if mod is unplayable in multiplayer if mod_settings.get("UserPlayable", True) is False: raise InvokeException( f"{context} {mod['acronym']} is not playable by users" ) if mod_settings.get("ValidForMultiplayer", True) is False: raise InvokeException( f"{context} {mod['acronym']} is not valid for multiplayer" ) def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None: from typing import Literal, cast API_MODS = self._get_api_mods() typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) for i, mod1 in enumerate(mods): mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"]) if mod1_settings: incompatible = set(mod1_settings.get("IncompatibleMods", [])) for mod2 in mods[i + 1 :]: if mod2["acronym"] in incompatible: raise InvokeException( f"Mods {mod1['acronym']} and " f"{mod2['acronym']} are incompatible" ) def _check_required_allowed_compatibility(self, ruleset_key: int) -> None: from typing import Literal, cast API_MODS = self._get_api_mods() typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} for req_mod in self.required_mods: req_acronym = req_mod["acronym"] req_settings = API_MODS[typed_ruleset_key].get(req_acronym) if req_settings: incompatible = set(req_settings.get("IncompatibleMods", [])) conflicting_allowed = allowed_acronyms & incompatible if conflicting_allowed: conflict_list = ", ".join(conflicting_allowed) raise InvokeException( f"Required mod {req_acronym} conflicts with " f"allowed mods: {conflict_list}" ) def validate_playlist_item_mods(self) -> None: ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id) # Validate required mods for mod in self.required_mods: self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod") # Validate allowed mods for mod in self.allowed_mods: self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod") # Check internal compatibility of required mods self._check_mod_compatibility(self.required_mods, ruleset_key) # Check compatibility between required and allowed mods self._check_required_allowed_compatibility(ruleset_key) def validate_user_mods( self, user: "MultiplayerRoomUser", proposed_mods: list[APIMod], ) -> tuple[bool, list[APIMod]]: """ Validates user mods against playlist item rules and returns valid mods. Returns (is_valid, valid_mods). """ from typing import Literal, cast API_MODS = self._get_api_mods() ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id) valid_mods = [] all_proposed_valid = True # Check if mods are valid for the ruleset for mod in proposed_mods: if ( ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[ruleset_key] ): all_proposed_valid = False continue valid_mods.append(mod) # Check mod compatibility within user mods incompatible_mods = set() final_valid_mods = [] for mod in valid_mods: if mod["acronym"] in incompatible_mods: all_proposed_valid = False continue setting_mods = API_MODS[ruleset_key].get(mod["acronym"]) if setting_mods: incompatible_mods.update(setting_mods["IncompatibleMods"]) final_valid_mods.append(mod) # If not freestyle, check against allowed mods if not self.freestyle: allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} filtered_valid_mods = [] for mod in final_valid_mods: if mod["acronym"] not in allowed_acronyms: all_proposed_valid = False else: filtered_valid_mods.append(mod) final_valid_mods = filtered_valid_mods # Check compatibility with required mods required_mod_acronyms = {mod["acronym"] for mod in self.required_mods} all_mod_acronyms = { mod["acronym"] for mod in final_valid_mods } | required_mod_acronyms # Check for incompatibility between required and user mods filtered_valid_mods = [] for mod in final_valid_mods: mod_acronym = mod["acronym"] is_compatible = True for other_acronym in all_mod_acronyms: if other_acronym == mod_acronym: continue setting_mods = API_MODS[ruleset_key].get(mod_acronym) if setting_mods and other_acronym in setting_mods["IncompatibleMods"]: is_compatible = False all_proposed_valid = False break if is_compatible: filtered_valid_mods.append(mod) return all_proposed_valid, filtered_valid_mods def clone(self) -> "PlaylistItem": copy = self.model_copy() copy.required_mods = list(self.required_mods) copy.allowed_mods = list(self.allowed_mods) copy.expired = False copy.played_at = None return copy class _MultiplayerCountdown(SignalRUnionMessage): id: int = 0 time_remaining: timedelta is_exclusive: Annotated[ bool, Field(default=True), SignalRMeta(member_ignore=True) ] = True class MatchStartCountdown(_MultiplayerCountdown): union_type: ClassVar[Literal[0]] = 0 class ForceGameplayStartCountdown(_MultiplayerCountdown): union_type: ClassVar[Literal[1]] = 1 class ServerShuttingDownCountdown(_MultiplayerCountdown): union_type: ClassVar[Literal[2]] = 2 MultiplayerCountdown = ( MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown ) class MultiplayerRoomUser(BaseModel): user_id: int state: MultiplayerUserState = MultiplayerUserState.IDLE availability: BeatmapAvailability = BeatmapAvailability( state=DownloadState.UNKNOWN, download_progress=None ) mods: list[APIMod] = Field(default_factory=list) match_state: MatchUserState | None = None ruleset_id: int | None = None # freestyle beatmap_id: int | None = None # freestyle class MultiplayerRoom(BaseModel): room_id: int state: MultiplayerRoomState settings: MultiplayerRoomSettings users: list[MultiplayerRoomUser] = Field(default_factory=list) host: MultiplayerRoomUser | None = None match_state: MatchRoomState | None = None playlist: list[PlaylistItem] = Field(default_factory=list) active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list) channel_id: int @classmethod def from_db(cls, room) -> "MultiplayerRoom": """ 将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型) """ # 用户列表 users = [MultiplayerRoomUser(user_id=room.host_id)] host_user = MultiplayerRoomUser(user_id=room.host_id) # playlist 转换 playlist = [] if hasattr(room, "playlist"): for item in room.playlist: playlist.append( PlaylistItem( id=item.id, owner_id=item.owner_id, beatmap_id=item.beatmap_id, beatmap_checksum=item.beatmap.checksum if item.beatmap else "", ruleset_id=item.ruleset_id, required_mods=item.required_mods, allowed_mods=item.allowed_mods, expired=item.expired, playlist_order=item.playlist_order, played_at=item.played_at, star_rating=item.beatmap.difficulty_rating if item.beatmap is not None else 0.0, freestyle=item.freestyle, ) ) return cls( room_id=room.id, state=getattr(room, "state", MultiplayerRoomState.OPEN), settings=MultiplayerRoomSettings( name=room.name, playlist_item_id=playlist[0].id if playlist else 0, password=getattr(room, "password", ""), match_type=room.type, queue_mode=room.queue_mode, auto_start_duration=timedelta(seconds=room.auto_start_duration), auto_skip=room.auto_skip, ), users=users, host=host_user, match_state=None, playlist=playlist, active_countdowns=[], channel_id=getattr(room, "channel_id", 0), ) class MultiplayerQueue: def __init__(self, room: "ServerMultiplayerRoom"): self.server_room = room self.current_index = 0 @property def hub(self) -> "MultiplayerHub": return self.server_room.hub @property def upcoming_items(self): return sorted( (item for item in self.room.playlist if not item.expired), key=lambda i: i.playlist_order, ) @property def room(self): return self.server_room.room async def update_order(self): from app.database import Playlist match self.room.settings.queue_mode: case QueueMode.ALL_PLAYERS_ROUND_ROBIN: ordered_active_items = [] is_first_set = True first_set_order_by_user_id = {} active_items = [item for item in self.room.playlist if not item.expired] active_items.sort(key=lambda x: x.id) user_item_groups = {} for item in active_items: if item.owner_id not in user_item_groups: user_item_groups[item.owner_id] = [] user_item_groups[item.owner_id].append(item) max_items = max( (len(items) for items in user_item_groups.values()), default=0 ) for i in range(max_items): current_set = [] for user_id, items in user_item_groups.items(): if i < len(items): current_set.append(items[i]) if is_first_set: current_set.sort( key=lambda item: (item.playlist_order, item.id) ) ordered_active_items.extend(current_set) first_set_order_by_user_id = { item.owner_id: idx for idx, item in enumerate(ordered_active_items) } else: current_set.sort( key=lambda item: first_set_order_by_user_id.get( item.owner_id, 0 ) ) ordered_active_items.extend(current_set) is_first_set = False for idx, item in enumerate(ordered_active_items): item.playlist_order = idx case _: ordered_active_items = sorted( (item for item in self.room.playlist if not item.expired), key=lambda x: x.id, ) async with AsyncSession(engine) as session: for idx, item in enumerate(ordered_active_items): if item.playlist_order == idx: continue item.playlist_order = idx await Playlist.update(item, self.room.room_id, session) await self.hub.playlist_changed( self.server_room, item, beatmap_changed=False ) async def update_current_item(self): upcoming_items = self.upcoming_items next_item = ( upcoming_items[0] if upcoming_items else max( self.room.playlist, key=lambda i: i.played_at or datetime.min, ) ) self.current_index = self.room.playlist.index(next_item) last_id = self.room.settings.playlist_item_id self.room.settings.playlist_item_id = next_item.id if last_id != next_item.id: await self.hub.setting_changed(self.server_room, True) async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser): from app.database import Playlist is_host = self.room.host and self.room.host.user_id == user.user_id if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host: raise InvokeException("You are not the host") limit = HOST_LIMIT if is_host else PER_USER_LIMIT if ( len([True for u in self.room.playlist if u.owner_id == user.user_id]) >= limit ): raise InvokeException(f"You can only have {limit} items in the queue") if item.freestyle and len(item.allowed_mods) > 0: raise InvokeException("Freestyle items cannot have allowed mods") async with AsyncSession(engine) as session: fetcher = await get_fetcher() async with session: beatmap = await Beatmap.get_or_fetch( session, fetcher, bid=item.beatmap_id ) if beatmap is None: raise InvokeException("Beatmap not found") if item.beatmap_checksum != beatmap.checksum: raise InvokeException("Checksum mismatch") item.validate_playlist_item_mods() item.owner_id = user.user_id item.star_rating = float( beatmap.difficulty_rating ) # FIXME: beatmap use decimal await Playlist.add_to_db(item, self.room.room_id, session) self.room.playlist.append(item) await self.hub.playlist_added(self.server_room, item) await self.update_order() await self.update_current_item() async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser): from app.database import Playlist if item.freestyle and len(item.allowed_mods) > 0: raise InvokeException("Freestyle items cannot have allowed mods") async with AsyncSession(engine) as session: fetcher = await get_fetcher() async with session: beatmap = await Beatmap.get_or_fetch( session, fetcher, bid=item.beatmap_id ) if item.beatmap_checksum != beatmap.checksum: raise InvokeException("Checksum mismatch") existing_item = next( (i for i in self.room.playlist if i.id == item.id), None ) if existing_item is None: raise InvokeException( "Attempted to change an item that doesn't exist" ) if existing_item.owner_id != user.user_id and self.room.host != user: raise InvokeException( "Attempted to change an item which is not owned by the user" ) if existing_item.expired: raise InvokeException( "Attempted to change an item which has already been played" ) item.validate_playlist_item_mods() item.owner_id = user.user_id item.star_rating = float(beatmap.difficulty_rating) item.playlist_order = existing_item.playlist_order await Playlist.update(item, self.room.room_id, session) # Update item in playlist for idx, playlist_item in enumerate(self.room.playlist): if playlist_item.id == item.id: self.room.playlist[idx] = item break await self.hub.playlist_changed( self.server_room, item, beatmap_changed=item.beatmap_checksum != existing_item.beatmap_checksum, ) async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser): from app.database import Playlist item = next( (i for i in self.room.playlist if i.id == playlist_item_id), None, ) if item is None: raise InvokeException("Item does not exist in the room") # Check if it's the only item and current item if item == self.current_item: upcoming_items = [i for i in self.room.playlist if not i.expired] if len(upcoming_items) == 1: raise InvokeException("The only item in the room cannot be removed") if item.owner_id != user.user_id and self.room.host != user: raise InvokeException( "Attempted to remove an item which is not owned by the user" ) if item.expired: raise InvokeException( "Attempted to remove an item which has already been played" ) async with AsyncSession(engine) as session: await Playlist.delete_item(item.id, self.room.room_id, session) self.room.playlist.remove(item) self.current_index = self.room.playlist.index(self.upcoming_items[0]) await self.update_order() await self.update_current_item() await self.hub.playlist_removed(self.server_room, item.id) async def finish_current_item(self): from app.database import Playlist async with AsyncSession(engine) as session: played_at = datetime.now(UTC) await session.execute( update(Playlist) .where( col(Playlist.id) == self.current_item.id, col(Playlist.room_id) == self.room.room_id, ) .values(expired=True, played_at=played_at) ) self.room.playlist[self.current_index].expired = True self.room.playlist[self.current_index].played_at = played_at await self.hub.playlist_changed(self.server_room, self.current_item, True) await self.update_order() if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all( playitem.expired for playitem in self.room.playlist ): assert self.room.host await self.add_item(self.current_item.clone(), self.room.host) async def update_queue_mode(self): if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all( playitem.expired for playitem in self.room.playlist ): assert self.room.host await self.add_item(self.current_item.clone(), self.room.host) await self.update_order() await self.update_current_item() @property def current_item(self): return self.room.playlist[self.current_index] @dataclass class CountdownInfo: countdown: MultiplayerCountdown duration: timedelta task: asyncio.Task | None = None def __init__(self, countdown: MultiplayerCountdown): self.countdown = countdown self.duration = ( countdown.time_remaining if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0) ) class _MatchRequest(SignalRUnionMessage): ... class ChangeTeamRequest(_MatchRequest): union_type: ClassVar[Literal[0]] = 0 team_id: int class StartMatchCountdownRequest(_MatchRequest): union_type: ClassVar[Literal[1]] = 1 duration: timedelta class StopCountdownRequest(_MatchRequest): union_type: ClassVar[Literal[2]] = 2 id: int MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest class MatchTypeHandler(ABC): def __init__(self, room: "ServerMultiplayerRoom"): self.room = room self.hub = room.hub @abstractmethod async def handle_join(self, user: MultiplayerRoomUser): ... @abstractmethod async def handle_request( self, user: MultiplayerRoomUser, request: MatchRequest ): ... @abstractmethod async def handle_leave(self, user: MultiplayerRoomUser): ... @abstractmethod def get_details(self) -> MatchStartedEventDetail: ... class HeadToHeadHandler(MatchTypeHandler): @override async def handle_join(self, user: MultiplayerRoomUser): if user.match_state is not None: user.match_state = None await self.hub.change_user_match_state(self.room, user) @override async def handle_request( self, user: MultiplayerRoomUser, request: MatchRequest ): ... @override async def handle_leave(self, user: MultiplayerRoomUser): ... @override def get_details(self) -> MatchStartedEventDetail: detail = MatchStartedEventDetail(room_type="head_to_head", team=None) return detail class TeamVersusHandler(MatchTypeHandler): @override def __init__(self, room: "ServerMultiplayerRoom"): super().__init__(room) self.state = TeamVersusRoomState() room.room.match_state = self.state task = asyncio.create_task(self.hub.change_room_match_state(self.room)) self.hub.tasks.add(task) task.add_done_callback(self.hub.tasks.discard) def _get_best_available_team(self) -> int: for team in self.state.teams: if all( ( user.match_state is None or not isinstance(user.match_state, TeamVersusUserState) or user.match_state.team_id != team.id ) for user in self.room.room.users ): return team.id from collections import defaultdict team_counts = defaultdict(int) for user in self.room.room.users: if user.match_state is not None and isinstance( user.match_state, TeamVersusUserState ): team_counts[user.match_state.team_id] += 1 if team_counts: min_count = min(team_counts.values()) for team_id, count in team_counts.items(): if count == min_count: return team_id return self.state.teams[0].id if self.state.teams else 0 @override async def handle_join(self, user: MultiplayerRoomUser): best_team_id = self._get_best_available_team() user.match_state = TeamVersusUserState(team_id=best_team_id) await self.hub.change_user_match_state(self.room, user) @override async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): if not isinstance(request, ChangeTeamRequest): return if request.team_id not in [team.id for team in self.state.teams]: raise InvokeException("Invalid team ID") user.match_state = TeamVersusUserState(team_id=request.team_id) await self.hub.change_user_match_state(self.room, user) @override async def handle_leave(self, user: MultiplayerRoomUser): ... @override def get_details(self) -> MatchStartedEventDetail: teams: dict[int, Literal["blue", "red"]] = {} for user in self.room.room.users: if user.match_state is not None and isinstance( user.match_state, TeamVersusUserState ): teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red" detail = MatchStartedEventDetail(room_type="team_versus", team=teams) return detail MATCH_TYPE_HANDLERS = { MatchType.HEAD_TO_HEAD: HeadToHeadHandler, MatchType.TEAM_VERSUS: TeamVersusHandler, } @dataclass class ServerMultiplayerRoom: room: MultiplayerRoom category: RoomCategory status: RoomStatus start_at: datetime hub: "MultiplayerHub" match_type_handler: MatchTypeHandler queue: MultiplayerQueue _next_countdown_id: int _countdown_id_lock: asyncio.Lock _tracked_countdown: dict[int, CountdownInfo] def __init__( self, room: MultiplayerRoom, category: RoomCategory, start_at: datetime, hub: "MultiplayerHub", ): self.room = room self.category = category self.status = RoomStatus.IDLE self.start_at = start_at self.hub = hub self.queue = MultiplayerQueue(self) self._next_countdown_id = 0 self._countdown_id_lock = asyncio.Lock() self._tracked_countdown = {} async def set_handler(self): self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type]( self ) for i in self.room.users: await self.match_type_handler.handle_join(i) async def get_next_countdown_id(self) -> int: async with self._countdown_id_lock: self._next_countdown_id += 1 return self._next_countdown_id async def start_countdown( self, countdown: MultiplayerCountdown, on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None, ): async def _countdown_task(self: "ServerMultiplayerRoom"): await asyncio.sleep(info.duration.total_seconds()) if on_complete is not None: await on_complete(self) await self.stop_countdown(countdown) if countdown.is_exclusive: await self.stop_all_countdowns() countdown.id = await self.get_next_countdown_id() info = CountdownInfo(countdown) self.room.active_countdowns.append(info.countdown) self._tracked_countdown[countdown.id] = info await self.hub.send_match_event( self, CountdownStartedEvent(countdown=info.countdown) ) info.task = asyncio.create_task(_countdown_task(self)) async def stop_countdown(self, countdown: MultiplayerCountdown): info = self._tracked_countdown.get(countdown.id) if info is None: return del self._tracked_countdown[countdown.id] self.room.active_countdowns.remove(countdown) await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id)) if info.task is not None and not info.task.done(): info.task.cancel() async def stop_all_countdowns(self): for countdown in list(self._tracked_countdown.values()): await self.stop_countdown(countdown.countdown) self._tracked_countdown.clear() self.room.active_countdowns.clear() class _MatchServerEvent(SignalRUnionMessage): ... class CountdownStartedEvent(_MatchServerEvent): countdown: MultiplayerCountdown union_type: ClassVar[Literal[0]] = 0 class CountdownStoppedEvent(_MatchServerEvent): id: int union_type: ClassVar[Literal[1]] = 1 MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent class GameplayAbortReason(IntEnum): LOAD_TOOK_TOO_LONG = 0 HOST_ABORTED = 1 class MatchStartedEventDetail(TypedDict): room_type: Literal["playlists", "head_to_head", "team_versus"] team: dict[int, Literal["blue", "red"]] | None