add message redis

This commit is contained in:
咕谷酱
2025-08-22 01:49:03 +08:00
parent 36b695b531
commit 1fe603f416
11 changed files with 1461 additions and 86 deletions

View File

@@ -53,16 +53,24 @@ async def get_update(
assert current_user.id
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
channel = await ChatChannel.get(channel_id, session)
if channel:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel:
# 提取必要的属性避免惰性加载
channel_type = db_channel.type
resp.presence.append(
await ChatChannelResp.from_db(
channel,
db_channel,
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel.type != ChannelType.PUBLIC
if channel_type != ChannelType.PUBLIC
else None,
)
)
@@ -105,7 +113,19 @@ async def join_channel(
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -125,7 +145,19 @@ async def leave_channel(
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -152,15 +184,19 @@ async def get_channel_list(
).all()
results = []
for channel in channels:
assert channel.channel_id is not None
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
channel_type = channel.type
assert channel_id is not None
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, [])
if channel.type != ChannelType.PUBLIC
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
)
)
@@ -185,14 +221,33 @@ async def get_channel(
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id is not None
# 立即提取需要的属性
channel_id = db_channel.channel_id
channel_type = db_channel.type
channel_name = db_channel.name
assert channel_id is not None
users = []
if db_channel.type == ChannelType.PM:
user_ids = db_channel.name.split("_")[1:]
if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:]
if len(user_ids) != 2:
raise HTTPException(status_code=404, detail="Target user not found")
for id_ in user_ids:
@@ -210,8 +265,8 @@ async def get_channel(
session,
current_user,
redis,
server.channels.get(db_channel.channel_id, [])
if db_channel.type != ChannelType.PUBLIC
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
)
)
@@ -270,7 +325,8 @@ async def create_channel(
channel_name = f"pm_{current_user.id}_{req.target_id}"
else:
channel_name = req.channel.name if req.channel else "Unnamed Channel"
channel = await ChatChannel.get(channel_name, session)
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
channel = result.first()
if channel is None:
channel = ChatChannel(
@@ -294,12 +350,16 @@ async def create_channel(
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session)
assert channel.channel_id
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
assert channel_id
return await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, []),
server.channels.get(channel_id, []),
include_recent_messages=True,
)

View File

@@ -1,5 +1,10 @@
from __future__ import annotations
import json
import uuid
from datetime import datetime
from typing import Optional
from app.database import ChatMessageResp
from app.database.chat import (
ChannelType,
@@ -11,11 +16,14 @@ from app.database.chat import (
UserSilenceResp,
)
from app.database.lazer_user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.database import Database, get_redis, get_redis_message
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.models.notification import ChannelMessage, ChannelMessageTeam
from app.router.v2 import api_v2_router as router
from app.service.optimized_message import optimized_message_service
from app.service.redis_message_system import redis_message_system
from app.log import logger
from .banchobot import bot
from .server import server
@@ -89,42 +97,73 @@ async def send_message(
req: MessageReq = Depends(BodyOrForm(MessageReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询来获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id
# 立即提取所有需要的属性,避免后续延迟加载
channel_id = db_channel.channel_id
channel_type = db_channel.type
channel_name = db_channel.name
assert channel_id is not None
assert current_user.id
msg = ChatMessage(
channel_id=db_channel.channel_id,
# 使用 Redis 消息系统发送消息 - 立即返回
resp = await redis_message_system.send_message(
channel_id=channel_id,
user=current_user,
content=req.message,
sender_id=current_user.id,
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
uuid=req.uuid,
is_action=req.is_action,
user_uuid=req.uuid
)
session.add(msg)
await session.commit()
await session.refresh(msg)
await session.refresh(current_user)
await session.refresh(db_channel)
resp = await ChatMessageResp.from_db(msg, session, current_user)
# 立即广播消息给所有客户端
is_bot_command = req.message.startswith("!")
await server.send_message_to_channel(
resp, is_bot_command and db_channel.type == ChannelType.PUBLIC
resp, is_bot_command and channel_type == ChannelType.PUBLIC
)
# 处理机器人命令
if is_bot_command:
await bot.try_handle(current_user, db_channel, req.message, session)
if db_channel.type == ChannelType.PM:
user_ids = db_channel.name.split("_")[1:]
await server.new_private_notification(
ChannelMessage.init(
msg, current_user, [int(u) for u in user_ids], db_channel.type
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
temp_msg = ChatMessage(
message_id=resp.message_id, # 使用 Redis 系统生成的ID
channel_id=channel_id,
content=req.message,
sender_id=current_user.id,
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
uuid=req.uuid,
)
if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:]
await server.new_private_notification(
ChannelMessage.init(
temp_msg, current_user, [int(u) for u in user_ids], channel_type
)
)
)
elif db_channel.type == ChannelType.TEAM:
await server.new_private_notification(
ChannelMessageTeam.init(msg, current_user)
)
elif channel_type == ChannelType.TEAM:
await server.new_private_notification(
ChannelMessageTeam.init(temp_msg, current_user)
)
return resp
@@ -143,21 +182,46 @@ async def get_message(
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
messages = await session.exec(
select(ChatMessage)
.where(
ChatMessage.channel_id == db_channel.channel_id,
col(ChatMessage.message_id) > since,
col(ChatMessage.message_id) < until if until is not None else True,
)
.order_by(col(ChatMessage.timestamp).desc())
.limit(limit)
)
# 提取必要的属性避免惰性加载
channel_id = db_channel.channel_id
assert channel_id is not None
# 使用 Redis 消息系统获取消息
try:
messages = await redis_message_system.get_messages(channel_id, limit, since)
return messages
except Exception as e:
logger.warning(f"Failed to get messages from Redis system: {e}")
# 回退到传统数据库查询
pass
# 回退到数据库查询
query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
if since > 0:
query = query.where(col(ChatMessage.message_id) > since)
if until is not None:
query = query.where(col(ChatMessage.message_id) < until)
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
messages = (await session.exec(query)).all()
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
resp.reverse()
return resp
@@ -174,12 +238,28 @@ async def mark_as_read(
message: int = Path(..., description="消息 ID"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id
# 立即提取需要的属性
channel_id = db_channel.channel_id
assert channel_id
assert current_user.id
await server.mark_as_read(db_channel.channel_id, current_user.id, message)
await server.mark_as_read(channel_id, current_user.id, message)
class PMReq(BaseModel):

View File

@@ -59,9 +59,14 @@ class ChatServer:
for channel_id, channel in self.channels.items():
if user_id in channel:
channel.remove(user_id)
channel = await ChatChannel.get(channel_id, session)
if channel:
await self.leave_channel(user, channel, session)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel:
await self.leave_channel(user, db_channel, session)
@overload
async def send_event(self, client: int, event: ChatEvent): ...
@@ -79,8 +84,11 @@ class ChatServer:
await client.send_text(event.model_dump_json())
async def broadcast(self, channel_id: int, event: ChatEvent):
for user_id in self.channels.get(channel_id, []):
users_in_channel = self.channels.get(channel_id, [])
logger.info(f"Broadcasting to channel {channel_id}, users: {users_in_channel}")
for user_id in users_in_channel:
await self.send_event(user_id, event)
logger.debug(f"Sent event to user {user_id} in channel {channel_id}")
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
@@ -88,24 +96,35 @@ class ChatServer:
async def send_message_to_channel(
self, message: ChatMessageResp, is_bot_command: bool = False
):
logger.info(f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}")
event = ChatEvent(
event="chat.message.new",
data={"messages": [message], "users": [message.sender]},
)
if is_bot_command:
logger.info(f"Sending bot command to user {message.sender_id}")
self._add_task(self.send_event(message.sender_id, event))
else:
# 总是广播消息无论是临时ID还是真实ID
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
self._add_task(
self.broadcast(
message.channel_id,
event,
)
)
assert message.message_id
await self.mark_as_read(
message.channel_id, message.sender_id, message.message_id
)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
# 只有真实消息 ID正数且非零才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数所以这里应该都能正常处理
if message.message_id and message.message_id > 0:
await self.mark_as_read(
message.channel_id, message.sender_id, message.message_id
)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
else:
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
async def batch_join_channel(
self, users: list[User], channel: ChatChannel, session: AsyncSession
@@ -206,27 +225,37 @@ class ChatServer:
async def join_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.join_channel(user, channel, session)
await self.join_channel(user, db_channel, session)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.leave_channel(user, channel, session)
await self.leave_channel(user, db_channel, session)
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
@@ -309,7 +338,13 @@ async def chat_websocket(
user_id = user.id
assert user_id
server.connect(user_id, websocket)
channel = await ChatChannel.get(1, session)
if channel is not None:
await server.join_channel(user, channel, session)
await _listen_stop(websocket, user_id, factory)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == 1)
)
).first()
if db_channel is not None:
await server.join_channel(user, db_channel, session)
await _listen_stop(websocket, user_id, factory)