diff --git a/app/database/playlists.py b/app/database/playlists.py index 42567b6..10ad86b 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -16,7 +16,10 @@ from sqlmodel import ( ForeignKey, Relationship, SQLModel, + func, + select, ) +from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from .room import Room @@ -59,9 +62,20 @@ class Playlist(PlaylistBase, table=True): room: "Room" = Relationship() @classmethod - async def from_hub(cls, playlist: PlaylistItem, room_id: int) -> "Playlist": + async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int: + stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where( + cls.room_id == room_id + ) + result = await session.exec(stmt) + return result.one() + + @classmethod + async def from_hub( + cls, playlist: PlaylistItem, room_id: int, session: AsyncSession + ) -> "Playlist": + next_id = await cls.get_next_id_for_room(room_id, session=session) return cls( - id=playlist.id, + id=next_id, owner_id=playlist.owner_id, ruleset_id=playlist.ruleset_id, beatmap_id=playlist.beatmap_id, @@ -74,6 +88,50 @@ class Playlist(PlaylistBase, table=True): room_id=room_id, ) + @classmethod + async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): + db_playlist = await session.exec( + select(cls).where(cls.id == playlist.id, cls.room_id == room_id) + ) + db_playlist = db_playlist.first() + if db_playlist is None: + raise ValueError("Playlist item not found") + db_playlist.owner_id = playlist.owner_id + db_playlist.ruleset_id = playlist.ruleset_id + db_playlist.beatmap_id = playlist.beatmap_id + db_playlist.required_mods = [ + msgpack_to_apimod(mod) for mod in playlist.required_mods + ] + db_playlist.allowed_mods = [ + msgpack_to_apimod(mod) for mod in playlist.allowed_mods + ] + db_playlist.expired = playlist.expired + db_playlist.playlist_order = playlist.order + db_playlist.played_at = playlist.played_at + db_playlist.freestyle = playlist.freestyle + await session.commit() + + @classmethod + async def add_to_db( + cls, playlist: PlaylistItem, room_id: int, session: AsyncSession + ): + db_playlist = await cls.from_hub(playlist, room_id, session) + session.add(db_playlist) + await session.commit() + await session.refresh(db_playlist) + playlist.id = db_playlist.id + + @classmethod + async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession): + db_playlist = await session.exec( + select(cls).where(cls.id == item_id, cls.room_id == room_id) + ) + db_playlist = db_playlist.first() + if db_playlist is None: + raise ValueError("Playlist item not found") + await session.delete(db_playlist) + await session.commit() + class PlaylistResp(PlaylistBase): beatmap: BeatmapResp | None = None diff --git a/app/signalr/exception.py b/app/exception.py similarity index 100% rename from app/signalr/exception.py rename to app/exception.py diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index fa5e935..39ced12 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -1,7 +1,12 @@ from __future__ import annotations +from dataclasses import dataclass import datetime -from typing import Annotated, Any, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal + +from app.database.beatmap import Beatmap +from app.dependencies.database import engine +from app.exception import InvokeException from .room import ( DownloadState, @@ -21,7 +26,14 @@ from .signalr import ( ) from msgpack_lazer_api import APIMod -from pydantic import BaseModel, Field, field_serializer, field_validator +from pydantic import Field, field_serializer, field_validator +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from app.signalr.hub import MultiplayerHub + +HOST_LIMIT = 50 +PER_USER_LIMIT = 3 class MultiplayerClientState(UserState): @@ -161,8 +173,246 @@ class MultiplayerRoom(MessagePackArrayModel): return msgpack_union_dump(v) -class ServerMultiplayerRoom(BaseModel): +class MultiplayerQueue: + def __init__(self, room: "ServerMultiplayerRoom", hub: "MultiplayerHub"): + self.server_room = room + self.hub = hub + self.current_index = 0 + + @property + def upcoming_items(self): + return sorted( + (item for item in self.room.playlist if not item.expired), + key=lambda i: i.order, + ) + + @property + def room(self): + return self.server_room.room + + async def update_order(self): + from app.database import Playlist + + match self.room.settings.queue_mode: + case QueueMode.ALL_PLAYERS_ROUND_ROBIN: + ordered_active_items = [] + + is_first_set = True + first_set_order_by_user_id = {} + + active_items = [item for item in self.room.playlist if not item.expired] + active_items.sort(key=lambda x: x.id) + + user_item_groups = {} + for item in active_items: + if item.owner_id not in user_item_groups: + user_item_groups[item.owner_id] = [] + user_item_groups[item.owner_id].append(item) + + max_items = max( + (len(items) for items in user_item_groups.values()), default=0 + ) + + for i in range(max_items): + current_set = [] + for user_id, items in user_item_groups.items(): + if i < len(items): + current_set.append(items[i]) + + if is_first_set: + current_set.sort(key=lambda item: (item.order, item.id)) + ordered_active_items.extend(current_set) + first_set_order_by_user_id = { + item.owner_id: idx + for idx, item in enumerate(ordered_active_items) + } + else: + current_set.sort( + key=lambda item: first_set_order_by_user_id.get( + item.owner_id, 0 + ) + ) + ordered_active_items.extend(current_set) + + is_first_set = False + + for idx, item in enumerate(ordered_active_items): + item.order = idx + case _: + ordered_active_items = sorted( + (item for item in self.room.playlist if not item.expired), + key=lambda x: x.id, + ) + async with AsyncSession(engine) as session: + for idx, item in enumerate(ordered_active_items): + if item.order == idx: + continue + item.order = idx + await Playlist.update(item, self.room.room_id, session) + await self.hub.playlist_changed( + self.server_room, item, beatmap_changed=False + ) + + async def update_current_item(self): + upcoming_items = self.upcoming_items + next_item = ( + upcoming_items[0] + if upcoming_items + else max( + self.room.playlist, + key=lambda i: i.played_at or datetime.datetime.min, + ) + ) + self.current_index = self.room.playlist.index(next_item) + last_id = self.room.settings.playlist_item_id + self.room.settings.playlist_item_id = next_item.id + if last_id != next_item.id: + await self.hub.setting_changed(self.server_room, True) + + async def add_item(self, item: PlaylistItem, user: MultiplayerRoomUser): + from app.database import Playlist + + is_host = self.room.host and self.room.host.user_id == user.user_id + if self.room.settings.queue_mode == QueueMode.HOST_ONLY and not is_host: + raise InvokeException("You are not the host") + + limit = HOST_LIMIT if is_host else PER_USER_LIMIT + if ( + len( + list( + filter( + lambda x: x.owner_id == user.user_id, + self.room.playlist, + ) + ) + ) + >= limit + ): + raise InvokeException(f"You can only have {limit} items in the queue") + + if item.freestyle and len(item.allowed_mods) > 0: + raise InvokeException("Freestyle items cannot have allowed mods") + + async with AsyncSession(engine) as session: + async with session: + beatmap = await session.get(Beatmap, item.beatmap_id) + if beatmap is None: + raise InvokeException("Beatmap not found") + if item.checksum != beatmap.checksum: + raise InvokeException("Checksum mismatch") + # TODO: mods validation + item.owner_id = user.user_id + item.star = float( + beatmap.difficulty_rating + ) # FIXME: beatmap use decimal + await Playlist.add_to_db(item, self.room.room_id, session) + self.room.playlist.append(item) + await self.hub.playlist_added(self.server_room, item) + await self.update_order() + await self.update_current_item() + + async def edit_item(self, item: PlaylistItem, user: MultiplayerRoomUser): + from app.database import Playlist + + if item.freestyle and len(item.allowed_mods) > 0: + raise InvokeException("Freestyle items cannot have allowed mods") + + async with AsyncSession(engine) as session: + async with session: + beatmap = await session.get(Beatmap, item.beatmap_id) + if beatmap is None: + raise InvokeException("Beatmap not found") + if item.checksum != beatmap.checksum: + raise InvokeException("Checksum mismatch") + + existing_item = next( + (i for i in self.room.playlist if i.id == item.id), None + ) + if existing_item is None: + raise InvokeException( + "Attempted to change an item that doesn't exist" + ) + + if existing_item.owner_id != user.user_id and self.room.host != user: + raise InvokeException( + "Attempted to change an item which is not owned by the user" + ) + + if existing_item.expired: + raise InvokeException( + "Attempted to change an item which has already been played" + ) + + # TODO: mods validation + item.owner_id = user.user_id + item.star = float(beatmap.difficulty_rating) + item.order = existing_item.order + + await Playlist.update(item, self.room.room_id, session) + + # Update item in playlist + for idx, playlist_item in enumerate(self.room.playlist): + if playlist_item.id == item.id: + self.room.playlist[idx] = item + break + + await self.hub.playlist_changed( + self.server_room, + item, + beatmap_changed=item.checksum != existing_item.checksum, + ) + + async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser): + from app.database import Playlist + + item = next( + (i for i in self.room.playlist if i.id == playlist_item_id), + None, + ) + + if item is None: + raise InvokeException("Item does not exist in the room") + + # Check if it's the only item and current item + if item == self.current_item: + upcoming_items = [i for i in self.room.playlist if not i.expired] + if len(upcoming_items) == 1: + raise InvokeException("The only item in the room cannot be removed") + + if item.owner_id != user.user_id and self.room.host != user: + raise InvokeException( + "Attempted to remove an item which is not owned by the user" + ) + + if item.expired: + raise InvokeException( + "Attempted to remove an item which has already been played" + ) + + async with AsyncSession(engine) as session: + await Playlist.delete_item(item.id, self.room.room_id, session) + + self.room.playlist.remove(item) + self.current_index = self.room.playlist.index(self.upcoming_items[0]) + + await self.update_order() + await self.update_current_item() + await self.hub.playlist_removed(self.server_room, item.id) + + @property + def current_item(self): + """Get the current playlist item""" + current_id = self.room.settings.playlist_item_id + return next( + (item for item in self.room.playlist if item.id == current_id), + None, + ) + + +@dataclass +class ServerMultiplayerRoom: room: MultiplayerRoom category: RoomCategory status: RoomStatus start_at: datetime.datetime + queue: MultiplayerQueue | None = None diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index 276140f..4e2c9d6 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -6,9 +6,9 @@ import time from typing import Any from app.config import settings +from app.exception import InvokeException from app.log import logger from app.models.signalr import UserState -from app.signalr.exception import InvokeException from app.signalr.packet import ( ClosePacket, CompletionPacket, diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 23ca69b..477396b 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -5,16 +5,19 @@ from typing import override from app.database import Room from app.database.playlists import Playlist from app.dependencies.database import engine +from app.exception import InvokeException from app.log import logger from app.models.multiplayer_hub import ( + BeatmapAvailability, MultiplayerClientState, + MultiplayerQueue, MultiplayerRoom, MultiplayerRoomUser, + PlaylistItem, ServerMultiplayerRoom, ) from app.models.room import RoomCategory, RoomStatus from app.models.signalr import serialize_to_list -from app.signalr.exception import InvokeException from .hub import Client, Hub @@ -40,6 +43,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): async def CreateRoom(self, client: Client, room: MultiplayerRoom): logger.info(f"[MultiplayerHub] {client.user_id} creating room") + store = self.get_or_create_state(client) + if store.room_id != 0: + raise InvokeException("You are already in a room") async with AsyncSession(engine) as session: async with session: db_room = Room( @@ -55,22 +61,22 @@ class MultiplayerHub(Hub[MultiplayerClientState]): session.add(db_room) await session.commit() await session.refresh(db_room) - playitem = room.playlist[0] - playitem.owner_id = client.user_id - playitem.order = 1 - db_playlist = await Playlist.from_hub(playitem, db_room.id) - session.add(db_playlist) + item = room.playlist[0] + item.owner_id = client.user_id room.room_id = db_room.id starts_at = db_room.starts_at - await session.commit() - await session.refresh(db_playlist) - # room.playlist.append() + await Playlist.add_to_db(item, db_room.id, session) server_room = ServerMultiplayerRoom( room=room, category=RoomCategory.NORMAL, status=RoomStatus.IDLE, start_at=starts_at, ) + queue = MultiplayerQueue( + room=server_room, + hub=self, + ) + server_room.queue = queue self.rooms[room.room_id] = server_room return await self.JoinRoomWithPassword( client, room.room_id, room.settings.password @@ -101,3 +107,115 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room.users.append(user) self.add_to_group(client, self.group_id(room_id)) return serialize_to_list(room) + + async def ChangeBeatmapAvailability( + self, client: Client, beatmap_availability: BeatmapAvailability + ): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + availability = user.availability + if ( + availability.state == beatmap_availability.state + and availability.progress == beatmap_availability.progress + ): + return + user.availability = availability + await self.broadcast_group_call( + self.group_id(store.room_id), + "UserBeatmapAvailabilityChanged", + user.user_id, + serialize_to_list(beatmap_availability), + ) + + async def AddPlaylistItem(self, client: Client, item: PlaylistItem): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + assert server_room.queue + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.add_item( + item, + user, + ) + + async def EditPlaylistItem(self, client: Client, item: PlaylistItem): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + assert server_room.queue + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.edit_item( + item, + user, + ) + + async def RemovePlaylistItem(self, client: Client, item_id: int): + store = self.get_or_create_state(client) + if store.room_id == 0: + raise InvokeException("You are not in a room") + if store.room_id not in self.rooms: + raise InvokeException("Room does not exist") + server_room = self.rooms[store.room_id] + room = server_room.room + assert server_room.queue + user = next((u for u in room.users if u.user_id == client.user_id), None) + if user is None: + raise InvokeException("You are not in this room") + + await server_room.queue.remove_item( + item_id, + user, + ) + + async def setting_changed(self, room: ServerMultiplayerRoom, beatmap_changed: bool): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "SettingsChanged", + serialize_to_list(room.room.settings), + ) + + async def playlist_added(self, room: ServerMultiplayerRoom, item: PlaylistItem): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemAdded", + serialize_to_list(item), + ) + + async def playlist_removed(self, room: ServerMultiplayerRoom, item_id: int): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemRemoved", + item_id, + ) + + async def playlist_changed( + self, room: ServerMultiplayerRoom, item: PlaylistItem, beatmap_changed: bool + ): + await self.broadcast_group_call( + self.group_id(room.room.room_id), + "PlaylistItemChanged", + serialize_to_list(item), + )