add message redis
This commit is contained in:
@@ -59,9 +59,14 @@ class ChatServer:
|
||||
for channel_id, channel in self.channels.items():
|
||||
if user_id in channel:
|
||||
channel.remove(user_id)
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel:
|
||||
await self.leave_channel(user, channel, session)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
if db_channel:
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
|
||||
@overload
|
||||
async def send_event(self, client: int, event: ChatEvent): ...
|
||||
@@ -79,8 +84,11 @@ class ChatServer:
|
||||
await client.send_text(event.model_dump_json())
|
||||
|
||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||
for user_id in self.channels.get(channel_id, []):
|
||||
users_in_channel = self.channels.get(channel_id, [])
|
||||
logger.info(f"Broadcasting to channel {channel_id}, users: {users_in_channel}")
|
||||
for user_id in users_in_channel:
|
||||
await self.send_event(user_id, event)
|
||||
logger.debug(f"Sent event to user {user_id} in channel {channel_id}")
|
||||
|
||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||
@@ -88,24 +96,35 @@ class ChatServer:
|
||||
async def send_message_to_channel(
|
||||
self, message: ChatMessageResp, is_bot_command: bool = False
|
||||
):
|
||||
logger.info(f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}")
|
||||
|
||||
event = ChatEvent(
|
||||
event="chat.message.new",
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
)
|
||||
if is_bot_command:
|
||||
logger.info(f"Sending bot command to user {message.sender_id}")
|
||||
self._add_task(self.send_event(message.sender_id, event))
|
||||
else:
|
||||
# 总是广播消息,无论是临时ID还是真实ID
|
||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
)
|
||||
assert message.message_id
|
||||
await self.mark_as_read(
|
||||
message.channel_id, message.sender_id, message.message_id
|
||||
)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
|
||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||
if message.message_id and message.message_id > 0:
|
||||
await self.mark_as_read(
|
||||
message.channel_id, message.sender_id, message.message_id
|
||||
)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||||
else:
|
||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||||
|
||||
async def batch_join_channel(
|
||||
self, users: list[User], channel: ChatChannel, session: AsyncSession
|
||||
@@ -206,27 +225,37 @@ class ChatServer:
|
||||
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
if db_channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.join_channel(user, channel, session)
|
||||
await self.join_channel(user, db_channel, session)
|
||||
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
if db_channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.leave_channel(user, channel, session)
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
|
||||
async def new_private_notification(self, detail: NotificationDetail):
|
||||
async with with_db() as session:
|
||||
@@ -309,7 +338,13 @@ async def chat_websocket(
|
||||
user_id = user.id
|
||||
assert user_id
|
||||
server.connect(user_id, websocket)
|
||||
channel = await ChatChannel.get(1, session)
|
||||
if channel is not None:
|
||||
await server.join_channel(user, channel, session)
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == 1)
|
||||
)
|
||||
).first()
|
||||
if db_channel is not None:
|
||||
await server.join_channel(user, db_channel, session)
|
||||
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
|
||||
Reference in New Issue
Block a user