feat(chat): support mp/playlist chat
This commit is contained in:
@@ -4,10 +4,16 @@ import asyncio
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import DBFactory, get_db_factory, get_redis
|
||||
from app.dependencies.database import (
|
||||
DBFactory,
|
||||
engine,
|
||||
get_db_factory,
|
||||
get_redis,
|
||||
)
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
from app.models.chat import ChatEvent
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
@@ -23,6 +29,9 @@ class ChatServer:
|
||||
self.redis: Redis = get_redis()
|
||||
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
self.ChatSubscriber = ChatSubscriber()
|
||||
self.ChatSubscriber.chat_server = self
|
||||
self._subscribed = False
|
||||
|
||||
def _add_task(self, task):
|
||||
task = asyncio.create_task(task)
|
||||
@@ -158,7 +167,13 @@ class ChatServer:
|
||||
del self.channels[channel_id]
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, user, self.redis, self.channels[channel_id]
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels.get(channel_id)
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
client = self.connect_client.get(user_id)
|
||||
if client:
|
||||
@@ -170,6 +185,30 @@ class ChatServer:
|
||||
),
|
||||
)
|
||||
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with AsyncSession(engine) as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.join_channel(user, channel, session)
|
||||
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with AsyncSession(engine) as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.leave_channel(user, channel, session)
|
||||
|
||||
|
||||
server = ChatServer()
|
||||
|
||||
@@ -207,6 +246,10 @@ async def chat_websocket(
|
||||
authorization: str = Header(...),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
):
|
||||
if not server._subscribed:
|
||||
server._subscribed = True
|
||||
await server.ChatSubscriber.start_subscribe()
|
||||
|
||||
async for session in factory():
|
||||
token = authorization[7:]
|
||||
if (
|
||||
|
||||
@@ -116,7 +116,7 @@ class APICreatedRoom(RoomResp):
|
||||
|
||||
|
||||
async def _participate_room(
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
|
||||
):
|
||||
participated_user = (
|
||||
await session.exec(
|
||||
@@ -138,6 +138,8 @@ async def _participate_room(
|
||||
participated_user.joined_at = datetime.now(UTC)
|
||||
db_room.participant_count += 1
|
||||
|
||||
await redis.publish("chat:room:joined", f"{db_room.channel_id}:{user_id}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rooms",
|
||||
@@ -150,11 +152,12 @@ async def create_room(
|
||||
room: APIUploadedRoom,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
await _participate_room(db_room.id, user_id, db_room, db)
|
||||
await _participate_room(db_room.id, user_id, db_room, db, redis)
|
||||
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
|
||||
created_room.error = ""
|
||||
return created_room
|
||||
@@ -219,11 +222,12 @@ async def add_user_to_room(
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is not None:
|
||||
await _participate_room(room_id, user_id, db_room, db)
|
||||
await _participate_room(room_id, user_id, db_room, db, redis)
|
||||
await db.commit()
|
||||
await db.refresh(db_room)
|
||||
resp = await RoomResp.from_db(db_room, db)
|
||||
@@ -243,6 +247,7 @@ async def remove_user_from_room(
|
||||
user_id: int = Path(..., description="用户 ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is not None:
|
||||
@@ -257,6 +262,7 @@ async def remove_user_from_room(
|
||||
if participated_user is not None:
|
||||
participated_user.left_at = datetime.now(UTC)
|
||||
db_room.participant_count -= 1
|
||||
await redis.publish("chat:room:left", f"{db_room.channel_id}:{user_id}")
|
||||
await db.commit()
|
||||
return None
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user