diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 39ced12..9bccb71 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -103,6 +103,84 @@ class PlaylistItem(MessagePackArrayModel): star: float freestyle: bool + def validate_user_mods( + self, + user: "MultiplayerRoomUser", + proposed_mods: list[APIMod], + ) -> tuple[bool, list[APIMod]]: + """ + Validates user mods against playlist item rules and returns valid mods. + Returns (is_valid, valid_mods). + """ + from typing import Literal, cast + + from app.models.mods import API_MODS, init_mods + + if not API_MODS: + init_mods() + + ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id + ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id) + + valid_mods = [] + all_proposed_valid = True + + # Check if mods are valid for the ruleset + for mod in proposed_mods: + if ruleset_key not in API_MODS or mod.acronym not in API_MODS[ruleset_key]: + all_proposed_valid = False + continue + valid_mods.append(mod) + + # Check mod compatibility within user mods + incompatible_mods = set() + final_valid_mods = [] + for mod in valid_mods: + if mod.acronym in incompatible_mods: + all_proposed_valid = False + continue + setting_mods = API_MODS[ruleset_key].get(mod.acronym) + if setting_mods: + incompatible_mods.update(setting_mods["IncompatibleMods"]) + final_valid_mods.append(mod) + + # If not freestyle, check against allowed mods + if not self.freestyle: + allowed_acronyms = {mod.acronym for mod in self.allowed_mods} + filtered_valid_mods = [] + for mod in final_valid_mods: + if mod.acronym not in allowed_acronyms: + all_proposed_valid = False + else: + filtered_valid_mods.append(mod) + final_valid_mods = filtered_valid_mods + + # Check compatibility with required mods + required_mod_acronyms = {mod.acronym for mod in self.required_mods} + all_mod_acronyms = { + mod.acronym for mod in final_valid_mods + } | required_mod_acronyms + + # Check for incompatibility between required and user mods + filtered_valid_mods = [] + for mod in final_valid_mods: + mod_acronym = mod.acronym + is_compatible = True + + for other_acronym in all_mod_acronyms: + if other_acronym == mod_acronym: + continue + setting_mods = API_MODS[ruleset_key].get(mod_acronym) + if setting_mods and other_acronym in setting_mods["IncompatibleMods"]: + is_compatible = False + all_proposed_valid = False + break + + if is_compatible: + filtered_valid_mods.append(mod) + + return all_proposed_valid, filtered_valid_mods + class _MultiplayerCountdown(MessagePackArrayModel): id: int @@ -405,7 +483,6 @@ class MultiplayerQueue: current_id = self.room.settings.playlist_item_id return next( (item for item in self.room.playlist if item.id == current_id), - None, ) diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 477396b..bd34be0 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import override from app.database import Room +from app.database.beatmap import Beatmap from app.database.playlists import Playlist from app.dependencies.database import engine from app.exception import InvokeException @@ -17,10 +18,13 @@ from app.models.multiplayer_hub import ( ServerMultiplayerRoom, ) from app.models.room import RoomCategory, RoomStatus +from app.models.score import GameMode from app.models.signalr import serialize_to_list from .hub import Client, Hub +from msgpack_lazer_api import APIMod +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -219,3 +223,175 @@ class MultiplayerHub(Hub[MultiplayerClientState]): "PlaylistItemChanged", serialize_to_list(item), ) + + async def ChangeUserStyle( + self, client: Client, beatmap_id: int | None, ruleset_id: int | None + ): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.change_user_style( + beatmap_id, + ruleset_id, + server_room, + user, + ) + + async def validate_styles(self, room: ServerMultiplayerRoom): + assert room.queue + if not room.queue.current_item.freestyle: + for user in room.room.users: + await self.change_user_style( + None, + None, + room, + user, + ) + async with AsyncSession(engine) as session: + beatmap = await session.get(Beatmap, room.queue.current_item.beatmap_id) + if beatmap is None: + raise InvokeException("Beatmap not found") + beatmap_ids = ( + await session.exec( + select(Beatmap.id, Beatmap.mode).where( + Beatmap.beatmapset_id == beatmap.beatmapset_id, + ) + ) + ).all() + for user in room.room.users: + beatmap_id = user.beatmap_id + ruleset_id = user.ruleset_id + user_beatmap = next( + (b for b in beatmap_ids if b[0] == beatmap_id), + None, + ) + if beatmap_id is not None and user_beatmap is None: + beatmap_id = None + beatmap_ruleset = user_beatmap[1] if user_beatmap else beatmap.mode + if ( + ruleset_id is not None + and beatmap_ruleset != GameMode.OSU + and ruleset_id != beatmap_ruleset + ): + ruleset_id = None + await self.change_user_style( + beatmap_id, + ruleset_id, + room, + user, + ) + + for user in room.room.users: + is_valid, valid_mods = room.queue.current_item.validate_user_mods( + user, user.mods + ) + if not is_valid: + await self.change_user_mods(valid_mods, room, user) + + async def change_user_style( + self, + beatmap_id: int | None, + ruleset_id: int | None, + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + ): + if user.beatmap_id == beatmap_id and user.ruleset_id == ruleset_id: + return + + if beatmap_id is not None or ruleset_id is not None: + assert room.queue + if not room.queue.current_item.freestyle: + raise InvokeException("Current item does not allow free user styles.") + + async with AsyncSession(engine) as session: + item_beatmap = await session.get( + Beatmap, room.queue.current_item.beatmap_id + ) + if item_beatmap is None: + raise InvokeException("Item beatmap not found") + + user_beatmap = ( + item_beatmap + if beatmap_id is None + else await session.get(Beatmap, beatmap_id) + ) + + if user_beatmap is None: + raise InvokeException("Invalid beatmap selected.") + + if user_beatmap.beatmapset_id != item_beatmap.beatmapset_id: + raise InvokeException( + "Selected beatmap is not from the same beatmap set." + ) + + if ( + ruleset_id is not None + and user_beatmap.mode != GameMode.OSU + and ruleset_id != user_beatmap.mode + ): + raise InvokeException( + "Selected ruleset is not supported for the given beatmap." + ) + + user.beatmap_id = beatmap_id + user.ruleset_id = ruleset_id + + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserStyleChanged", + user.user_id, + beatmap_id, + ruleset_id, + ) + + async def ChangeUserMods(self, client: Client, new_mods: list[APIMod]): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await self.change_user_mods(new_mods, server_room, user) + + async def change_user_mods( + self, + new_mods: list[APIMod], + room: ServerMultiplayerRoom, + user: MultiplayerRoomUser, + ): + assert room.queue + is_valid, valid_mods = room.queue.current_item.validate_user_mods( + user, new_mods + ) + if not is_valid: + incompatible_mods = [ + mod.acronym for mod in new_mods if mod not in valid_mods + ] + raise InvokeException( + f"Incompatible mods were selected: {','.join(incompatible_mods)}" + ) + + if user.mods == valid_mods: + return + + user.mods = valid_mods + + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "UserModsChanged", + user.user_id, + valid_mods, + )