Files
g0v0-server/app/models/multiplayer_hub.py
2025-08-03 15:14:30 +00:00

752 lines
24 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from enum import IntEnum
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, override
from app.database.beatmap import Beatmap
from app.dependencies.database import engine
from app.exception import InvokeException
from .mods import APIMod
from .room import (
DownloadState,
MatchType,
MultiplayerRoomState,
MultiplayerUserState,
QueueMode,
RoomCategory,
RoomStatus,
)
from .signalr import (
SignalRMeta,
SignalRUnionMessage,
UserState,
)
from pydantic import BaseModel, Field
from sqlalchemy import update
from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.signalr.hub import MultiplayerHub
HOST_LIMIT = 50
PER_USER_LIMIT = 3
class MultiplayerClientState(UserState):
room_id: int = 0
class MultiplayerRoomSettings(BaseModel):
name: str = "Unnamed Room"
playlist_item_id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
password: str = ""
match_type: MatchType = MatchType.HEAD_TO_HEAD
queue_mode: QueueMode = QueueMode.HOST_ONLY
auto_start_duration: timedelta = timedelta(seconds=0)
auto_skip: bool = False
class BeatmapAvailability(BaseModel):
state: DownloadState = DownloadState.UNKNOWN
progress: float | None = None
class _MatchUserState(SignalRUnionMessage): ...
class TeamVersusUserState(_MatchUserState):
team_id: int
union_type: ClassVar[Literal[0]] = 0
MatchUserState = TeamVersusUserState
class _MatchRoomState(SignalRUnionMessage): ...
class MultiplayerTeam(BaseModel):
id: int
name: str
class TeamVersusRoomState(_MatchRoomState):
teams: list[MultiplayerTeam] = Field(
default_factory=lambda: [
MultiplayerTeam(id=0, name="Team Red"),
MultiplayerTeam(id=1, name="Team Blue"),
]
)
union_type: ClassVar[Literal[0]] = 0
MatchRoomState = TeamVersusRoomState
class PlaylistItem(BaseModel):
id: Annotated[int, Field(default=0), SignalRMeta(use_abbr=False)]
owner_id: int
beatmap_id: int
beatmap_checksum: str
ruleset_id: int
required_mods: list[APIMod] = Field(default_factory=list)
allowed_mods: list[APIMod] = Field(default_factory=list)
expired: bool
playlist_order: int
played_at: datetime | None = None
star_rating: float
freestyle: bool
def validate_user_mods(
self,
user: "MultiplayerRoomUser",
proposed_mods: list[APIMod],
) -> tuple[bool, list[APIMod]]:
"""
Validates user mods against playlist item rules and returns valid mods.
Returns (is_valid, valid_mods).
"""
from typing import Literal, cast
from app.models.mods import API_MODS, init_mods
if not API_MODS:
init_mods()
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
valid_mods = []
all_proposed_valid = True
# Check if mods are valid for the ruleset
for mod in proposed_mods:
if (
ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[ruleset_key]
):
all_proposed_valid = False
continue
valid_mods.append(mod)
# Check mod compatibility within user mods
incompatible_mods = set()
final_valid_mods = []
for mod in valid_mods:
if mod["acronym"] in incompatible_mods:
all_proposed_valid = False
continue
setting_mods = API_MODS[ruleset_key].get(mod["acronym"])
if setting_mods:
incompatible_mods.update(setting_mods["IncompatibleMods"])
final_valid_mods.append(mod)
# If not freestyle, check against allowed mods
if not self.freestyle:
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
filtered_valid_mods = []
for mod in final_valid_mods:
if mod["acronym"] not in allowed_acronyms:
all_proposed_valid = False
else:
filtered_valid_mods.append(mod)
final_valid_mods = filtered_valid_mods
# Check compatibility with required mods
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
all_mod_acronyms = {
mod["acronym"] for mod in final_valid_mods
} | required_mod_acronyms
# Check for incompatibility between required and user mods
filtered_valid_mods = []
for mod in final_valid_mods:
mod_acronym = mod["acronym"]
is_compatible = True
for other_acronym in all_mod_acronyms:
if other_acronym == mod_acronym:
continue
setting_mods = API_MODS[ruleset_key].get(mod_acronym)
if setting_mods and other_acronym in setting_mods["IncompatibleMods"]:
is_compatible = False
all_proposed_valid = False
break
if is_compatible:
filtered_valid_mods.append(mod)
return all_proposed_valid, filtered_valid_mods
def clone(self) -> "PlaylistItem":
copy = self.model_copy()
copy.required_mods = list(self.required_mods)
copy.allowed_mods = list(self.allowed_mods)
return copy
class _MultiplayerCountdown(BaseModel):
id: int = 0
remaining: timedelta
is_exclusive: bool = False
class MatchStartCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[0]] = 0
class ForceGameplayStartCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[1]] = 1
class ServerShuttingDownCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[2]] = 2
MultiplayerCountdown = (
MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
)
class MultiplayerRoomUser(BaseModel):
user_id: int
state: MultiplayerUserState = MultiplayerUserState.IDLE
availability: BeatmapAvailability = BeatmapAvailability(
state=DownloadState.UNKNOWN, progress=None
)
mods: list[APIMod] = Field(default_factory=list)
match_state: MatchUserState | None = None
ruleset_id: int | None = None # freestyle
beatmap_id: int | None = None # freestyle
class MultiplayerRoom(BaseModel):
room_id: int
state: MultiplayerRoomState
settings: MultiplayerRoomSettings
users: list[MultiplayerRoomUser] = Field(default_factory=list)
host: MultiplayerRoomUser | None = None
match_state: MatchRoomState | None = None
playlist: list[PlaylistItem] = Field(default_factory=list)
active_countdowns: list[MultiplayerCountdown] = Field(default_factory=list)
channel_id: int
class MultiplayerQueue:
def __init__(self, room: "ServerMultiplayerRoom"):
self.server_room = room
self.current_index = 0
@property
def hub(self) -> "MultiplayerHub":
return self.server_room.hub
@property
def upcoming_items(self):
return sorted(
(item for item in self.room.playlist if not item.expired),
key=lambda i: i.playlist_order,
)
@property
def room(self):
return self.server_room.room
async def update_order(self):
from app.database import Playlist
match self.room.settings.queue_mode:
case QueueMode.ALL_PLAYERS_ROUND_ROBIN:
ordered_active_items = []
is_first_set = True
first_set_order_by_user_id = {}
active_items = [item for item in self.room.playlist if not item.expired]
active_items.sort(key=lambda x: x.id)
user_item_groups = {}
for item in active_items:
if item.owner_id not in user_item_groups:
user_item_groups[item.owner_id] = []
user_item_groups[item.owner_id].append(item)
max_items = max(
(len(items) for items in user_item_groups.values()), default=0
)
for i in range(max_items):
current_set = []
for user_id, items in user_item_groups.items():
if i < len(items):
current_set.append(items[i])
if is_first_set:
current_set.sort(
key=lambda item: (item.playlist_order, item.id)
)
ordered_active_items.extend(current_set)
first_set_order_by_user_id = {
item.owner_id: idx
for idx, item in enumerate(ordered_active_items)
}
else:
current_set.sort(
key=lambda item: first_set_order_by_user_id.get(
item.owner_id, 0
)
)
ordered_active_items.extend(current_set)
is_first_set = False
for idx, item in enumerate(ordered_active_items):
item.playlist_order = idx
case _:
ordered_active_items = sorted(
(item for item in self.room.playlist if not item.expired),
key=lambda x: x.id,
)
async with AsyncSession(engine) as session:
for idx, item in enumerate(ordered_active_items):
if item.playlist_order == idx:
continue
item.playlist_order = idx
await Playlist.update(item, self.room.room_id, session)
await self.hub.playlist_changed(
self.server_room, item, beatmap_changed=False
)
async def update_current_item(self):
upcoming_items = self.upcoming_items
next_item = (
upcoming_items[0]
if upcoming_items
else max(
self.room.playlist,
key=lambda i: i.played_at or datetime.min,
)
)
self.current_index = self.room.playlist.index(next_item)
last_id = self.room.settings.playlist_item_id
self.room.settings.playlist_item_id = next_item.id
if last_id != next_item.id:
await self.hub.setting_changed(self.server_room, True)
async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
from app.database import Playlist
is_host = self.room.host and self.room.host.user_id == user.user_id
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host:
raise InvokeException("You are not the host")
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
if (
len([True for u in self.room.playlist if u.owner_id == user.user_id])
>= limit
):
raise InvokeException(f"You can only have {limit} items in the queue")
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with AsyncSession(engine) as session:
async with session:
beatmap = await session.get(Beatmap, item.beatmap_id)
if beatmap is None:
raise InvokeException("Beatmap not found")
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
# TODO: mods validation
item.owner_id = user.user_id
item.star_rating = float(
beatmap.difficulty_rating
) # FIXME: beatmap use decimal
await Playlist.add_to_db(item, self.room.room_id, session)
self.room.playlist.append(item)
await self.hub.playlist_added(self.server_room, item)
await self.update_order()
await self.update_current_item()
async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser):
from app.database import Playlist
if item.freestyle and len(item.allowed_mods) > 0:
raise InvokeException("Freestyle items cannot have allowed mods")
async with AsyncSession(engine) as session:
async with session:
beatmap = await session.get(Beatmap, item.beatmap_id)
if beatmap is None:
raise InvokeException("Beatmap not found")
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
existing_item = next(
(i for i in self.room.playlist if i.id == item.id), None
)
if existing_item is None:
raise InvokeException(
"Attempted to change an item that doesn't exist"
)
if existing_item.owner_id != user.user_id and self.room.host != user:
raise InvokeException(
"Attempted to change an item which is not owned by the user"
)
if existing_item.expired:
raise InvokeException(
"Attempted to change an item which has already been played"
)
# TODO: mods validation
item.owner_id = user.user_id
item.star_rating = float(beatmap.difficulty_rating)
item.playlist_order = existing_item.playlist_order
await Playlist.update(item, self.room.room_id, session)
# Update item in playlist
for idx, playlist_item in enumerate(self.room.playlist):
if playlist_item.id == item.id:
self.room.playlist[idx] = item
break
await self.hub.playlist_changed(
self.server_room,
item,
beatmap_changed=item.beatmap_checksum
!= existing_item.beatmap_checksum,
)
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
from app.database import Playlist
item = next(
(i for i in self.room.playlist if i.id == playlist_item_id),
None,
)
if item is None:
raise InvokeException("Item does not exist in the room")
# Check if it's the only item and current item
if item == self.current_item:
upcoming_items = [i for i in self.room.playlist if not i.expired]
if len(upcoming_items) == 1:
raise InvokeException("The only item in the room cannot be removed")
if item.owner_id != user.user_id and self.room.host != user:
raise InvokeException(
"Attempted to remove an item which is not owned by the user"
)
if item.expired:
raise InvokeException(
"Attempted to remove an item which has already been played"
)
async with AsyncSession(engine) as session:
await Playlist.delete_item(item.id, self.room.room_id, session)
self.room.playlist.remove(item)
self.current_index = self.room.playlist.index(self.upcoming_items[0])
await self.update_order()
await self.update_current_item()
await self.hub.playlist_removed(self.server_room, item.id)
async def finish_current_item(self):
from app.database import Playlist
async with AsyncSession(engine) as session:
played_at = datetime.now(UTC)
await session.execute(
update(Playlist)
.where(
col(Playlist.id) == self.current_item.id,
col(Playlist.room_id) == self.room.room_id,
)
.values(expired=True, played_at=played_at)
)
self.room.playlist[self.current_index].expired = True
self.room.playlist[self.current_index].played_at = played_at
await self.hub.playlist_changed(self.server_room, self.current_item, True)
await self.update_order()
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
async def update_queue_mode(self):
if self.room.settings.queue_mode == QueueMode.HOST_ONLY and all(
playitem.expired for playitem in self.room.playlist
):
assert self.room.host
await self.add_item(self.current_item.clone(), self.room.host)
await self.update_order()
await self.update_current_item()
@property
def current_item(self):
return self.room.playlist[self.current_index]
@dataclass
class CountdownInfo:
countdown: MultiplayerCountdown
duration: timedelta
task: asyncio.Task | None = None
def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown
self.duration = (
countdown.remaining
if countdown.remaining > timedelta(seconds=0)
else timedelta(seconds=0)
)
class _MatchRequest(SignalRUnionMessage): ...
class ChangeTeamRequest(_MatchRequest):
union_type: ClassVar[Literal[0]] = 0
team_id: int
class StartMatchCountdownRequest(_MatchRequest):
union_type: ClassVar[Literal[1]] = 1
duration: timedelta
class StopCountdownRequest(_MatchRequest):
union_type: ClassVar[Literal[2]] = 2
id: int
MatchRequest = ChangeTeamRequest | StartMatchCountdownRequest | StopCountdownRequest
class MatchTypeHandler(ABC):
def __init__(self, room: "ServerMultiplayerRoom"):
self.room = room
self.hub = room.hub
@abstractmethod
async def handle_join(self, user: MultiplayerRoomUser): ...
@abstractmethod
async def handle_request(
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
@abstractmethod
async def handle_leave(self, user: MultiplayerRoomUser): ...
class HeadToHeadHandler(MatchTypeHandler):
@override
async def handle_join(self, user: MultiplayerRoomUser):
if user.match_state is not None:
user.match_state = None
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
class TeamVersusHandler(MatchTypeHandler):
@override
def __init__(self, room: "ServerMultiplayerRoom"):
super().__init__(room)
self.state = TeamVersusRoomState()
room.room.match_state = self.state
task = asyncio.create_task(self.hub.change_room_match_state(self.room))
self.hub.tasks.add(task)
task.add_done_callback(self.hub.tasks.discard)
def _get_best_available_team(self) -> int:
for team in self.state.teams:
if all(
(
user.match_state is None
or not isinstance(user.match_state, TeamVersusUserState)
or user.match_state.team_id != team.id
)
for user in self.room.room.users
):
return team.id
from collections import defaultdict
team_counts = defaultdict(int)
for user in self.room.room.users:
if user.match_state is not None and isinstance(
user.match_state, TeamVersusUserState
):
team_counts[user.match_state.team_id] += 1
if team_counts:
min_count = min(team_counts.values())
for team_id, count in team_counts.items():
if count == min_count:
return team_id
return self.state.teams[0].id if self.state.teams else 0
@override
async def handle_join(self, user: MultiplayerRoomUser):
best_team_id = self._get_best_available_team()
user.match_state = TeamVersusUserState(team_id=best_team_id)
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest):
if not isinstance(request, ChangeTeamRequest):
return
if request.team_id not in [team.id for team in self.state.teams]:
raise InvokeException("Invalid team ID")
user.match_state = TeamVersusUserState(team_id=request.team_id)
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
MATCH_TYPE_HANDLERS = {
MatchType.HEAD_TO_HEAD: HeadToHeadHandler,
MatchType.TEAM_VERSUS: TeamVersusHandler,
}
@dataclass
class ServerMultiplayerRoom:
room: MultiplayerRoom
category: RoomCategory
status: RoomStatus
start_at: datetime
hub: "MultiplayerHub"
match_type_handler: MatchTypeHandler
queue: MultiplayerQueue
_next_countdown_id: int
_countdown_id_lock: asyncio.Lock
_tracked_countdown: dict[int, CountdownInfo]
def __init__(
self,
room: MultiplayerRoom,
category: RoomCategory,
start_at: datetime,
hub: "MultiplayerHub",
):
self.room = room
self.category = category
self.status = RoomStatus.IDLE
self.start_at = start_at
self.hub = hub
self.queue = MultiplayerQueue(self)
self._next_countdown_id = 0
self._countdown_id_lock = asyncio.Lock()
self._tracked_countdown = {}
async def set_handler(self):
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](
self
)
for i in self.room.users:
await self.match_type_handler.handle_join(i)
async def get_next_countdown_id(self) -> int:
async with self._countdown_id_lock:
self._next_countdown_id += 1
return self._next_countdown_id
async def start_countdown(
self,
countdown: MultiplayerCountdown,
on_complete: Callable[["ServerMultiplayerRoom"], Awaitable[Any]] | None = None,
):
async def _countdown_task(self: "ServerMultiplayerRoom"):
await asyncio.sleep(info.duration.total_seconds())
await self.stop_countdown(countdown)
if on_complete is not None:
await on_complete(self)
if countdown.is_exclusive:
await self.stop_all_countdowns()
countdown.id = await self.get_next_countdown_id()
info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown)
self._tracked_countdown[countdown.id] = info
await self.hub.send_match_event(
self, CountdownStartedEvent(countdown=info.countdown)
)
info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown):
info = next(
(
info
for info in self._tracked_countdown.values()
if info.countdown.id == countdown.id
),
None,
)
if info is None:
return
if info.task is not None and not info.task.done():
info.task.cancel()
del self._tracked_countdown[countdown.id]
self.room.active_countdowns.remove(countdown)
await self.hub.send_match_event(self, CountdownStoppedEvent(id=countdown.id))
async def stop_all_countdowns(self):
for countdown in list(self._tracked_countdown.values()):
await self.stop_countdown(countdown.countdown)
self._tracked_countdown.clear()
self.room.active_countdowns.clear()
class _MatchServerEvent(BaseModel): ...
class CountdownStartedEvent(_MatchServerEvent):
countdown: MultiplayerCountdown
type: Literal[0] = Field(default=0, exclude=True)
class CountdownStoppedEvent(_MatchServerEvent):
id: int
type: Literal[1] = Field(default=1, exclude=True)
MatchServerEvent = CountdownStartedEvent | CountdownStoppedEvent
class GameplayAbortReason(IntEnum):
LOAD_TOOK_TOO_LONG = 0
HOST_ABORTED = 1