feat(chat): support mp/playlist chat

This commit is contained in:
MingxuanGame
2025-08-16 08:42:40 +00:00
parent 368bdfe588
commit 3de73f2420
9 changed files with 206 additions and 10 deletions

View File

@@ -54,7 +54,7 @@ class RoomBase(SQLModel, UTCBaseModel):
auto_skip: bool auto_skip: bool
auto_start_duration: int auto_start_duration: int
status: RoomStatus status: RoomStatus
# TODO: channel_id channel_id: int | None = None
class Room(AsyncAttrs, RoomBase, table=True): class Room(AsyncAttrs, RoomBase, table=True):
@@ -84,6 +84,7 @@ class RoomResp(RoomBase):
current_playlist_item: PlaylistResp | None = None current_playlist_item: PlaylistResp | None = None
current_user_score: PlaylistAggregateScore | None = None current_user_score: PlaylistAggregateScore | None = None
recent_participants: list[UserResp] = Field(default_factory=list) recent_participants: list[UserResp] = Field(default_factory=list)
channel_id: int = 0
@classmethod @classmethod
async def from_db( async def from_db(
@@ -93,7 +94,9 @@ class RoomResp(RoomBase):
include: list[str] = [], include: list[str] = [],
user: User | None = None, user: User | None = None,
) -> "RoomResp": ) -> "RoomResp":
resp = cls.model_validate(room.model_dump()) d = room.model_dump()
d["channel_id"] = d.get("channel_id", 0) or 0
resp = cls.model_validate(d)
stats = RoomPlaylistItemStats(count_active=0, count_total=0) stats = RoomPlaylistItemStats(count_active=0, count_total=0)
difficulty_range = RoomDifficultyRange( difficulty_range = RoomDifficultyRange(
@@ -158,6 +161,7 @@ class RoomResp(RoomBase):
# duration = room.settings.duration, # duration = room.settings.duration,
starts_at=server_room.start_at, starts_at=server_room.start_at,
participant_count=len(room.users), participant_count=len(room.users),
channel_id=server_room.room.channel_id or 0,
) )
return resp return resp

View File

@@ -44,6 +44,7 @@ from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
from app.database.room import Room
from app.signalr.hub import MultiplayerHub from app.signalr.hub import MultiplayerHub
HOST_LIMIT = 50 HOST_LIMIT = 50
@@ -348,7 +349,7 @@ class MultiplayerRoom(BaseModel):
channel_id: int channel_id: int
@classmethod @classmethod
def from_db(cls, room) -> "MultiplayerRoom": def from_db(cls, room: "Room") -> "MultiplayerRoom":
""" """
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型) 将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
""" """
@@ -358,7 +359,7 @@ class MultiplayerRoom(BaseModel):
host_user = MultiplayerRoomUser(user_id=room.host_id) host_user = MultiplayerRoomUser(user_id=room.host_id)
# playlist 转换 # playlist 转换
playlist = [] playlist = []
if hasattr(room, "playlist"): if room.playlist:
for item in room.playlist: for item in room.playlist:
playlist.append( playlist.append(
PlaylistItem( PlaylistItem(
@@ -396,7 +397,7 @@ class MultiplayerRoom(BaseModel):
match_state=None, match_state=None,
playlist=playlist, playlist=playlist,
active_countdowns=[], active_countdowns=[],
channel_id=getattr(room, "channel_id", 0), channel_id=room.channel_id,
) )

View File

@@ -4,10 +4,16 @@ import asyncio
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import DBFactory, get_db_factory, get_redis from app.dependencies.database import (
DBFactory,
engine,
get_db_factory,
get_redis,
)
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.log import logger from app.log import logger
from app.models.chat import ChatEvent from app.models.chat import ChatEvent
from app.service.subscribers.chat import ChatSubscriber
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes from fastapi.security import SecurityScopes
@@ -23,6 +29,9 @@ class ChatServer:
self.redis: Redis = get_redis() self.redis: Redis = get_redis()
self.tasks: set[asyncio.Task] = set() self.tasks: set[asyncio.Task] = set()
self.ChatSubscriber = ChatSubscriber()
self.ChatSubscriber.chat_server = self
self._subscribed = False
def _add_task(self, task): def _add_task(self, task):
task = asyncio.create_task(task) task = asyncio.create_task(task)
@@ -158,7 +167,13 @@ class ChatServer:
del self.channels[channel_id] del self.channels[channel_id]
channel_resp = await ChatChannelResp.from_db( channel_resp = await ChatChannelResp.from_db(
channel, session, user, self.redis, self.channels[channel_id] channel,
session,
user,
self.redis,
self.channels.get(channel_id)
if channel.type != ChannelType.PUBLIC
else None,
) )
client = self.connect_client.get(user_id) client = self.connect_client.get(user_id)
if client: if client:
@@ -170,6 +185,30 @@ class ChatServer:
), ),
) )
async def join_room_channel(self, channel_id: int, user_id: int):
async with AsyncSession(engine) as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.join_channel(user, channel, session)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with AsyncSession(engine) as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.leave_channel(user, channel, session)
server = ChatServer() server = ChatServer()
@@ -207,6 +246,10 @@ async def chat_websocket(
authorization: str = Header(...), authorization: str = Header(...),
factory: DBFactory = Depends(get_db_factory), factory: DBFactory = Depends(get_db_factory),
): ):
if not server._subscribed:
server._subscribed = True
await server.ChatSubscriber.start_subscribe()
async for session in factory(): async for session in factory():
token = authorization[7:] token = authorization[7:]
if ( if (

View File

@@ -116,7 +116,7 @@ class APICreatedRoom(RoomResp):
async def _participate_room( async def _participate_room(
room_id: int, user_id: int, db_room: Room, session: AsyncSession room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
): ):
participated_user = ( participated_user = (
await session.exec( await session.exec(
@@ -138,6 +138,8 @@ async def _participate_room(
participated_user.joined_at = datetime.now(UTC) participated_user.joined_at = datetime.now(UTC)
db_room.participant_count += 1 db_room.participant_count += 1
await redis.publish("chat:room:joined", f"{db_room.channel_id}:{user_id}")
@router.post( @router.post(
"/rooms", "/rooms",
@@ -150,11 +152,12 @@ async def create_room(
room: APIUploadedRoom, room: APIUploadedRoom,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
): ):
assert current_user.id is not None assert current_user.id is not None
user_id = current_user.id user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id) db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db) await _participate_room(db_room.id, user_id, db_room, db, redis)
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db)) created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
created_room.error = "" created_room.error = ""
return created_room return created_room
@@ -219,11 +222,12 @@ async def add_user_to_room(
room_id: int = Path(..., description="房间 ID"), room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"), user_id: int = Path(..., description="用户 ID"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None: if db_room is not None:
await _participate_room(room_id, user_id, db_room, db) await _participate_room(room_id, user_id, db_room, db, redis)
await db.commit() await db.commit()
await db.refresh(db_room) await db.refresh(db_room)
resp = await RoomResp.from_db(db_room, db) resp = await RoomResp.from_db(db_room, db)
@@ -243,6 +247,7 @@ async def remove_user_from_room(
user_id: int = Path(..., description="用户 ID"), user_id: int = Path(..., description="用户 ID"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
): ):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None: if db_room is not None:
@@ -257,6 +262,7 @@ async def remove_user_from_room(
if participated_user is not None: if participated_user is not None:
participated_user.left_at = datetime.now(UTC) participated_user.left_at = datetime.now(UTC)
db_room.participant_count -= 1 db_room.participant_count -= 1
await redis.publish("chat:room:left", f"{db_room.channel_id}:{user_id}")
await db.commit() await db.commit()
return None return None
else: else:

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from app.database.beatmap import Beatmap from app.database.beatmap import Beatmap
from app.database.chat import ChannelType, ChatChannel
from app.database.playlists import Playlist from app.database.playlists import Playlist
from app.database.room import APIUploadedRoom, Room from app.database.room import APIUploadedRoom, Room
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
@@ -25,6 +26,18 @@ async def create_playlist_room_from_api(
session.add(db_room) session.add(db_room)
await session.commit() await session.commit()
await session.refresh(db_room) await session.refresh(db_room)
channel = ChatChannel(
name=f"room_{db_room.id}",
description="Playlist room",
type=ChannelType.MULTIPLAYER,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(db_room)
db_room.channel_id = channel.channel_id
await add_playlists_to_room(session, db_room.id, room.playlist, host_id) await add_playlists_to_room(session, db_room.id, room.playlist, host_id)
await session.refresh(db_room) await session.refresh(db_room)
return db_room return db_room
@@ -57,6 +70,18 @@ async def create_playlist_room(
session.add(db_room) session.add(db_room)
await session.commit() await session.commit()
await session.refresh(db_room) await session.refresh(db_room)
channel = ChatChannel(
name=f"room_{db_room.id}",
description="Playlist room",
type=ChannelType.MULTIPLAYER,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(db_room)
db_room.channel_id = channel.channel_id
await add_playlists_to_room(session, db_room.id, playlist, host_id) await add_playlists_to_room(session, db_room.id, playlist, host_id)
await session.refresh(db_room) await session.refresh(db_room)
return db_room return db_room

View File

@@ -24,6 +24,11 @@ class RedisSubscriber:
del self.handlers[channel] del self.handlers[channel]
await self.pubsub.unsubscribe(channel) await self.pubsub.unsubscribe(channel)
def add_handler(self, channel: str, handler: Callable[[str, str], Awaitable[Any]]):
if channel not in self.handlers:
self.handlers[channel] = []
self.handlers[channel].append(handler)
async def listen(self): async def listen(self):
while True: while True:
message = await self.pubsub.get_message( message = await self.pubsub.get_message(

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from .base import RedisSubscriber
if TYPE_CHECKING:
from app.router.chat.server import ChatServer
JOIN_CHANNEL = "chat:room:joined"
EXIT_CHANNEL = "chat:room:left"
class ChatSubscriber(RedisSubscriber):
def __init__(self):
super().__init__()
self.room_subscriber: dict[int, list[int]] = {}
self.chat_server: "ChatServer | None" = None
async def start_subscribe(self):
await self.subscribe(JOIN_CHANNEL)
self.add_handler(JOIN_CHANNEL, self.on_join_room)
await self.subscribe(EXIT_CHANNEL)
self.add_handler(EXIT_CHANNEL, self.on_leave_room)
self.start()
async def on_join_room(self, c: str, s: str):
channel_id, user_id = s.split(":")
if self.chat_server is None:
return
await self.chat_server.join_room_channel(int(channel_id), int(user_id))
async def on_leave_room(self, c: str, s: str):
channel_id, user_id = s.split(":")
if self.chat_server is None:
return
await self.chat_server.leave_room_channel(int(channel_id), int(user_id))

View File

@@ -6,6 +6,7 @@ from typing import override
from app.database import Room from app.database import Room
from app.database.beatmap import Beatmap from app.database.beatmap import Beatmap
from app.database.chat import ChannelType, ChatChannel
from app.database.lazer_user import User from app.database.lazer_user import User
from app.database.multiplayer_event import MultiplayerEvent from app.database.multiplayer_event import MultiplayerEvent
from app.database.playlists import Playlist from app.database.playlists import Playlist
@@ -195,6 +196,18 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
await session.commit() await session.commit()
await session.refresh(db_room) await session.refresh(db_room)
channel = ChatChannel(
name=f"room_{db_room.id}",
description="Multiplayer room",
type=ChannelType.MULTIPLAYER,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(db_room)
room.channel_id = channel.channel_id # pyright: ignore[reportAttributeAccessIssue]
db_room.channel_id = channel.channel_id
item = room.playlist[0] item = room.playlist[0]
item.owner_id = client.user_id item.owner_id = client.user_id
room.room_id = db_room.id room.room_id = db_room.id
@@ -280,6 +293,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if db_room is None: if db_room is None:
raise InvokeException("Room does not exist in database") raise InvokeException("Room does not exist in database")
db_room.participant_count += 1 db_room.participant_count += 1
redis = get_redis()
await redis.publish("chat:room:joined", f"{room.channel_id}:{user.user_id}")
return room return room
async def change_beatmap_availability( async def change_beatmap_availability(
@@ -914,6 +931,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if target_store: if target_store:
target_store.room_id = 0 target_store.room_id = 0
redis = get_redis()
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
async def end_room(self, room: ServerMultiplayerRoom): async def end_room(self, room: ServerMultiplayerRoom):
assert room.room.host assert room.room.host
async with AsyncSession(engine) as session: async with AsyncSession(engine) as session:

View File

@@ -0,0 +1,54 @@
"""room: add channel_id
Revision ID: df9f725a077c
Revises: dd33d89aa2c2
Create Date: 2025-08-16 08:05:28.748265
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "df9f725a077c"
down_revision: str | Sequence[str] | None = "dd33d89aa2c2"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=True
)
op.alter_column(
"chat_silence_users", "banned_at", existing_type=mysql.DATETIME(), nullable=True
)
op.create_index(
op.f("ix_chat_silence_users_id"), "chat_silence_users", ["id"], unique=False
)
op.add_column("rooms", sa.Column("channel_id", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("rooms", "channel_id")
op.drop_index(op.f("ix_chat_silence_users_id"), table_name="chat_silence_users")
op.alter_column(
"chat_silence_users",
"banned_at",
existing_type=mysql.DATETIME(),
nullable=False,
)
op.alter_column(
"chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=False
)
# ### end Alembic commands ###