From 368bdfe58811bce30ebd67afb56a4fd74720c2a2 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 16 Aug 2025 07:48:19 +0000 Subject: [PATCH] feat(chat): support pm --- app/database/chat.py | 43 ++++++++++++-- app/database/lazer_user.py | 40 +++++++++++++ app/router/chat/channel.py | 117 +++++++++++++++++++++++++++++++++++-- app/router/chat/message.py | 83 +++++++++++++++++++++++++- app/router/chat/server.py | 49 ++++++++++++++-- 5 files changed, 316 insertions(+), 16 deletions(-) diff --git a/app/database/chat.py b/app/database/chat.py index 565657d..777e2ac 100644 --- a/app/database/chat.py +++ b/app/database/chat.py @@ -16,6 +16,7 @@ from sqlmodel import ( ForeignKey, Relationship, SQLModel, + col, select, ) from sqlmodel.ext.asyncio.session import AsyncSession @@ -65,6 +66,15 @@ class ChatChannel(ChatChannelBase, table=True): await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) ).first() + @classmethod + async def get_pm_channel( + cls, user1: int, user2: int, session: AsyncSession + ) -> "ChatChannel | None": + channel = await cls.get(f"pm_{user1}_{user2}", session) + if channel is None: + channel = await cls.get(f"pm_{user2}_{user1}", session) + return channel + class ChatChannelResp(ChatChannelBase): channel_id: int @@ -73,8 +83,8 @@ class ChatChannelResp(ChatChannelBase): current_user_attributes: ChatUserAttributes | None = None last_read_id: int | None = None last_message_id: int | None = None - recent_messages: list[str] | None = None - users: list[int] | None = None + recent_messages: list["ChatMessageResp"] = Field(default_factory=list) + users: list[int] = Field(default_factory=list) message_length_limit: int = 1000 @classmethod @@ -82,9 +92,10 @@ class ChatChannelResp(ChatChannelBase): cls, channel: ChatChannel, session: AsyncSession, - users: list[int], user: User, redis: Redis, + users: list[int] | None = None, + include_recent_messages: bool = False, ) -> Self: c = cls.model_validate(channel) silence = ( @@ -123,9 +134,33 @@ class ChatChannelResp(ChatChannelBase): c.moderated = False c.current_user_attributes = attribute - c.users = users + if c.type != ChannelType.PUBLIC and users is not None: + c.users = users c.last_message_id = last_msg c.last_read_id = last_read_id + + if include_recent_messages: + messages = ( + await session.exec( + select(ChatMessage) + .where(ChatMessage.channel_id == channel.channel_id) + .order_by(col(ChatMessage.timestamp).desc()) + .limit(10) + ) + ).all() + c.recent_messages = [ + await ChatMessageResp.from_db(msg, session, user) for msg in messages + ] + c.recent_messages.reverse() + + if c.type == ChannelType.PM and users and len(users) == 2: + target_user_id = next(u for u in users if u != user.id) + target_name = await session.exec( + select(User.username).where(User.id == target_user_id) + ) + c.name = target_name.one() + assert user.id + c.users = [target_user_id, user.id] return c diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index ca767f5..e242fa4 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -168,6 +168,46 @@ class User(AsyncAttrs, UserBase, table=True): default=None, sa_column=Column(DateTime(timezone=True)), exclude=True ) + async def is_user_can_pm( + self, from_user: "User", session: AsyncSession + ) -> tuple[bool, str]: + from .relationship import Relationship, RelationshipType + + from_relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == from_user.id, + Relationship.target_id == self.id, + ) + ) + ).first() + if from_relationship and from_relationship.type == RelationshipType.BLOCK: + return False, "You have blocked the target user." + if from_user.pm_friends_only and ( + not from_relationship or from_relationship.type != RelationshipType.FOLLOW + ): + return ( + False, + "You have disabled non-friend communications " + "and target user is not your friend.", + ) + + relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == self.id, + Relationship.target_id == from_user.id, + ) + ) + ).first() + if relationship and relationship.type == RelationshipType.BLOCK: + return False, "Target user has blocked you." + if self.pm_friends_only and ( + not relationship or relationship.type != RelationshipType.FOLLOW + ): + return False, "Target user has disabled non-friend communications" + return True, "" + class UserResp(UserBase): id: int | None = None diff --git a/app/router/chat/channel.py b/app/router/chat/channel.py index 798a5f2..ceb7b66 100644 --- a/app/router/chat/channel.py +++ b/app/router/chat/channel.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal, Self from app.database.chat import ( ChannelType, @@ -9,15 +9,16 @@ from app.database.chat import ( ) from app.database.lazer_user import User, UserResp from app.dependencies.database import get_db, get_redis +from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router from .server import server from fastapi import Depends, HTTPException, Query, Security -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from redis.asyncio import Redis -from sqlmodel import select +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -37,6 +38,7 @@ async def get_update( ): resp = UpdateResponse() if "presence" in includes: + assert current_user.id channel_ids = server.get_user_joined_channel(current_user.id) for channel_id in channel_ids: channel = await ChatChannel.get(channel_id, session) @@ -45,9 +47,11 @@ async def get_update( await ChatChannelResp.from_db( channel, session, - server.channels.get(channel_id, []), current_user, redis, + server.channels.get(channel_id, []) + if channel.type != ChannelType.PUBLIC + else None, ) ) return resp @@ -103,9 +107,11 @@ async def get_channel_list( await ChatChannelResp.from_db( channel, session, - server.channels.get(channel.channel_id, []), current_user, redis, + server.channels.get(channel.channel_id, []) + if channel.type != ChannelType.PUBLIC + else None, ) ) return results @@ -127,12 +133,111 @@ async def get_channel( if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") assert db_channel.channel_id is not None + + users = [] + if db_channel.type == ChannelType.PM: + user_ids = db_channel.name.split("_")[1:] + if len(user_ids) != 2: + raise HTTPException(status_code=404, detail="Target user not found") + for id_ in user_ids: + if int(id_) == current_user.id: + continue + target_user = await session.get(User, int(id_)) + if target_user is None: + raise HTTPException(status_code=404, detail="Target user not found") + users.extend([target_user, current_user]) + break + return GetChannelResp( channel=await ChatChannelResp.from_db( db_channel, session, - server.channels.get(db_channel.channel_id, []), current_user, redis, + server.channels.get(db_channel.channel_id, []) + if db_channel.type != ChannelType.PUBLIC + else None, ) ) + + +class CreateChannelReq(BaseModel): + class AnnounceChannel(BaseModel): + name: str + description: str + + message: str | None = None + type: Literal["ANNOUNCE", "PM"] = "PM" + target_id: int | None = None + target_ids: list[int] | None = None + channel: AnnounceChannel | None = None + + @model_validator(mode="after") + def check(self) -> Self: + if self.type == "PM": + if self.target_id is None: + raise ValueError("target_id must be set for PM channels") + else: + if self.target_ids is None or self.channel is None or self.message is None: + raise ValueError( + "target_ids, channel, and message must be set for ANNOUNCE channels" + ) + return self + + +@router.post("/chat/channels") +async def create_channel( + req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)), + current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + if req.type == "PM": + target = await session.get(User, req.target_id) + if not target: + raise HTTPException(status_code=404, detail="Target user not found") + is_can_pm, block = await target.is_user_can_pm(current_user, session) + if not is_can_pm: + raise HTTPException(status_code=403, detail=block) + + channel = await ChatChannel.get_pm_channel( + current_user.id, # pyright: ignore[reportArgumentType] + req.target_id, # pyright: ignore[reportArgumentType] + session, + ) + channel_name = f"pm_{current_user.id}_{req.target_id}" + else: + channel_name = req.channel.name if req.channel else "Unnamed Channel" + channel = await ChatChannel.get(channel_name, session) + + if channel is None: + channel = ChatChannel( + name=channel_name, + description=req.channel.description + if req.channel + else "Private message channel", + type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(current_user) + if req.type == "PM": + await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable] + await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable] + else: + target_users = await session.exec( + select(User).where(col(User.id).in_(req.target_ids or [])) + ) + await server.batch_join_channel([*target_users, current_user], channel, session) + + await server.join_channel(current_user, channel, session) + assert channel.channel_id + return await ChatChannelResp.from_db( + channel, + session, + current_user, + redis, + server.channels.get(channel.channel_id, []), + include_recent_messages=True, + ) diff --git a/app/router/chat/message.py b/app/router/chat/message.py index 5bab869..dc4f134 100644 --- a/app/router/chat/message.py +++ b/app/router/chat/message.py @@ -1,9 +1,15 @@ from __future__ import annotations from app.database import ChatMessageResp -from app.database.chat import ChatChannel, ChatMessage, MessageType +from app.database.chat import ( + ChannelType, + ChatChannel, + ChatChannelResp, + ChatMessage, + MessageType, +) from app.database.lazer_user import User -from app.dependencies.database import get_db +from app.dependencies.database import get_db, get_redis from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.router.v2 import api_v2_router as router @@ -12,6 +18,7 @@ from .server import server from fastapi import Depends, HTTPException, Query, Security from pydantic import BaseModel +from redis.asyncio import Redis from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -42,6 +49,9 @@ async def send_message( db_channel = await ChatChannel.get(channel, session) if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") + + assert db_channel.channel_id + assert current_user.id msg = ChatMessage( channel_id=db_channel.channel_id, content=req.message, @@ -95,4 +105,73 @@ async def mark_as_read( db_channel = await ChatChannel.get(channel, session) if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") + assert db_channel.channel_id await server.mark_as_read(db_channel.channel_id, message) + + +class PMReq(BaseModel): + target_id: int + message: str + is_action: bool = False + uuid: str | None = None + + +class NewPMResp(BaseModel): + channel: ChatChannelResp + message: ChatMessageResp + new_channel_id: int + + +@router.post("/chat/new") +async def create_new_pm( + req: PMReq = Depends(BodyOrForm(PMReq)), + current_user: User = Security(get_current_user, scopes=["chat.write"]), + session: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), +): + user_id = current_user.id + target = await session.get(User, req.target_id) + if target is None: + raise HTTPException(status_code=404, detail="Target user not found") + is_can_pm, block = await target.is_user_can_pm(current_user, session) + if not is_can_pm: + raise HTTPException(status_code=403, detail=block) + + assert user_id + channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session) + if channel is None: + channel = ChatChannel( + name=f"pm_{user_id}_{req.target_id}", + description="Private message channel", + type=ChannelType.PM, + ) + session.add(channel) + await session.commit() + await session.refresh(channel) + await session.refresh(target) + await session.refresh(current_user) + + assert channel.channel_id + await server.batch_join_channel([target, current_user], channel, session) + channel_resp = await ChatChannelResp.from_db( + channel, session, current_user, redis, server.channels[channel.channel_id] + ) + msg = ChatMessage( + channel_id=channel.channel_id, + content=req.message, + sender_id=user_id, + type=MessageType.ACTION if req.is_action else MessageType.PLAIN, + uuid=req.uuid, + ) + session.add(msg) + await session.commit() + await session.refresh(msg) + await session.refresh(current_user) + await session.refresh(channel) + message_resp = await ChatMessageResp.from_db(msg, session, current_user) + await server.send_message_to_channel(message_resp) + return NewPMResp( + channel=channel_resp, + message=message_resp, + new_channel_id=channel_resp.channel_id, + ) diff --git a/app/router/chat/server.py b/app/router/chat/server.py index 727f426..2c91952 100644 --- a/app/router/chat/server.py +++ b/app/router/chat/server.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from app.database.chat import ChatChannel, ChatChannelResp, ChatMessageResp +from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp from app.database.lazer_user import User from app.dependencies.database import DBFactory, get_db_factory, get_redis from app.dependencies.user import get_current_user @@ -73,14 +73,50 @@ class ChatServer: ), ) ) + assert message.message_id await self.mark_as_read(message.channel_id, message.message_id) + async def batch_join_channel( + self, users: list[User], channel: ChatChannel, session: AsyncSession + ): + channel_id = channel.channel_id + assert channel_id is not None + + if channel_id not in self.channels: + self.channels[channel_id] = [] + for user_id in [user.id for user in users]: + assert user_id is not None + if user_id not in self.channels[channel_id]: + self.channels[channel_id].append(user_id) + + for user in users: + assert user.id is not None + channel_resp = await ChatChannelResp.from_db( + channel, + session, + user, + self.redis, + self.channels[channel_id] + if channel.type != ChannelType.PUBLIC + else None, + ) + client = self.connect_client.get(user.id) + if client: + await self.send_event( + client, + ChatEvent( + event="chat.channel.join", + data=channel_resp.model_dump(), + ), + ) + async def join_channel( self, user: User, channel: ChatChannel, session: AsyncSession ) -> ChatChannelResp: user_id = user.id channel_id = channel.channel_id assert channel_id is not None + assert user_id is not None if channel_id not in self.channels: self.channels[channel_id] = [] @@ -88,7 +124,11 @@ class ChatServer: self.channels[channel_id].append(user_id) channel_resp = await ChatChannelResp.from_db( - channel, session, self.channels[channel_id], user, self.redis + channel, + session, + user, + self.redis, + self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None, ) client = self.connect_client.get(user_id) @@ -109,15 +149,16 @@ class ChatServer: user_id = user.id channel_id = channel.channel_id assert channel_id is not None + assert user_id is not None if channel_id in self.channels and user_id in self.channels[channel_id]: self.channels[channel_id].remove(user_id) - if not self.channels.get(channel_id): + if (c := self.channels.get(channel_id)) is not None and not c: del self.channels[channel_id] channel_resp = await ChatChannelResp.from_db( - channel, session, self.channels.get(channel_id, []), user, self.redis + channel, session, user, self.redis, self.channels[channel_id] ) client = self.connect_client.get(user_id) if client: