feat(chat): support pm

This commit is contained in:
MingxuanGame
2025-08-16 07:48:19 +00:00
parent f992e4cc71
commit 368bdfe588
5 changed files with 316 additions and 16 deletions

View File

@@ -16,6 +16,7 @@ from sqlmodel import (
ForeignKey, ForeignKey,
Relationship, Relationship,
SQLModel, SQLModel,
col,
select, select,
) )
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -65,6 +66,15 @@ class ChatChannel(ChatChannelBase, table=True):
await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first() ).first()
@classmethod
async def get_pm_channel(
cls, user1: int, user2: int, session: AsyncSession
) -> "ChatChannel | None":
channel = await cls.get(f"pm_{user1}_{user2}", session)
if channel is None:
channel = await cls.get(f"pm_{user2}_{user1}", session)
return channel
class ChatChannelResp(ChatChannelBase): class ChatChannelResp(ChatChannelBase):
channel_id: int channel_id: int
@@ -73,8 +83,8 @@ class ChatChannelResp(ChatChannelBase):
current_user_attributes: ChatUserAttributes | None = None current_user_attributes: ChatUserAttributes | None = None
last_read_id: int | None = None last_read_id: int | None = None
last_message_id: int | None = None last_message_id: int | None = None
recent_messages: list[str] | None = None recent_messages: list["ChatMessageResp"] = Field(default_factory=list)
users: list[int] | None = None users: list[int] = Field(default_factory=list)
message_length_limit: int = 1000 message_length_limit: int = 1000
@classmethod @classmethod
@@ -82,9 +92,10 @@ class ChatChannelResp(ChatChannelBase):
cls, cls,
channel: ChatChannel, channel: ChatChannel,
session: AsyncSession, session: AsyncSession,
users: list[int],
user: User, user: User,
redis: Redis, redis: Redis,
users: list[int] | None = None,
include_recent_messages: bool = False,
) -> Self: ) -> Self:
c = cls.model_validate(channel) c = cls.model_validate(channel)
silence = ( silence = (
@@ -123,9 +134,33 @@ class ChatChannelResp(ChatChannelBase):
c.moderated = False c.moderated = False
c.current_user_attributes = attribute c.current_user_attributes = attribute
c.users = users if c.type != ChannelType.PUBLIC and users is not None:
c.users = users
c.last_message_id = last_msg c.last_message_id = last_msg
c.last_read_id = last_read_id c.last_read_id = last_read_id
if include_recent_messages:
messages = (
await session.exec(
select(ChatMessage)
.where(ChatMessage.channel_id == channel.channel_id)
.order_by(col(ChatMessage.timestamp).desc())
.limit(10)
)
).all()
c.recent_messages = [
await ChatMessageResp.from_db(msg, session, user) for msg in messages
]
c.recent_messages.reverse()
if c.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id)
target_name = await session.exec(
select(User.username).where(User.id == target_user_id)
)
c.name = target_name.one()
assert user.id
c.users = [target_user_id, user.id]
return c return c

View File

@@ -168,6 +168,46 @@ class User(AsyncAttrs, UserBase, table=True):
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
) )
async def is_user_can_pm(
self, from_user: "User", session: AsyncSession
) -> tuple[bool, str]:
from .relationship import Relationship, RelationshipType
from_relationship = (
await session.exec(
select(Relationship).where(
Relationship.user_id == from_user.id,
Relationship.target_id == self.id,
)
)
).first()
if from_relationship and from_relationship.type == RelationshipType.BLOCK:
return False, "You have blocked the target user."
if from_user.pm_friends_only and (
not from_relationship or from_relationship.type != RelationshipType.FOLLOW
):
return (
False,
"You have disabled non-friend communications "
"and target user is not your friend.",
)
relationship = (
await session.exec(
select(Relationship).where(
Relationship.user_id == self.id,
Relationship.target_id == from_user.id,
)
)
).first()
if relationship and relationship.type == RelationshipType.BLOCK:
return False, "Target user has blocked you."
if self.pm_friends_only and (
not relationship or relationship.type != RelationshipType.FOLLOW
):
return False, "Target user has disabled non-friend communications"
return True, ""
class UserResp(UserBase): class UserResp(UserBase):
id: int | None = None id: int | None = None

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any, Literal, Self
from app.database.chat import ( from app.database.chat import (
ChannelType, ChannelType,
@@ -9,15 +9,16 @@ from app.database.chat import (
) )
from app.database.lazer_user import User, UserResp from app.database.lazer_user import User, UserResp
from app.dependencies.database import get_db, get_redis from app.dependencies.database import get_db, get_redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router from app.router.v2 import api_v2_router as router
from .server import server from .server import server
from fastapi import Depends, HTTPException, Query, Security from fastapi import Depends, HTTPException, Query, Security
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -37,6 +38,7 @@ async def get_update(
): ):
resp = UpdateResponse() resp = UpdateResponse()
if "presence" in includes: if "presence" in includes:
assert current_user.id
channel_ids = server.get_user_joined_channel(current_user.id) channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids: for channel_id in channel_ids:
channel = await ChatChannel.get(channel_id, session) channel = await ChatChannel.get(channel_id, session)
@@ -45,9 +47,11 @@ async def get_update(
await ChatChannelResp.from_db( await ChatChannelResp.from_db(
channel, channel,
session, session,
server.channels.get(channel_id, []),
current_user, current_user,
redis, redis,
server.channels.get(channel_id, [])
if channel.type != ChannelType.PUBLIC
else None,
) )
) )
return resp return resp
@@ -103,9 +107,11 @@ async def get_channel_list(
await ChatChannelResp.from_db( await ChatChannelResp.from_db(
channel, channel,
session, session,
server.channels.get(channel.channel_id, []),
current_user, current_user,
redis, redis,
server.channels.get(channel.channel_id, [])
if channel.type != ChannelType.PUBLIC
else None,
) )
) )
return results return results
@@ -127,12 +133,111 @@ async def get_channel(
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id is not None assert db_channel.channel_id is not None
users = []
if db_channel.type == ChannelType.PM:
user_ids = db_channel.name.split("_")[1:]
if len(user_ids) != 2:
raise HTTPException(status_code=404, detail="Target user not found")
for id_ in user_ids:
if int(id_) == current_user.id:
continue
target_user = await session.get(User, int(id_))
if target_user is None:
raise HTTPException(status_code=404, detail="Target user not found")
users.extend([target_user, current_user])
break
return GetChannelResp( return GetChannelResp(
channel=await ChatChannelResp.from_db( channel=await ChatChannelResp.from_db(
db_channel, db_channel,
session, session,
server.channels.get(db_channel.channel_id, []),
current_user, current_user,
redis, redis,
server.channels.get(db_channel.channel_id, [])
if db_channel.type != ChannelType.PUBLIC
else None,
) )
) )
class CreateChannelReq(BaseModel):
class AnnounceChannel(BaseModel):
name: str
description: str
message: str | None = None
type: Literal["ANNOUNCE", "PM"] = "PM"
target_id: int | None = None
target_ids: list[int] | None = None
channel: AnnounceChannel | None = None
@model_validator(mode="after")
def check(self) -> Self:
if self.type == "PM":
if self.target_id is None:
raise ValueError("target_id must be set for PM channels")
else:
if self.target_ids is None or self.channel is None or self.message is None:
raise ValueError(
"target_ids, channel, and message must be set for ANNOUNCE channels"
)
return self
@router.post("/chat/channels")
async def create_channel(
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
if req.type == "PM":
target = await session.get(User, req.target_id)
if not target:
raise HTTPException(status_code=404, detail="Target user not found")
is_can_pm, block = await target.is_user_can_pm(current_user, session)
if not is_can_pm:
raise HTTPException(status_code=403, detail=block)
channel = await ChatChannel.get_pm_channel(
current_user.id, # pyright: ignore[reportArgumentType]
req.target_id, # pyright: ignore[reportArgumentType]
session,
)
channel_name = f"pm_{current_user.id}_{req.target_id}"
else:
channel_name = req.channel.name if req.channel else "Unnamed Channel"
channel = await ChatChannel.get(channel_name, session)
if channel is None:
channel = ChatChannel(
name=channel_name,
description=req.channel.description
if req.channel
else "Private message channel",
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(current_user)
if req.type == "PM":
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
else:
target_users = await session.exec(
select(User).where(col(User.id).in_(req.target_ids or []))
)
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session)
assert channel.channel_id
return await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, []),
include_recent_messages=True,
)

View File

@@ -1,9 +1,15 @@
from __future__ import annotations from __future__ import annotations
from app.database import ChatMessageResp from app.database import ChatMessageResp
from app.database.chat import ChatChannel, ChatMessage, MessageType from app.database.chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatMessage,
MessageType,
)
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import get_db from app.dependencies.database import get_db, get_redis
from app.dependencies.param import BodyOrForm from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router from app.router.v2 import api_v2_router as router
@@ -12,6 +18,7 @@ from .server import server
from fastapi import Depends, HTTPException, Query, Security from fastapi import Depends, HTTPException, Query, Security
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -42,6 +49,9 @@ async def send_message(
db_channel = await ChatChannel.get(channel, session) db_channel = await ChatChannel.get(channel, session)
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id
assert current_user.id
msg = ChatMessage( msg = ChatMessage(
channel_id=db_channel.channel_id, channel_id=db_channel.channel_id,
content=req.message, content=req.message,
@@ -95,4 +105,73 @@ async def mark_as_read(
db_channel = await ChatChannel.get(channel, session) db_channel = await ChatChannel.get(channel, session)
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id
await server.mark_as_read(db_channel.channel_id, message) await server.mark_as_read(db_channel.channel_id, message)
class PMReq(BaseModel):
target_id: int
message: str
is_action: bool = False
uuid: str | None = None
class NewPMResp(BaseModel):
channel: ChatChannelResp
message: ChatMessageResp
new_channel_id: int
@router.post("/chat/new")
async def create_new_pm(
req: PMReq = Depends(BodyOrForm(PMReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
user_id = current_user.id
target = await session.get(User, req.target_id)
if target is None:
raise HTTPException(status_code=404, detail="Target user not found")
is_can_pm, block = await target.is_user_can_pm(current_user, session)
if not is_can_pm:
raise HTTPException(status_code=403, detail=block)
assert user_id
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
if channel is None:
channel = ChatChannel(
name=f"pm_{user_id}_{req.target_id}",
description="Private message channel",
type=ChannelType.PM,
)
session.add(channel)
await session.commit()
await session.refresh(channel)
await session.refresh(target)
await session.refresh(current_user)
assert channel.channel_id
await server.batch_join_channel([target, current_user], channel, session)
channel_resp = await ChatChannelResp.from_db(
channel, session, current_user, redis, server.channels[channel.channel_id]
)
msg = ChatMessage(
channel_id=channel.channel_id,
content=req.message,
sender_id=user_id,
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
uuid=req.uuid,
)
session.add(msg)
await session.commit()
await session.refresh(msg)
await session.refresh(current_user)
await session.refresh(channel)
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
await server.send_message_to_channel(message_resp)
return NewPMResp(
channel=channel_resp,
message=message_resp,
new_channel_id=channel_resp.channel_id,
)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
from app.database.chat import ChatChannel, ChatChannelResp, ChatMessageResp from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import DBFactory, get_db_factory, get_redis from app.dependencies.database import DBFactory, get_db_factory, get_redis
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
@@ -73,14 +73,50 @@ class ChatServer:
), ),
) )
) )
assert message.message_id
await self.mark_as_read(message.channel_id, message.message_id) await self.mark_as_read(message.channel_id, message.message_id)
async def batch_join_channel(
self, users: list[User], channel: ChatChannel, session: AsyncSession
):
channel_id = channel.channel_id
assert channel_id is not None
if channel_id not in self.channels:
self.channels[channel_id] = []
for user_id in [user.id for user in users]:
assert user_id is not None
if user_id not in self.channels[channel_id]:
self.channels[channel_id].append(user_id)
for user in users:
assert user.id is not None
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id]
if channel.type != ChannelType.PUBLIC
else None,
)
client = self.connect_client.get(user.id)
if client:
await self.send_event(
client,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
),
)
async def join_channel( async def join_channel(
self, user: User, channel: ChatChannel, session: AsyncSession self, user: User, channel: ChatChannel, session: AsyncSession
) -> ChatChannelResp: ) -> ChatChannelResp:
user_id = user.id user_id = user.id
channel_id = channel.channel_id channel_id = channel.channel_id
assert channel_id is not None assert channel_id is not None
assert user_id is not None
if channel_id not in self.channels: if channel_id not in self.channels:
self.channels[channel_id] = [] self.channels[channel_id] = []
@@ -88,7 +124,11 @@ class ChatServer:
self.channels[channel_id].append(user_id) self.channels[channel_id].append(user_id)
channel_resp = await ChatChannelResp.from_db( channel_resp = await ChatChannelResp.from_db(
channel, session, self.channels[channel_id], user, self.redis channel,
session,
user,
self.redis,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
) )
client = self.connect_client.get(user_id) client = self.connect_client.get(user_id)
@@ -109,15 +149,16 @@ class ChatServer:
user_id = user.id user_id = user.id
channel_id = channel.channel_id channel_id = channel.channel_id
assert channel_id is not None assert channel_id is not None
assert user_id is not None
if channel_id in self.channels and user_id in self.channels[channel_id]: if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id) self.channels[channel_id].remove(user_id)
if not self.channels.get(channel_id): if (c := self.channels.get(channel_id)) is not None and not c:
del self.channels[channel_id] del self.channels[channel_id]
channel_resp = await ChatChannelResp.from_db( channel_resp = await ChatChannelResp.from_db(
channel, session, self.channels.get(channel_id, []), user, self.redis channel, session, user, self.redis, self.channels[channel_id]
) )
client = self.connect_client.get(user_id) client = self.connect_client.get(user_id)
if client: if client: