diff --git a/app/database/room.py b/app/database/room.py index 368a04a..54497db 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -54,7 +54,7 @@ class RoomBase(SQLModel, UTCBaseModel): auto_skip: bool auto_start_duration: int status: RoomStatus - # TODO: channel_id + channel_id: int | None = None class Room(AsyncAttrs, RoomBase, table=True): @@ -84,6 +84,7 @@ class RoomResp(RoomBase): current_playlist_item: PlaylistResp | None = None current_user_score: PlaylistAggregateScore | None = None recent_participants: list[UserResp] = Field(default_factory=list) + channel_id: int = 0 @classmethod async def from_db( @@ -93,7 +94,9 @@ class RoomResp(RoomBase): include: list[str] = [], user: User | None = None, ) -> "RoomResp": - resp = cls.model_validate(room.model_dump()) + d = room.model_dump() + d["channel_id"] = d.get("channel_id", 0) or 0 + resp = cls.model_validate(d) stats = RoomPlaylistItemStats(count_active=0, count_total=0) difficulty_range = RoomDifficultyRange( @@ -158,6 +161,7 @@ class RoomResp(RoomBase): # duration = room.settings.duration, starts_at=server_room.start_at, participant_count=len(room.users), + channel_id=server_room.room.channel_id or 0, ) return resp diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index 5f99136..9eec7b5 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -44,6 +44,7 @@ from sqlmodel import col from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: + from app.database.room import Room from app.signalr.hub import MultiplayerHub HOST_LIMIT = 50 @@ -348,7 +349,7 @@ class MultiplayerRoom(BaseModel): channel_id: int @classmethod - def from_db(cls, room) -> "MultiplayerRoom": + def from_db(cls, room: "Room") -> "MultiplayerRoom": """ 将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型) """ @@ -358,7 +359,7 @@ class MultiplayerRoom(BaseModel): host_user = MultiplayerRoomUser(user_id=room.host_id) # playlist 转换 playlist = [] - if hasattr(room, "playlist"): + if room.playlist: for item in room.playlist: playlist.append( PlaylistItem( @@ -396,7 +397,7 @@ class MultiplayerRoom(BaseModel): match_state=None, playlist=playlist, active_countdowns=[], - channel_id=getattr(room, "channel_id", 0), + channel_id=room.channel_id, ) diff --git a/app/router/chat/server.py b/app/router/chat/server.py index 2c91952..6add821 100644 --- a/app/router/chat/server.py +++ b/app/router/chat/server.py @@ -4,10 +4,16 @@ import asyncio from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp from app.database.lazer_user import User -from app.dependencies.database import DBFactory, get_db_factory, get_redis +from app.dependencies.database import ( + DBFactory, + engine, + get_db_factory, + get_redis, +) from app.dependencies.user import get_current_user from app.log import logger from app.models.chat import ChatEvent +from app.service.subscribers.chat import ChatSubscriber from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect from fastapi.security import SecurityScopes @@ -23,6 +29,9 @@ class ChatServer: self.redis: Redis = get_redis() self.tasks: set[asyncio.Task] = set() + self.ChatSubscriber = ChatSubscriber() + self.ChatSubscriber.chat_server = self + self._subscribed = False def _add_task(self, task): task = asyncio.create_task(task) @@ -158,7 +167,13 @@ class ChatServer: del self.channels[channel_id] channel_resp = await ChatChannelResp.from_db( - channel, session, user, self.redis, self.channels[channel_id] + channel, + session, + user, + self.redis, + self.channels.get(channel_id) + if channel.type != ChannelType.PUBLIC + else None, ) client = self.connect_client.get(user_id) if client: @@ -170,6 +185,30 @@ class ChatServer: ), ) + async def join_room_channel(self, channel_id: int, user_id: int): + async with AsyncSession(engine) as session: + channel = await ChatChannel.get(channel_id, session) + if channel is None: + return + + user = await session.get(User, user_id) + if user is None: + return + + await self.join_channel(user, channel, session) + + async def leave_room_channel(self, channel_id: int, user_id: int): + async with AsyncSession(engine) as session: + channel = await ChatChannel.get(channel_id, session) + if channel is None: + return + + user = await session.get(User, user_id) + if user is None: + return + + await self.leave_channel(user, channel, session) + server = ChatServer() @@ -207,6 +246,10 @@ async def chat_websocket( authorization: str = Header(...), factory: DBFactory = Depends(get_db_factory), ): + if not server._subscribed: + server._subscribed = True + await server.ChatSubscriber.start_subscribe() + async for session in factory(): token = authorization[7:] if ( diff --git a/app/router/v2/room.py b/app/router/v2/room.py index ff62b4c..9a66e8a 100644 --- a/app/router/v2/room.py +++ b/app/router/v2/room.py @@ -116,7 +116,7 @@ class APICreatedRoom(RoomResp): async def _participate_room( - room_id: int, user_id: int, db_room: Room, session: AsyncSession + room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis ): participated_user = ( await session.exec( @@ -138,6 +138,8 @@ async def _participate_room( participated_user.joined_at = datetime.now(UTC) db_room.participant_count += 1 + await redis.publish("chat:room:joined", f"{db_room.channel_id}:{user_id}") + @router.post( "/rooms", @@ -150,11 +152,12 @@ async def create_room( room: APIUploadedRoom, db: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), + redis: Redis = Depends(get_redis), ): assert current_user.id is not None user_id = current_user.id db_room = await create_playlist_room_from_api(db, room, user_id) - await _participate_room(db_room.id, user_id, db_room, db) + await _participate_room(db_room.id, user_id, db_room, db, redis) created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db)) created_room.error = "" return created_room @@ -219,11 +222,12 @@ async def add_user_to_room( room_id: int = Path(..., description="房间 ID"), user_id: int = Path(..., description="用户 ID"), db: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), current_user: User = Security(get_client_user), ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is not None: - await _participate_room(room_id, user_id, db_room, db) + await _participate_room(room_id, user_id, db_room, db, redis) await db.commit() await db.refresh(db_room) resp = await RoomResp.from_db(db_room, db) @@ -243,6 +247,7 @@ async def remove_user_from_room( user_id: int = Path(..., description="用户 ID"), db: AsyncSession = Depends(get_db), current_user: User = Security(get_client_user), + redis: Redis = Depends(get_redis), ): db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is not None: @@ -257,6 +262,7 @@ async def remove_user_from_room( if participated_user is not None: participated_user.left_at = datetime.now(UTC) db_room.participant_count -= 1 + await redis.publish("chat:room:left", f"{db_room.channel_id}:{user_id}") await db.commit() return None else: diff --git a/app/service/room.py b/app/service/room.py index d11dced..d6fdcc6 100644 --- a/app/service/room.py +++ b/app/service/room.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta from app.database.beatmap import Beatmap +from app.database.chat import ChannelType, ChatChannel from app.database.playlists import Playlist from app.database.room import APIUploadedRoom, Room from app.dependencies.fetcher import get_fetcher @@ -25,6 +26,18 @@ async def create_playlist_room_from_api( session.add(db_room) await session.commit() await session.refresh(db_room) + + channel = ChatChannel( + name=f"room_{db_room.id}", + description="Playlist room", + type=ChannelType.MULTIPLAYER, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(db_room) + db_room.channel_id = channel.channel_id + await add_playlists_to_room(session, db_room.id, room.playlist, host_id) await session.refresh(db_room) return db_room @@ -57,6 +70,18 @@ async def create_playlist_room( session.add(db_room) await session.commit() await session.refresh(db_room) + + channel = ChatChannel( + name=f"room_{db_room.id}", + description="Playlist room", + type=ChannelType.MULTIPLAYER, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(db_room) + db_room.channel_id = channel.channel_id + await add_playlists_to_room(session, db_room.id, playlist, host_id) await session.refresh(db_room) return db_room diff --git a/app/service/subscribers/base.py b/app/service/subscribers/base.py index 8fc5654..693d4d8 100644 --- a/app/service/subscribers/base.py +++ b/app/service/subscribers/base.py @@ -24,6 +24,11 @@ class RedisSubscriber: del self.handlers[channel] await self.pubsub.unsubscribe(channel) + def add_handler(self, channel: str, handler: Callable[[str, str], Awaitable[Any]]): + if channel not in self.handlers: + self.handlers[channel] = [] + self.handlers[channel].append(handler) + async def listen(self): while True: message = await self.pubsub.get_message( diff --git a/app/service/subscribers/chat.py b/app/service/subscribers/chat.py new file mode 100644 index 0000000..0c61c06 --- /dev/null +++ b/app/service/subscribers/chat.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .base import RedisSubscriber + +if TYPE_CHECKING: + from app.router.chat.server import ChatServer + + +JOIN_CHANNEL = "chat:room:joined" +EXIT_CHANNEL = "chat:room:left" + + +class ChatSubscriber(RedisSubscriber): + def __init__(self): + super().__init__() + self.room_subscriber: dict[int, list[int]] = {} + self.chat_server: "ChatServer | None" = None + + async def start_subscribe(self): + await self.subscribe(JOIN_CHANNEL) + self.add_handler(JOIN_CHANNEL, self.on_join_room) + await self.subscribe(EXIT_CHANNEL) + self.add_handler(EXIT_CHANNEL, self.on_leave_room) + self.start() + + async def on_join_room(self, c: str, s: str): + channel_id, user_id = s.split(":") + if self.chat_server is None: + return + await self.chat_server.join_room_channel(int(channel_id), int(user_id)) + + async def on_leave_room(self, c: str, s: str): + channel_id, user_id = s.split(":") + if self.chat_server is None: + return + await self.chat_server.leave_room_channel(int(channel_id), int(user_id)) diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 4ada615..20eecf6 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -6,6 +6,7 @@ from typing import override from app.database import Room from app.database.beatmap import Beatmap +from app.database.chat import ChannelType, ChatChannel from app.database.lazer_user import User from app.database.multiplayer_event import MultiplayerEvent from app.database.playlists import Playlist @@ -195,6 +196,18 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await session.commit() await session.refresh(db_room) + channel = ChatChannel( + name=f"room_{db_room.id}", + description="Multiplayer room", + type=ChannelType.MULTIPLAYER, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(db_room) + room.channel_id = channel.channel_id # pyright: ignore[reportAttributeAccessIssue] + db_room.channel_id = channel.channel_id + item = room.playlist[0] item.owner_id = client.user_id room.room_id = db_room.id @@ -280,6 +293,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if db_room is None: raise InvokeException("Room does not exist in database") db_room.participant_count += 1 + + redis = get_redis() + await redis.publish("chat:room:joined", f"{room.channel_id}:{user.user_id}") + return room async def change_beatmap_availability( @@ -914,6 +931,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if target_store: target_store.room_id = 0 + redis = get_redis() + await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}") + async def end_room(self, room: ServerMultiplayerRoom): assert room.room.host async with AsyncSession(engine) as session: diff --git a/migrations/versions/df9f725a077c_room_add_channel_id.py b/migrations/versions/df9f725a077c_room_add_channel_id.py new file mode 100644 index 0000000..85e7d93 --- /dev/null +++ b/migrations/versions/df9f725a077c_room_add_channel_id.py @@ -0,0 +1,54 @@ +"""room: add channel_id + +Revision ID: df9f725a077c +Revises: dd33d89aa2c2 +Create Date: 2025-08-16 08:05:28.748265 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "df9f725a077c" +down_revision: str | Sequence[str] | None = "dd33d89aa2c2" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=True + ) + op.alter_column( + "chat_silence_users", "banned_at", existing_type=mysql.DATETIME(), nullable=True + ) + op.create_index( + op.f("ix_chat_silence_users_id"), "chat_silence_users", ["id"], unique=False + ) + op.add_column("rooms", sa.Column("channel_id", sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("rooms", "channel_id") + op.drop_index(op.f("ix_chat_silence_users_id"), table_name="chat_silence_users") + op.alter_column( + "chat_silence_users", + "banned_at", + existing_type=mysql.DATETIME(), + nullable=False, + ) + op.alter_column( + "chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=False + ) + # ### end Alembic commands ###