diff --git a/app/models/mods.py b/app/models/mods.py index fb07b0c..b96811f 100644 --- a/app/models/mods.py +++ b/app/models/mods.py @@ -215,3 +215,35 @@ def get_speed_rate(mods: list[APIMod]): if mod["acronym"] in {"DT", "NC", "HT", "DC"}: rate *= mod.get("settings", {}).get("speed_change", 1.0) # pyright: ignore[reportOperatorIssue] return rate + + +def get_available_mods(ruleset_id: int, required_mods: list[APIMod]) -> list[APIMod]: + if ruleset_id not in API_MODS: + return [] + + ruleset_mods = API_MODS[ruleset_id] + required_mod_acronyms = {mod["acronym"] for mod in required_mods} + + incompatible_mods = set() + for mod_acronym in required_mod_acronyms: + if mod_acronym in ruleset_mods: + incompatible_mods.update(ruleset_mods[mod_acronym]["IncompatibleMods"]) + + available_mods = [] + for mod_acronym, mod_data in ruleset_mods.items(): + if mod_acronym in required_mod_acronyms: + continue + + if mod_acronym in incompatible_mods: + continue + + if any( + required_acronym in mod_data["IncompatibleMods"] + for required_acronym in required_mod_acronyms + ): + continue + + if mod_data.get("UserPlayable", False): + available_mods.append(mod_acronym) + + return [APIMod(acronym=acronym) for acronym in available_mods] diff --git a/app/router/chat/banchobot.py b/app/router/chat/banchobot.py index 6cddee9..7aefccc 100644 --- a/app/router/chat/banchobot.py +++ b/app/router/chat/banchobot.py @@ -16,7 +16,7 @@ from app.database.score import Score from app.database.statistics import UserStatistics, get_rank from app.dependencies.fetcher import get_fetcher from app.exception import InvokeException -from app.models.mods import APIMod, mod_to_save +from app.models.mods import APIMod, get_available_mods, mod_to_save from app.models.multiplayer_hub import ( ChangeTeamRequest, ServerMultiplayerRoom, @@ -501,6 +501,7 @@ async def _mp_mods( required_mods = [] allowed_mods = [] freestyle = False + freemod = False for arg in args: arg = arg.upper() if arg == "NONE": @@ -509,6 +510,8 @@ async def _mp_mods( break elif arg == "FREESTYLE": freestyle = True + elif arg == "FREEMOD": + freemod = True elif arg.startswith("+"): mod = arg.removeprefix("+") if len(mod) != 2: @@ -524,10 +527,14 @@ async def _mp_mods( item = current_item.model_copy(deep=True) item.owner_id = signalr_client.user_id item.freestyle = freestyle - if not freestyle: - item.allowed_mods = allowed_mods - else: + if freestyle: item.allowed_mods = [] + elif freemod: + item.allowed_mods = get_available_mods( + current_item.ruleset_id, required_mods + ) + else: + item.allowed_mods = allowed_mods item.required_mods = required_mods if item.expired: item.id = 0 diff --git a/app/service/daily_challenge.py b/app/service/daily_challenge.py index 0d4f93e..7b10874 100644 --- a/app/service/daily_challenge.py +++ b/app/service/daily_challenge.py @@ -15,7 +15,7 @@ from app.dependencies.database import get_redis, with_db from app.dependencies.scheduler import get_scheduler from app.log import logger from app.models.metadata_hub import DailyChallengeInfo -from app.models.mods import APIMod +from app.models.mods import APIMod, get_available_mods from app.models.room import RoomCategory from app.utils import are_same_weeks @@ -25,7 +25,11 @@ from sqlmodel import col, select async def create_daily_challenge_room( - beatmap: int, ruleset_id: int, duration: int, required_mods: list[APIMod] = [] + beatmap: int, + ruleset_id: int, + duration: int, + required_mods: list[APIMod] = [], + allowed_mods: list[APIMod] = [], ) -> Room: async with with_db() as session: today = datetime.now(UTC).date() @@ -41,6 +45,7 @@ async def create_daily_challenge_room( ruleset_id=ruleset_id, beatmap_id=beatmap, required_mods=required_mods, + allowed_mods=allowed_mods, ) ], category=RoomCategory.DAILY_CHALLENGE, @@ -73,6 +78,7 @@ async def daily_challenge_job(): beatmap = await redis.hget(key, "beatmap") # pyright: ignore[reportGeneralTypeIssues] ruleset_id = await redis.hget(key, "ruleset_id") # pyright: ignore[reportGeneralTypeIssues] required_mods = await redis.hget(key, "required_mods") # pyright: ignore[reportGeneralTypeIssues] + allowed_mods = await redis.hget(key, "allowed_mods") # pyright: ignore[reportGeneralTypeIssues] if beatmap is None or ruleset_id is None: logger.warning( @@ -89,9 +95,14 @@ async def daily_challenge_job(): beatmap_int = int(beatmap) ruleset_id_int = int(ruleset_id) - mods_list = [] + required_mods_list = [] + allowed_mods_list = [] if required_mods: - mods_list = json.loads(required_mods) + required_mods_list = json.loads(required_mods) + if allowed_mods: + allowed_mods_list = json.loads(allowed_mods) + else: + allowed_mods_list = get_available_mods(ruleset_id_int, required_mods_list) next_day = (now + timedelta(days=1)).replace( hour=0, minute=0, second=0, microsecond=0 @@ -99,7 +110,8 @@ async def daily_challenge_job(): room = await create_daily_challenge_room( beatmap=beatmap_int, ruleset_id=ruleset_id_int, - required_mods=mods_list, + required_mods=required_mods_list, + allowed_mods=allowed_mods_list, duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60), ) await MetadataHubs.broadcast_call(