feat(chat): support mp/playlist chat
This commit is contained in:
@@ -54,7 +54,7 @@ class RoomBase(SQLModel, UTCBaseModel):
|
||||
auto_skip: bool
|
||||
auto_start_duration: int
|
||||
status: RoomStatus
|
||||
# TODO: channel_id
|
||||
channel_id: int | None = None
|
||||
|
||||
|
||||
class Room(AsyncAttrs, RoomBase, table=True):
|
||||
@@ -84,6 +84,7 @@ class RoomResp(RoomBase):
|
||||
current_playlist_item: PlaylistResp | None = None
|
||||
current_user_score: PlaylistAggregateScore | None = None
|
||||
recent_participants: list[UserResp] = Field(default_factory=list)
|
||||
channel_id: int = 0
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
@@ -93,7 +94,9 @@ class RoomResp(RoomBase):
|
||||
include: list[str] = [],
|
||||
user: User | None = None,
|
||||
) -> "RoomResp":
|
||||
resp = cls.model_validate(room.model_dump())
|
||||
d = room.model_dump()
|
||||
d["channel_id"] = d.get("channel_id", 0) or 0
|
||||
resp = cls.model_validate(d)
|
||||
|
||||
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
|
||||
difficulty_range = RoomDifficultyRange(
|
||||
@@ -158,6 +161,7 @@ class RoomResp(RoomBase):
|
||||
# duration = room.settings.duration,
|
||||
starts_at=server_room.start_at,
|
||||
participant_count=len(room.users),
|
||||
channel_id=server_room.room.channel_id or 0,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ from sqlmodel import col
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.room import Room
|
||||
from app.signalr.hub import MultiplayerHub
|
||||
|
||||
HOST_LIMIT = 50
|
||||
@@ -348,7 +349,7 @@ class MultiplayerRoom(BaseModel):
|
||||
channel_id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, room) -> "MultiplayerRoom":
|
||||
def from_db(cls, room: "Room") -> "MultiplayerRoom":
|
||||
"""
|
||||
将 Room (数据库模型) 转换为 MultiplayerRoom (业务模型)
|
||||
"""
|
||||
@@ -358,7 +359,7 @@ class MultiplayerRoom(BaseModel):
|
||||
host_user = MultiplayerRoomUser(user_id=room.host_id)
|
||||
# playlist 转换
|
||||
playlist = []
|
||||
if hasattr(room, "playlist"):
|
||||
if room.playlist:
|
||||
for item in room.playlist:
|
||||
playlist.append(
|
||||
PlaylistItem(
|
||||
@@ -396,7 +397,7 @@ class MultiplayerRoom(BaseModel):
|
||||
match_state=None,
|
||||
playlist=playlist,
|
||||
active_countdowns=[],
|
||||
channel_id=getattr(room, "channel_id", 0),
|
||||
channel_id=room.channel_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,10 +4,16 @@ import asyncio
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import DBFactory, get_db_factory, get_redis
|
||||
from app.dependencies.database import (
|
||||
DBFactory,
|
||||
engine,
|
||||
get_db_factory,
|
||||
get_redis,
|
||||
)
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
from app.models.chat import ChatEvent
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
@@ -23,6 +29,9 @@ class ChatServer:
|
||||
self.redis: Redis = get_redis()
|
||||
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
self.ChatSubscriber = ChatSubscriber()
|
||||
self.ChatSubscriber.chat_server = self
|
||||
self._subscribed = False
|
||||
|
||||
def _add_task(self, task):
|
||||
task = asyncio.create_task(task)
|
||||
@@ -158,7 +167,13 @@ class ChatServer:
|
||||
del self.channels[channel_id]
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, user, self.redis, self.channels[channel_id]
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels.get(channel_id)
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
@@ -170,6 +185,30 @@ class ChatServer:
|
||||
),
|
||||
)
|
||||
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with AsyncSession(engine) as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.join_channel(user, channel, session)
|
||||
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with AsyncSession(engine) as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.leave_channel(user, channel, session)
|
||||
|
||||
|
||||
server = ChatServer()
|
||||
|
||||
@@ -207,6 +246,10 @@ async def chat_websocket(
|
||||
authorization: str = Header(...),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
):
|
||||
if not server._subscribed:
|
||||
server._subscribed = True
|
||||
await server.ChatSubscriber.start_subscribe()
|
||||
|
||||
async for session in factory():
|
||||
token = authorization[7:]
|
||||
if (
|
||||
|
||||
@@ -116,7 +116,7 @@ class APICreatedRoom(RoomResp):
|
||||
|
||||
|
||||
async def _participate_room(
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
|
||||
):
|
||||
participated_user = (
|
||||
await session.exec(
|
||||
@@ -138,6 +138,8 @@ async def _participate_room(
|
||||
participated_user.joined_at = datetime.now(UTC)
|
||||
db_room.participant_count += 1
|
||||
|
||||
await redis.publish("chat:room:joined", f"{db_room.channel_id}:{user_id}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rooms",
|
||||
@@ -150,11 +152,12 @@ async def create_room(
|
||||
room: APIUploadedRoom,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
await _participate_room(db_room.id, user_id, db_room, db)
|
||||
await _participate_room(db_room.id, user_id, db_room, db, redis)
|
||||
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
|
||||
created_room.error = ""
|
||||
return created_room
|
||||
@@ -219,11 +222,12 @@ async def add_user_to_room(
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is not None:
|
||||
await _participate_room(room_id, user_id, db_room, db)
|
||||
await _participate_room(room_id, user_id, db_room, db, redis)
|
||||
await db.commit()
|
||||
await db.refresh(db_room)
|
||||
resp = await RoomResp.from_db(db_room, db)
|
||||
@@ -243,6 +247,7 @@ async def remove_user_from_room(
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is not None:
|
||||
@@ -257,6 +262,7 @@ async def remove_user_from_room(
|
||||
if participated_user is not None:
|
||||
participated_user.left_at = datetime.now(UTC)
|
||||
db_room.participant_count -= 1
|
||||
await redis.publish("chat:room:left", f"{db_room.channel_id}:{user_id}")
|
||||
await db.commit()
|
||||
return None
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.chat import ChannelType, ChatChannel
|
||||
from app.database.playlists import Playlist
|
||||
from app.database.room import APIUploadedRoom, Room
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
@@ -25,6 +26,18 @@ async def create_playlist_room_from_api(
|
||||
session.add(db_room)
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
|
||||
channel = ChatChannel(
|
||||
name=f"room_{db_room.id}",
|
||||
description="Playlist room",
|
||||
type=ChannelType.MULTIPLAYER,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(db_room)
|
||||
db_room.channel_id = channel.channel_id
|
||||
|
||||
await add_playlists_to_room(session, db_room.id, room.playlist, host_id)
|
||||
await session.refresh(db_room)
|
||||
return db_room
|
||||
@@ -57,6 +70,18 @@ async def create_playlist_room(
|
||||
session.add(db_room)
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
|
||||
channel = ChatChannel(
|
||||
name=f"room_{db_room.id}",
|
||||
description="Playlist room",
|
||||
type=ChannelType.MULTIPLAYER,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(db_room)
|
||||
db_room.channel_id = channel.channel_id
|
||||
|
||||
await add_playlists_to_room(session, db_room.id, playlist, host_id)
|
||||
await session.refresh(db_room)
|
||||
return db_room
|
||||
|
||||
@@ -24,6 +24,11 @@ class RedisSubscriber:
|
||||
del self.handlers[channel]
|
||||
await self.pubsub.unsubscribe(channel)
|
||||
|
||||
def add_handler(self, channel: str, handler: Callable[[str, str], Awaitable[Any]]):
|
||||
if channel not in self.handlers:
|
||||
self.handlers[channel] = []
|
||||
self.handlers[channel].append(handler)
|
||||
|
||||
async def listen(self):
|
||||
while True:
|
||||
message = await self.pubsub.get_message(
|
||||
|
||||
38
app/service/subscribers/chat.py
Normal file
38
app/service/subscribers/chat.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import RedisSubscriber
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.router.chat.server import ChatServer
|
||||
|
||||
|
||||
JOIN_CHANNEL = "chat:room:joined"
|
||||
EXIT_CHANNEL = "chat:room:left"
|
||||
|
||||
|
||||
class ChatSubscriber(RedisSubscriber):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.room_subscriber: dict[int, list[int]] = {}
|
||||
self.chat_server: "ChatServer | None" = None
|
||||
|
||||
async def start_subscribe(self):
|
||||
await self.subscribe(JOIN_CHANNEL)
|
||||
self.add_handler(JOIN_CHANNEL, self.on_join_room)
|
||||
await self.subscribe(EXIT_CHANNEL)
|
||||
self.add_handler(EXIT_CHANNEL, self.on_leave_room)
|
||||
self.start()
|
||||
|
||||
async def on_join_room(self, c: str, s: str):
|
||||
channel_id, user_id = s.split(":")
|
||||
if self.chat_server is None:
|
||||
return
|
||||
await self.chat_server.join_room_channel(int(channel_id), int(user_id))
|
||||
|
||||
async def on_leave_room(self, c: str, s: str):
|
||||
channel_id, user_id = s.split(":")
|
||||
if self.chat_server is None:
|
||||
return
|
||||
await self.chat_server.leave_room_channel(int(channel_id), int(user_id))
|
||||
@@ -6,6 +6,7 @@ from typing import override
|
||||
|
||||
from app.database import Room
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.chat import ChannelType, ChatChannel
|
||||
from app.database.lazer_user import User
|
||||
from app.database.multiplayer_event import MultiplayerEvent
|
||||
from app.database.playlists import Playlist
|
||||
@@ -195,6 +196,18 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
|
||||
channel = ChatChannel(
|
||||
name=f"room_{db_room.id}",
|
||||
description="Multiplayer room",
|
||||
type=ChannelType.MULTIPLAYER,
|
||||
)
|
||||
session.add(channel)
|
||||
await session.commit()
|
||||
await session.refresh(channel)
|
||||
await session.refresh(db_room)
|
||||
room.channel_id = channel.channel_id # pyright: ignore[reportAttributeAccessIssue]
|
||||
db_room.channel_id = channel.channel_id
|
||||
|
||||
item = room.playlist[0]
|
||||
item.owner_id = client.user_id
|
||||
room.room_id = db_room.id
|
||||
@@ -280,6 +293,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
if db_room is None:
|
||||
raise InvokeException("Room does not exist in database")
|
||||
db_room.participant_count += 1
|
||||
|
||||
redis = get_redis()
|
||||
await redis.publish("chat:room:joined", f"{room.channel_id}:{user.user_id}")
|
||||
|
||||
return room
|
||||
|
||||
async def change_beatmap_availability(
|
||||
@@ -914,6 +931,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]):
|
||||
if target_store:
|
||||
target_store.room_id = 0
|
||||
|
||||
redis = get_redis()
|
||||
await redis.publish("chat:room:left", f"{room.room.channel_id}:{user.user_id}")
|
||||
|
||||
async def end_room(self, room: ServerMultiplayerRoom):
|
||||
assert room.room.host
|
||||
async with AsyncSession(engine) as session:
|
||||
|
||||
54
migrations/versions/df9f725a077c_room_add_channel_id.py
Normal file
54
migrations/versions/df9f725a077c_room_add_channel_id.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""room: add channel_id
|
||||
|
||||
Revision ID: df9f725a077c
|
||||
Revises: dd33d89aa2c2
|
||||
Create Date: 2025-08-16 08:05:28.748265
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "df9f725a077c"
|
||||
down_revision: str | Sequence[str] | None = "dd33d89aa2c2"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=True
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_silence_users", "banned_at", existing_type=mysql.DATETIME(), nullable=True
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_chat_silence_users_id"), "chat_silence_users", ["id"], unique=False
|
||||
)
|
||||
op.add_column("rooms", sa.Column("channel_id", sa.Integer(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("rooms", "channel_id")
|
||||
op.drop_index(op.f("ix_chat_silence_users_id"), table_name="chat_silence_users")
|
||||
op.alter_column(
|
||||
"chat_silence_users",
|
||||
"banned_at",
|
||||
existing_type=mysql.DATETIME(),
|
||||
nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=False
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user