from __future__ import annotations import asyncio from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp from app.database.lazer_user import User from app.dependencies.database import ( DBFactory, engine, get_db_factory, get_redis, ) from app.dependencies.user import get_current_user from app.log import logger from app.models.chat import ChatEvent from app.service.subscribers.chat import ChatSubscriber from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect from fastapi.security import SecurityScopes from fastapi.websockets import WebSocketState from redis.asyncio import Redis from sqlmodel.ext.asyncio.session import AsyncSession class ChatServer: def __init__(self): self.connect_client: dict[int, WebSocket] = {} self.channels: dict[int, list[int]] = {} self.redis: Redis = get_redis() self.tasks: set[asyncio.Task] = set() self.ChatSubscriber = ChatSubscriber() self.ChatSubscriber.chat_server = self self._subscribed = False def _add_task(self, task): task = asyncio.create_task(task) self.tasks.add(task) task.add_done_callback(self.tasks.discard) def connect(self, user_id: int, client: WebSocket): self.connect_client[user_id] = client def get_user_joined_channel(self, user_id: int) -> list[int]: return [ channel_id for channel_id, users in self.channels.items() if user_id in users ] async def disconnect(self, user: User, session: AsyncSession): user_id = user.id if user_id in self.connect_client: del self.connect_client[user_id] for channel_id, channel in self.channels.items(): if user_id in channel: channel.remove(user_id) channel = await ChatChannel.get(channel_id, session) if channel: await self.leave_channel(user, channel, session) async def send_event(self, client: WebSocket, event: ChatEvent): if client.client_state == WebSocketState.CONNECTED: await client.send_text(event.model_dump_json()) async def broadcast(self, channel_id: int, event: ChatEvent): for user_id in self.channels.get(channel_id, []): client = self.connect_client.get(user_id) if client: await self.send_event(client, event) async def mark_as_read(self, channel_id: int, message_id: int): await self.redis.set(f"chat:{channel_id}:last_msg", message_id) async def send_message_to_channel( self, message: ChatMessageResp, is_bot_command: bool = False ): event = ChatEvent( event="chat.message.new", data={"messages": [message], "users": [message.sender]}, ) if is_bot_command: client = self.connect_client.get(message.sender_id) if client: self._add_task(self.send_event(client, event)) else: self._add_task( self.broadcast( message.channel_id, event, ) ) assert message.message_id await self.mark_as_read(message.channel_id, message.message_id) async def batch_join_channel( self, users: list[User], channel: ChatChannel, session: AsyncSession ): channel_id = channel.channel_id assert channel_id is not None not_joined = [] if channel_id not in self.channels: self.channels[channel_id] = [] for user in users: assert user.id is not None if user.id not in self.channels[channel_id]: self.channels[channel_id].append(user.id) not_joined.append(user) for user in not_joined: assert user.id is not None channel_resp = await ChatChannelResp.from_db( channel, session, user, self.redis, self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None, ) client = self.connect_client.get(user.id) if client: await self.send_event( client, ChatEvent( event="chat.channel.join", data=channel_resp.model_dump(), ), ) async def join_channel( self, user: User, channel: ChatChannel, session: AsyncSession ) -> ChatChannelResp: user_id = user.id channel_id = channel.channel_id assert channel_id is not None assert user_id is not None if channel_id not in self.channels: self.channels[channel_id] = [] if user_id not in self.channels[channel_id]: self.channels[channel_id].append(user_id) channel_resp = await ChatChannelResp.from_db( channel, session, user, self.redis, self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None, ) client = self.connect_client.get(user_id) if client: await self.send_event( client, ChatEvent( event="chat.channel.join", data=channel_resp.model_dump(), ), ) return channel_resp async def leave_channel( self, user: User, channel: ChatChannel, session: AsyncSession ) -> None: user_id = user.id channel_id = channel.channel_id assert channel_id is not None assert user_id is not None if channel_id in self.channels and user_id in self.channels[channel_id]: self.channels[channel_id].remove(user_id) if (c := self.channels.get(channel_id)) is not None and not c: del self.channels[channel_id] channel_resp = await ChatChannelResp.from_db( channel, session, user, self.redis, self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None, ) client = self.connect_client.get(user_id) if client: await self.send_event( client, ChatEvent( event="chat.channel.part", data=channel_resp.model_dump(), ), ) async def join_room_channel(self, channel_id: int, user_id: int): async with AsyncSession(engine) as session: channel = await ChatChannel.get(channel_id, session) if channel is None: return user = await session.get(User, user_id) if user is None: return await self.join_channel(user, channel, session) async def leave_room_channel(self, channel_id: int, user_id: int): async with AsyncSession(engine) as session: channel = await ChatChannel.get(channel_id, session) if channel is None: return user = await session.get(User, user_id) if user is None: return await self.leave_channel(user, channel, session) server = ChatServer() chat_router = APIRouter(include_in_schema=False) async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory): try: while True: packets = await ws.receive_json() if packets.get("event") == "chat.end": async for session in factory(): user = await session.get(User, user_id) if user is None: break await server.disconnect(user, session) await ws.close(code=1000) break except WebSocketDisconnect as e: logger.info( f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}" ) except RuntimeError as e: if "disconnect message" in str(e): logger.info(f"[NotificationServer] Client {user_id} closed the connection.") else: logger.exception(f"RuntimeError in client {user_id}: {e}") except Exception: logger.exception(f"Error in client {user_id}") @chat_router.websocket("/notification-server") async def chat_websocket( websocket: WebSocket, authorization: str = Header(...), factory: DBFactory = Depends(get_db_factory), ): if not server._subscribed: server._subscribed = True await server.ChatSubscriber.start_subscribe() async for session in factory(): token = authorization[7:] if ( user := await get_current_user( SecurityScopes(scopes=["chat.read"]), session, token_pw=token ) ) is None: await websocket.close(code=1008) return await websocket.accept() login = await websocket.receive_json() if login.get("event") != "chat.start": await websocket.close(code=1008) return user_id = user.id assert user_id server.connect(user_id, websocket) channel = await ChatChannel.get(1, session) if channel is not None: await server.join_channel(user, channel, session) await _listen_stop(websocket, user_id, factory)