整理代码
This commit is contained in:
@@ -62,7 +62,7 @@ async def get_update(
|
||||
if db_channel:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_type = db_channel.type
|
||||
|
||||
|
||||
resp.presence.append(
|
||||
await ChatChannelResp.from_db(
|
||||
db_channel,
|
||||
@@ -122,9 +122,7 @@ async def join_channel(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
@@ -154,9 +152,7 @@ async def leave_channel(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
@@ -187,7 +183,7 @@ async def get_channel_list(
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
channel_type = channel.type
|
||||
|
||||
|
||||
assert channel_id is not None
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
@@ -230,19 +226,17 @@ async def get_channel(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
|
||||
assert channel_id is not None
|
||||
|
||||
users = []
|
||||
@@ -325,7 +319,9 @@ async def create_channel(
|
||||
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
||||
else:
|
||||
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel_name)
|
||||
)
|
||||
channel = result.first()
|
||||
|
||||
if channel is None:
|
||||
@@ -350,11 +346,11 @@ async def create_channel(
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id
|
||||
|
||||
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
@@ -16,14 +11,13 @@ from app.database.chat import (
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import Database, get_redis, get_redis_message
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
from app.models.notification import ChannelMessage, ChannelMessageTeam
|
||||
from app.router.v2 import api_v2_router as router
|
||||
from app.service.optimized_message import optimized_message_service
|
||||
from app.service.redis_message_system import redis_message_system
|
||||
from app.log import logger
|
||||
|
||||
from .banchobot import bot
|
||||
from .server import server
|
||||
@@ -106,11 +100,9 @@ async def send_message(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
@@ -118,29 +110,29 @@ async def send_message(
|
||||
channel_id = db_channel.channel_id
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
|
||||
assert channel_id is not None
|
||||
assert current_user.id
|
||||
|
||||
|
||||
# 使用 Redis 消息系统发送消息 - 立即返回
|
||||
resp = await redis_message_system.send_message(
|
||||
channel_id=channel_id,
|
||||
user=current_user,
|
||||
content=req.message,
|
||||
is_action=req.is_action,
|
||||
user_uuid=req.uuid
|
||||
user_uuid=req.uuid,
|
||||
)
|
||||
|
||||
|
||||
# 立即广播消息给所有客户端
|
||||
is_bot_command = req.message.startswith("!")
|
||||
await server.send_message_to_channel(
|
||||
resp, is_bot_command and channel_type == ChannelType.PUBLIC
|
||||
)
|
||||
|
||||
|
||||
# 处理机器人命令
|
||||
if is_bot_command:
|
||||
await bot.try_handle(current_user, db_channel, req.message, session)
|
||||
|
||||
|
||||
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
|
||||
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
|
||||
temp_msg = ChatMessage(
|
||||
@@ -151,7 +143,7 @@ async def send_message(
|
||||
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||
uuid=req.uuid,
|
||||
)
|
||||
|
||||
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
await server.new_private_notification(
|
||||
@@ -163,7 +155,7 @@ async def send_message(
|
||||
await server.new_private_notification(
|
||||
ChannelMessageTeam.init(temp_msg, current_user)
|
||||
)
|
||||
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@@ -191,11 +183,9 @@ async def get_message(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
@@ -218,7 +208,7 @@ async def get_message(
|
||||
query = query.where(col(ChatMessage.message_id) > since)
|
||||
if until is not None:
|
||||
query = query.where(col(ChatMessage.message_id) < until)
|
||||
|
||||
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
messages = (await session.exec(query)).all()
|
||||
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
@@ -247,14 +237,12 @@ async def mark_as_read(
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
assert channel_id
|
||||
|
||||
@@ -96,8 +96,10 @@ class ChatServer:
|
||||
async def send_message_to_channel(
|
||||
self, message: ChatMessageResp, is_bot_command: bool = False
|
||||
):
|
||||
logger.info(f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}")
|
||||
|
||||
logger.info(
|
||||
f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}"
|
||||
)
|
||||
|
||||
event = ChatEvent(
|
||||
event="chat.message.new",
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
@@ -107,24 +109,32 @@ class ChatServer:
|
||||
self._add_task(self.send_event(message.sender_id, event))
|
||||
else:
|
||||
# 总是广播消息,无论是临时ID还是真实ID
|
||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||||
logger.info(
|
||||
f"Broadcasting message to all users in channel {message.channel_id}"
|
||||
)
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||
if message.message_id and message.message_id > 0:
|
||||
await self.mark_as_read(
|
||||
message.channel_id, message.sender_id, message.message_id
|
||||
)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||||
await self.redis.set(
|
||||
f"chat:{message.channel_id}:last_msg", message.message_id
|
||||
)
|
||||
logger.info(
|
||||
f"Updated last message ID for channel {message.channel_id} to {message.message_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||||
logger.debug(
|
||||
f"Skipping last message update for message ID: {message.message_id}"
|
||||
)
|
||||
|
||||
async def batch_join_channel(
|
||||
self, users: list[User], channel: ChatChannel, session: AsyncSession
|
||||
@@ -340,11 +350,9 @@ async def chat_websocket(
|
||||
server.connect(user_id, websocket)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == 1)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))
|
||||
).first()
|
||||
if db_channel is not None:
|
||||
await server.join_channel(user, db_channel, session)
|
||||
|
||||
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
|
||||
Reference in New Issue
Block a user