from __future__ import annotations from app.database import ChatMessageResp from app.database.chat import ( ChannelType, ChatChannel, ChatChannelResp, ChatMessage, MessageType, SilenceUser, UserSilenceResp, ) from app.database.lazer_user import User from app.dependencies.database import get_db, get_redis from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router from .banchobot import bot from .server import server from fastapi import Depends, HTTPException, Query, Security from pydantic import BaseModel, Field from redis.asyncio import Redis from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession class KeepAliveResp(BaseModel): silences: list[UserSilenceResp] = Field(default_factory=list) @router.post("/chat/ack") async def keep_alive( history_since: int | None = Query(None), since: int | None = Query(None), current_user: User = Security(get_current_user, scopes=["chat.read"]), session: AsyncSession = Depends(get_db), ): resp = KeepAliveResp() if history_since: silences = ( await session.exec( select(SilenceUser).where(col(SilenceUser.id) > history_since) ) ).all() resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences]) elif since: msg = await session.get(ChatMessage, since) if msg: silences = ( await session.exec( select(SilenceUser).where( col(SilenceUser.banned_at) > msg.timestamp ) ) ).all() resp.silences.extend( [UserSilenceResp.from_db(silence) for silence in silences] ) return resp class MessageReq(BaseModel): message: str is_action: bool = False uuid: str | None = None @router.post("/chat/channels/{channel}/messages", response_model=ChatMessageResp) async def send_message( channel: str, req: MessageReq = Depends(BodyOrForm(MessageReq)), current_user: User = Security(get_current_user, scopes=["chat.write"]), session: AsyncSession = Depends(get_db), ): db_channel = await ChatChannel.get(channel, session) if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") assert db_channel.channel_id assert current_user.id msg = ChatMessage( channel_id=db_channel.channel_id, content=req.message, sender_id=current_user.id, type=MessageType.ACTION if req.is_action else MessageType.PLAIN, uuid=req.uuid, ) session.add(msg) await session.commit() await session.refresh(msg) await session.refresh(current_user) await session.refresh(db_channel) resp = await ChatMessageResp.from_db(msg, session, current_user) is_bot_command = req.message.startswith("!") await server.send_message_to_channel( resp, is_bot_command and db_channel.type == ChannelType.PUBLIC ) if is_bot_command: await bot.try_handle(current_user, db_channel, req.message, session) return resp @router.get("/chat/channels/{channel}/messages", response_model=list[ChatMessageResp]) async def get_message( channel: str, limit: int = Query(50, ge=1, le=50), since: int = Query(default=0, ge=0), until: int | None = Query(None), current_user: User = Security(get_current_user, scopes=["chat.read"]), session: AsyncSession = Depends(get_db), ): db_channel = await ChatChannel.get(channel, session) if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") messages = await session.exec( select(ChatMessage) .where( ChatMessage.channel_id == db_channel.channel_id, col(ChatMessage.message_id) > since, col(ChatMessage.message_id) < until if until is not None else True, ) .order_by(col(ChatMessage.timestamp).desc()) .limit(limit) ) resp = [await ChatMessageResp.from_db(msg, session) for msg in messages] resp.reverse() return resp @router.put("/chat/channels/{channel}/mark-as-read/{message}", status_code=204) async def mark_as_read( channel: str, message: int, current_user: User = Security(get_current_user, scopes=["chat.read"]), session: AsyncSession = Depends(get_db), ): db_channel = await ChatChannel.get(channel, session) if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") assert db_channel.channel_id await server.mark_as_read(db_channel.channel_id, message) class PMReq(BaseModel): target_id: int message: str is_action: bool = False uuid: str | None = None class NewPMResp(BaseModel): channel: ChatChannelResp message: ChatMessageResp new_channel_id: int @router.post("/chat/new") async def create_new_pm( req: PMReq = Depends(BodyOrForm(PMReq)), current_user: User = Security(get_current_user, scopes=["chat.write"]), session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), ): user_id = current_user.id target = await session.get(User, req.target_id) if target is None: raise HTTPException(status_code=404, detail="Target user not found") is_can_pm, block = await target.is_user_can_pm(current_user, session) if not is_can_pm: raise HTTPException(status_code=403, detail=block) assert user_id channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session) if channel is None: channel = ChatChannel( name=f"pm_{user_id}_{req.target_id}", description="Private message channel", type=ChannelType.PM, ) session.add(channel) await session.commit() await session.refresh(channel) await session.refresh(target) await session.refresh(current_user) assert channel.channel_id await server.batch_join_channel([target, current_user], channel, session) channel_resp = await ChatChannelResp.from_db( channel, session, current_user, redis, server.channels[channel.channel_id] ) msg = ChatMessage( channel_id=channel.channel_id, content=req.message, sender_id=user_id, type=MessageType.ACTION if req.is_action else MessageType.PLAIN, uuid=req.uuid, ) session.add(msg) await session.commit() await session.refresh(msg) await session.refresh(current_user) await session.refresh(channel) message_resp = await ChatMessageResp.from_db(msg, session, current_user) await server.send_message_to_channel(message_resp) return NewPMResp( channel=channel_resp, message=message_resp, new_channel_id=channel_resp.channel_id, )