From 9da9f27febcccabbded8dd736c8376976fe66c4e Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Mon, 4 Aug 2025 02:20:14 +0000 Subject: [PATCH] feat(multiplayer): complete validation --- app/models/multiplayer_hub.py | 103 ++++++++++++++++++++++++++++++--- app/signalr/hub/multiplayer.py | 38 ++++++++++-- 2 files changed, 129 insertions(+), 12 deletions(-) diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 97148db..e2f4edf 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import UTC, datetime, timedelta from enum import IntEnum -from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, override +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, cast, override from app.database.beatmap import Beatmap from app.dependencies.database import engine @@ -107,6 +107,97 @@ class PlaylistItem(BaseModel): star_rating: float freestyle: bool + def _get_api_mods(self): + from app.models.mods import API_MODS, init_mods + + if not API_MODS: + init_mods() + return API_MODS + + def _validate_mod_for_ruleset( + self, mod: APIMod, ruleset_key: int, context: str = "mod" + ) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + + # Check if mod is valid for ruleset + if ( + typed_ruleset_key not in API_MODS + or mod["acronym"] not in API_MODS[typed_ruleset_key] + ): + raise InvokeException( + f"{context} {mod['acronym']} is invalid for this ruleset" + ) + + mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]] + + # Check if mod is unplayable in multiplayer + if mod_settings.get("UserPlayable", True) is False: + raise InvokeException( + f"{context} {mod['acronym']} is not playable by users" + ) + + if mod_settings.get("ValidForMultiplayer", True) is False: + raise InvokeException( + f"{context} {mod['acronym']} is not valid for multiplayer" + ) + + def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + + for i, mod1 in enumerate(mods): + mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"]) + if mod1_settings: + incompatible = set(mod1_settings.get("IncompatibleMods", [])) + for mod2 in mods[i + 1 :]: + if mod2["acronym"] in incompatible: + raise InvokeException( + f"Mods {mod1['acronym']} and " + f"{mod2['acronym']} are incompatible" + ) + + def _check_required_allowed_compatibility(self, ruleset_key: int) -> None: + from typing import Literal, cast + + API_MODS = self._get_api_mods() + typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) + allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} + + for req_mod in self.required_mods: + req_acronym = req_mod["acronym"] + req_settings = API_MODS[typed_ruleset_key].get(req_acronym) + if req_settings: + incompatible = set(req_settings.get("IncompatibleMods", [])) + conflicting_allowed = allowed_acronyms & incompatible + if conflicting_allowed: + conflict_list = ", ".join(conflicting_allowed) + raise InvokeException( + f"Required mod {req_acronym} conflicts with " + f"allowed mods: {conflict_list}" + ) + + def validate_playlist_item_mods(self) -> None: + ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id) + + # Validate required mods + for mod in self.required_mods: + self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod") + + # Validate allowed mods + for mod in self.allowed_mods: + self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod") + + # Check internal compatibility of required mods + self._check_mod_compatibility(self.required_mods, ruleset_key) + + # Check compatibility between required and allowed mods + self._check_required_allowed_compatibility(ruleset_key) + def validate_user_mods( self, user: "MultiplayerRoomUser", @@ -118,10 +209,7 @@ class PlaylistItem(BaseModel): """ from typing import Literal, cast - from app.models.mods import API_MODS, init_mods - - if not API_MODS: - init_mods() + API_MODS = self._get_api_mods() ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id) @@ -367,7 +455,8 @@ class MultiplayerQueue: raise InvokeException("Beatmap not found") if item.beatmap_checksum != beatmap.checksum: raise InvokeException("Checksum mismatch") - # TODO: mods validation + + item.validate_playlist_item_mods() item.owner_id = user.user_id item.star_rating = float( beatmap.difficulty_rating @@ -410,7 +499,7 @@ class MultiplayerQueue: "Attempted to change an item which has already been played" ) - # TODO: mods validation + item.validate_playlist_item_mods() item.owner_id = user.user_id item.star_rating = float(beatmap.difficulty_rating) item.playlist_order = existing_item.playlist_order diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index eb602fe..3be8024 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -64,6 +64,18 @@ class MultiplayerHub(Hub[MultiplayerClientState]): connection_token=client.connection_token, ) + @override + async def _clean_state(self, state: MultiplayerClientState): + user_id = int(state.connection_id) + if state.room_id != 0 and state.room_id in self.rooms: + server_room = self.rooms[state.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == user_id), None) + if user is not None: + await self.make_user_leave( + self.get_client_by_id(str(user_id)), server_room, user + ) + async def CreateRoom(self, client: Client, room: MultiplayerRoom): logger.info(f"[MultiplayerHub] {client.user_id} creating room") store = self.get_or_create_state(client) @@ -554,8 +566,17 @@ class MultiplayerHub(Hub[MultiplayerClientState]): raise InvokeException("You are not in this room") if room.host is None or room.host.user_id != client.user_id: raise InvokeException("You are not the host of this room") - if any(u.state != MultiplayerUserState.READY for u in room.users): - raise InvokeException("Not all users are ready") + + # Check host state - host must be ready or spectating + if room.host.state not in ( + MultiplayerUserState.SPECTATING, + MultiplayerUserState.READY, + ): + raise InvokeException("Can't start match when the host is not ready.") + + # Check if any users are ready + if all(u.state != MultiplayerUserState.READY for u in room.users): + raise InvokeException("Can't start match when no users are ready.") await self.start_match(server_room) @@ -646,7 +667,11 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if len(room.room.users) == 0: await self.end_room(room) await self.update_room_state(room) - if room.room.host and room.room.host.user_id == user.user_id: + if ( + len(room.room.users) != 0 + and room.room.host + and room.room.host.user_id == user.user_id + ): next_host = room.room.users[0] await self.set_host(room, next_host) @@ -710,6 +735,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if room.host is None or room.host.user_id != client.user_id: raise InvokeException("You are not the host of this room") + if user_id == client.user_id: + raise InvokeException("Can't kick self") + user = next((u for u in room.users if u.user_id == user_id), None) if user is None: raise InvokeException("User not found in this room") @@ -780,9 +808,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if ( room.state != MultiplayerRoomState.PLAYING - or room.state == MultiplayerRoomState.WAITING_FOR_LOAD + and room.state != MultiplayerRoomState.WAITING_FOR_LOAD ): - raise InvokeException("Room is not in a playable state") + raise InvokeException("Cannot abort a match that hasn't started.") await asyncio.gather( *[