feat(multiplayer): complete validation
This commit is contained in:
@@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import IntEnum
|
||||
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, override
|
||||
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, cast, override
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.dependencies.database import engine
|
||||
@@ -107,6 +107,97 @@ class PlaylistItem(BaseModel):
|
||||
star_rating: float
|
||||
freestyle: bool
|
||||
|
||||
def _get_api_mods(self):
|
||||
from app.models.mods import API_MODS, init_mods
|
||||
|
||||
if not API_MODS:
|
||||
init_mods()
|
||||
return API_MODS
|
||||
|
||||
def _validate_mod_for_ruleset(
|
||||
self, mod: APIMod, ruleset_key: int, context: str = "mod"
|
||||
) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
# Check if mod is valid for ruleset
|
||||
if (
|
||||
typed_ruleset_key not in API_MODS
|
||||
or mod["acronym"] not in API_MODS[typed_ruleset_key]
|
||||
):
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is invalid for this ruleset"
|
||||
)
|
||||
|
||||
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
|
||||
|
||||
# Check if mod is unplayable in multiplayer
|
||||
if mod_settings.get("UserPlayable", True) is False:
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is not playable by users"
|
||||
)
|
||||
|
||||
if mod_settings.get("ValidForMultiplayer", True) is False:
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is not valid for multiplayer"
|
||||
)
|
||||
|
||||
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
for i, mod1 in enumerate(mods):
|
||||
mod1_settings = API_MODS[typed_ruleset_key].get(mod1["acronym"])
|
||||
if mod1_settings:
|
||||
incompatible = set(mod1_settings.get("IncompatibleMods", []))
|
||||
for mod2 in mods[i + 1 :]:
|
||||
if mod2["acronym"] in incompatible:
|
||||
raise InvokeException(
|
||||
f"Mods {mod1['acronym']} and "
|
||||
f"{mod2['acronym']} are incompatible"
|
||||
)
|
||||
|
||||
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
|
||||
for req_mod in self.required_mods:
|
||||
req_acronym = req_mod["acronym"]
|
||||
req_settings = API_MODS[typed_ruleset_key].get(req_acronym)
|
||||
if req_settings:
|
||||
incompatible = set(req_settings.get("IncompatibleMods", []))
|
||||
conflicting_allowed = allowed_acronyms & incompatible
|
||||
if conflicting_allowed:
|
||||
conflict_list = ", ".join(conflicting_allowed)
|
||||
raise InvokeException(
|
||||
f"Required mod {req_acronym} conflicts with "
|
||||
f"allowed mods: {conflict_list}"
|
||||
)
|
||||
|
||||
def validate_playlist_item_mods(self) -> None:
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
|
||||
|
||||
# Validate required mods
|
||||
for mod in self.required_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Required mod")
|
||||
|
||||
# Validate allowed mods
|
||||
for mod in self.allowed_mods:
|
||||
self._validate_mod_for_ruleset(mod, ruleset_key, "Allowed mod")
|
||||
|
||||
# Check internal compatibility of required mods
|
||||
self._check_mod_compatibility(self.required_mods, ruleset_key)
|
||||
|
||||
# Check compatibility between required and allowed mods
|
||||
self._check_required_allowed_compatibility(ruleset_key)
|
||||
|
||||
def validate_user_mods(
|
||||
self,
|
||||
user: "MultiplayerRoomUser",
|
||||
@@ -118,10 +209,7 @@ class PlaylistItem(BaseModel):
|
||||
"""
|
||||
from typing import Literal, cast
|
||||
|
||||
from app.models.mods import API_MODS, init_mods
|
||||
|
||||
if not API_MODS:
|
||||
init_mods()
|
||||
API_MODS = self._get_api_mods()
|
||||
|
||||
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
|
||||
@@ -367,7 +455,8 @@ class MultiplayerQueue:
|
||||
raise InvokeException("Beatmap not found")
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
# TODO: mods validation
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = float(
|
||||
beatmap.difficulty_rating
|
||||
@@ -410,7 +499,7 @@ class MultiplayerQueue:
|
||||
"Attempted to change an item which has already been played"
|
||||
)
|
||||
|
||||
# TODO: mods validation
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
item.star_rating = float(beatmap.difficulty_rating)
|
||||
item.playlist_order = existing_item.playlist_order
|
||||
|
||||
Reference in New Issue
Block a user