feat(chat): support mp/playlist chat

This commit is contained in:
MingxuanGame
2025-08-16 08:42:40 +00:00
parent 368bdfe588
commit 3de73f2420
9 changed files with 206 additions and 10 deletions

View File

@@ -54,7 +54,7 @@ class RoomBase(SQLModel, UTCBaseModel):
auto_skip: bool
auto_start_duration: int
status: RoomStatus
# TODO: channel_id
channel_id: int | None = None
class Room(AsyncAttrs, RoomBase, table=True):
@@ -84,6 +84,7 @@ class RoomResp(RoomBase):
current_playlist_item: PlaylistResp | None = None
current_user_score: PlaylistAggregateScore | None = None
recent_participants: list[UserResp] = Field(default_factory=list)
channel_id: int = 0
@classmethod
async def from_db(
@@ -93,7 +94,9 @@ class RoomResp(RoomBase):
include: list[str] = [],
user: User | None = None,
) -> "RoomResp":
resp = cls.model_validate(room.model_dump())
d = room.model_dump()
d["channel_id"] = d.get("channel_id", 0) or 0
resp = cls.model_validate(d)
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
difficulty_range = RoomDifficultyRange(
@@ -158,6 +161,7 @@ class RoomResp(RoomBase):
# duration = room.settings.duration,
starts_at=server_room.start_at,
participant_count=len(room.users),
channel_id=server_room.room.channel_id or 0,
)
return resp

View File

@@ -44,6 +44,7 @@ from sqlmodel import col
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.database.room import Room
from app.signalr.hub import MultiplayerHub
HOST_LIMIT = 50
@@ -348,7 +349,7 @@ class MultiplayerRoom(BaseModel):
channel_id: int
@classmethod
def from_db(cls, room) -> "MultiplayerRoom":
def from_db(cls, room: "Room") -> "MultiplayerRoom":
"""
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
"""
@@ -358,7 +359,7 @@ class MultiplayerRoom(BaseModel):
host_user = MultiplayerRoomUser(user_id=room.host_id)
# playlist 转换
playlist = []
if hasattr(room, "playlist"):
if room.playlist:
for item in room.playlist:
playlist.append(
PlaylistItem(
@@ -396,7 +397,7 @@ class MultiplayerRoom(BaseModel):
match_state=None,
playlist=playlist,
active_countdowns=[],
channel_id=getattr(room, "channel_id", 0),
channel_id=room.channel_id,
)

View File

@@ -4,10 +4,16 @@ import asyncio
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database.lazer_user import User
from app.dependencies.database import DBFactory, get_db_factory, get_redis
from app.dependencies.database import (
DBFactory,
engine,
get_db_factory,
get_redis,
)
from app.dependencies.user import get_current_user
from app.log import logger
from app.models.chat import ChatEvent
from app.service.subscribers.chat import ChatSubscriber
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
@@ -23,6 +29,9 @@ class ChatServer:
self.redis: Redis = get_redis()
self.tasks: set[asyncio.Task] = set()
self.ChatSubscriber = ChatSubscriber()
self.ChatSubscriber.chat_server = self
self._subscribed = False
def _add_task(self, task):
task = asyncio.create_task(task)
@@ -158,7 +167,13 @@ class ChatServer:
del self.channels[channel_id]
channel_resp = await ChatChannelResp.from_db(
channel, session, user, self.redis, self.channels[channel_id]
channel,
session,
user,
self.redis,
self.channels.get(channel_id)
if channel.type != ChannelType.PUBLIC
else None,
)
client = self.connect_client.get(user_id)
if client:
@@ -170,6 +185,30 @@ class ChatServer:
),
)
async def join_room_channel(self, channel_id: int, user_id: int):
async with AsyncSession(engine) as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.join_channel(user, channel, session)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with AsyncSession(engine) as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.leave_channel(user, channel, session)
server = ChatServer()
@@ -207,6 +246,10 @@ async def chat_websocket(
authorization: str = Header(...),
factory: DBFactory = Depends(get_db_factory),
):
if not server._subscribed:
server._subscribed = True
await server.ChatSubscriber.start_subscribe()
async for session in factory():
token = authorization[7:]
if (

View File

@@ -116,7 +116,7 @@ class APICreatedRoom(RoomResp):
async def _participate_room(
room_id: int, user_id: int, db_room: Room, session: AsyncSession
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
):
participated_user = (
await session.exec(
@@ -138,6 +138,8 @@ async def _participate_room(
participated_user.joined_at = datetime.now(UTC)
db_room.participant_count += 1
await redis.publish("chat:room:joined", f"{db_room.channel_id}:{user_id}")
@router.post(
"/rooms",
@@ -150,11 +152,12 @@ async def create_room(
room: APIUploadedRoom,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
assert current_user.id is not None
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db)
await _participate_room(db_room.id, user_id, db_room, db, redis)
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
created_room.error = ""
return created_room
@@ -219,11 +222,12 @@ async def add_user_to_room(
room_id: int = Path(..., description="房间 ID"),
user_id: int = Path(..., description="用户 ID"),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None:
await _participate_room(room_id, user_id, db_room, db)
await _participate_room(room_id, user_id, db_room, db, redis)
await db.commit()
await db.refresh(db_room)
resp = await RoomResp.from_db(db_room, db)
@@ -243,6 +247,7 @@ async def remove_user_from_room(
user_id: int = Path(..., description="用户 ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None:
@@ -257,6 +262,7 @@ async def remove_user_from_room(
if participated_user is not None:
participated_user.left_at = datetime.now(UTC)
db_room.participant_count -= 1
await redis.publish("chat:room:left", f"{db_room.channel_id}:{user_id}")
await db.commit()
return None
else:

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
from app.database.beatmap import Beatmap
from app.database.chat import ChannelType, ChatChannel
from app.database.playlists import Playlist
from app.database.room import APIUploadedRoom, Room
from app.dependencies.fetcher import get_fetcher
@@ -25,6 +26,18 @@ async def create_playlist_room_from_api(
session.add(db_room)
await session.commit()
await session.refresh(db_room)
channel = ChatChannel(
name=f"room_{db_room.id}",
description="Playlist room",
type=ChannelType.MULTIPLAYER,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(db_room)
db_room.channel_id = channel.channel_id
await add_playlists_to_room(session, db_room.id, room.playlist, host_id)
await session.refresh(db_room)
return db_room
@@ -57,6 +70,18 @@ async def create_playlist_room(
session.add(db_room)
await session.commit()
await session.refresh(db_room)
channel = ChatChannel(
name=f"room_{db_room.id}",
description="Playlist room",
type=ChannelType.MULTIPLAYER,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(db_room)
db_room.channel_id = channel.channel_id
await add_playlists_to_room(session, db_room.id, playlist, host_id)
await session.refresh(db_room)
return db_room

View File

@@ -24,6 +24,11 @@ class RedisSubscriber:
del self.handlers[channel]
await self.pubsub.unsubscribe(channel)
def add_handler(self, channel: str, handler: Callable[[str, str], Awaitable[Any]]):
if channel not in self.handlers:
self.handlers[channel] = []
self.handlers[channel].append(handler)
async def listen(self):
while True:
message = await self.pubsub.get_message(

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from .base import RedisSubscriber
if TYPE_CHECKING:
from app.router.chat.server import ChatServer
JOIN_CHANNEL = "chat:room:joined"
EXIT_CHANNEL = "chat:room:left"
class ChatSubscriber(RedisSubscriber):
def __init__(self):
super().__init__()
self.room_subscriber: dict[int, list[int]] = {}
self.chat_server: "ChatServer | None" = None
async def start_subscribe(self):
await self.subscribe(JOIN_CHANNEL)
self.add_handler(JOIN_CHANNEL, self.on_join_room)
await self.subscribe(EXIT_CHANNEL)
self.add_handler(EXIT_CHANNEL, self.on_leave_room)
self.start()
async def on_join_room(self, c: str, s: str):
channel_id, user_id = s.split(":")
if self.chat_server is None:
return
await self.chat_server.join_room_channel(int(channel_id), int(user_id))
async def on_leave_room(self, c: str, s: str):
channel_id, user_id = s.split(":")
if self.chat_server is None:
return
await self.chat_server.leave_room_channel(int(channel_id), int(user_id))

View File

@@ -6,6 +6,7 @@ from typing import override
from app.database import Room
from app.database.beatmap import Beatmap
from app.database.chat import ChannelType, ChatChannel
from app.database.lazer_user import User
from app.database.multiplayer_event import MultiplayerEvent
from app.database.playlists import Playlist
@@ -195,6 +196,18 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
await session.commit()
await session.refresh(db_room)
channel = ChatChannel(
name=f"room_{db_room.id}",
description="Multiplayer room",
type=ChannelType.MULTIPLAYER,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(db_room)
room.channel_id = channel.channel_id # pyright: ignore[reportAttributeAccessIssue]
db_room.channel_id = channel.channel_id
item = room.playlist[0]
item.owner_id = client.user_id
room.room_id = db_room.id
@@ -280,6 +293,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if db_room is None:
raise InvokeException("Room does not exist in database")
db_room.participant_count += 1
redis = get_redis()
await redis.publish("chat:room:joined", f"{room.channel_id}:{user.user_id}")
return room
async def change_beatmap_availability(
@@ -914,6 +931,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
if target_store:
target_store.room_id = 0
redis = get_redis()
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
async def end_room(self, room: ServerMultiplayerRoom):
assert room.room.host
async with AsyncSession(engine) as session:

View File

@@ -0,0 +1,54 @@
"""room: add channel_id
Revision ID: df9f725a077c
Revises: dd33d89aa2c2
Create Date: 2025-08-16 08:05:28.748265
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "df9f725a077c"
down_revision: str | Sequence[str] | None = "dd33d89aa2c2"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=True
)
op.alter_column(
"chat_silence_users", "banned_at", existing_type=mysql.DATETIME(), nullable=True
)
op.create_index(
op.f("ix_chat_silence_users_id"), "chat_silence_users", ["id"], unique=False
)
op.add_column("rooms", sa.Column("channel_id", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("rooms", "channel_id")
op.drop_index(op.f("ix_chat_silence_users_id"), table_name="chat_silence_users")
op.alter_column(
"chat_silence_users",
"banned_at",
existing_type=mysql.DATETIME(),
nullable=False,
)
op.alter_column(
"chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=False
)
# ### end Alembic commands ###