refactor(service): remove unused services
This commit is contained in:
@@ -1,214 +0,0 @@
|
|||||||
"""
|
|
||||||
Redis 消息队列服务
|
|
||||||
用于实现实时消息推送和异步数据库持久化
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import concurrent.futures
|
|
||||||
from datetime import datetime
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from app.database.chat import ChatMessage, MessageType
|
|
||||||
from app.dependencies.database import get_redis, with_db
|
|
||||||
from app.log import logger
|
|
||||||
from app.utils import bg_tasks
|
|
||||||
|
|
||||||
|
|
||||||
class MessageQueue:
|
|
||||||
"""Redis 消息队列服务"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.redis = get_redis()
|
|
||||||
self._processing = False
|
|
||||||
self._batch_size = 50 # 批量处理大小
|
|
||||||
self._batch_timeout = 1.0 # 批量处理超时时间(秒)
|
|
||||||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
|
||||||
|
|
||||||
async def _run_in_executor(self, func, *args):
|
|
||||||
"""在线程池中运行同步 Redis 操作"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(self._executor, func, *args)
|
|
||||||
|
|
||||||
async def start_processing(self):
|
|
||||||
"""启动消息处理任务"""
|
|
||||||
if not self._processing:
|
|
||||||
self._processing = True
|
|
||||||
bg_tasks.add_task(self._process_message_queue)
|
|
||||||
logger.info("Message queue processing started")
|
|
||||||
|
|
||||||
async def stop_processing(self):
|
|
||||||
"""停止消息处理"""
|
|
||||||
self._processing = False
|
|
||||||
logger.info("Message queue processing stopped")
|
|
||||||
|
|
||||||
async def enqueue_message(self, message_data: dict) -> str:
|
|
||||||
"""
|
|
||||||
将消息加入 Redis 队列(实时响应)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_data: 消息数据字典,包含所有必要的字段
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
消息的临时 UUID
|
|
||||||
"""
|
|
||||||
# 生成临时 UUID
|
|
||||||
temp_uuid = str(uuid.uuid4())
|
|
||||||
message_data["temp_uuid"] = temp_uuid
|
|
||||||
message_data["timestamp"] = datetime.now().isoformat()
|
|
||||||
message_data["status"] = "pending" # pending, processing, completed, failed
|
|
||||||
|
|
||||||
# 将消息存储到 Redis
|
|
||||||
await self._run_in_executor(lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data))
|
|
||||||
await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
|
|
||||||
|
|
||||||
# 加入处理队列
|
|
||||||
await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid)
|
|
||||||
|
|
||||||
logger.info(f"Message enqueued with temp_uuid: {temp_uuid}")
|
|
||||||
return temp_uuid
|
|
||||||
|
|
||||||
async def get_message_status(self, temp_uuid: str) -> dict | None:
|
|
||||||
"""获取消息状态"""
|
|
||||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
|
|
||||||
if not message_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return message_data
|
|
||||||
|
|
||||||
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
|
||||||
"""
|
|
||||||
从 Redis 获取缓存的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: 频道 ID
|
|
||||||
limit: 限制数量
|
|
||||||
since: 获取自此消息 ID 之后的消息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
消息列表
|
|
||||||
"""
|
|
||||||
# 从 Redis 获取频道最近的消息 UUID 列表
|
|
||||||
message_uuids = await self._run_in_executor(self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1)
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
for uuid_str in message_uuids:
|
|
||||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}")
|
|
||||||
if message_data:
|
|
||||||
# 检查是否满足 since 条件
|
|
||||||
if since > 0 and "message_id" in message_data:
|
|
||||||
if int(message_data["message_id"]) <= since:
|
|
||||||
continue
|
|
||||||
|
|
||||||
messages.append(message_data)
|
|
||||||
|
|
||||||
return messages[::-1] # 返回时间顺序
|
|
||||||
|
|
||||||
async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100):
|
|
||||||
"""将消息 UUID 缓存到频道消息列表"""
|
|
||||||
# 添加到频道消息列表开头
|
|
||||||
await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid)
|
|
||||||
# 限制缓存大小
|
|
||||||
await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1)
|
|
||||||
# 设置过期时间(24小时)
|
|
||||||
await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400)
|
|
||||||
|
|
||||||
async def _process_message_queue(self):
|
|
||||||
"""异步处理消息队列,批量写入数据库"""
|
|
||||||
while self._processing:
|
|
||||||
try:
|
|
||||||
# 批量获取消息
|
|
||||||
message_uuids = []
|
|
||||||
for _ in range(self._batch_size):
|
|
||||||
result = await self._run_in_executor(lambda: self.redis.brpop(["message_queue"], timeout=1))
|
|
||||||
if result:
|
|
||||||
message_uuids.append(result[1])
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
if message_uuids:
|
|
||||||
await self._process_message_batch(message_uuids)
|
|
||||||
else:
|
|
||||||
# 没有消息时短暂等待
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing message queue: {e}")
|
|
||||||
await asyncio.sleep(1) # 错误时等待1秒再重试
|
|
||||||
|
|
||||||
async def _process_message_batch(self, message_uuids: list[str]):
|
|
||||||
"""批量处理消息写入数据库"""
|
|
||||||
async with with_db() as session:
|
|
||||||
messages_to_insert = []
|
|
||||||
|
|
||||||
for temp_uuid in message_uuids:
|
|
||||||
try:
|
|
||||||
# 获取消息数据
|
|
||||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
|
|
||||||
if not message_data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 更新状态为处理中
|
|
||||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing")
|
|
||||||
|
|
||||||
# 创建数据库消息对象
|
|
||||||
msg = ChatMessage(
|
|
||||||
channel_id=int(message_data["channel_id"]),
|
|
||||||
content=message_data["content"],
|
|
||||||
sender_id=int(message_data["sender_id"]),
|
|
||||||
type=MessageType(message_data["type"]),
|
|
||||||
uuid=message_data.get("user_uuid"), # 用户提供的 UUID(如果有)
|
|
||||||
)
|
|
||||||
|
|
||||||
messages_to_insert.append((msg, temp_uuid))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error preparing message {temp_uuid}: {e}")
|
|
||||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
|
|
||||||
|
|
||||||
if messages_to_insert:
|
|
||||||
try:
|
|
||||||
# 批量插入数据库
|
|
||||||
for msg, temp_uuid in messages_to_insert:
|
|
||||||
session.add(msg)
|
|
||||||
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
# 更新所有消息状态和真实 ID
|
|
||||||
for msg, temp_uuid in messages_to_insert:
|
|
||||||
await session.refresh(msg)
|
|
||||||
await self._run_in_executor(
|
|
||||||
lambda: self.redis.hset(
|
|
||||||
f"msg:{temp_uuid}",
|
|
||||||
mapping={
|
|
||||||
"status": "completed",
|
|
||||||
"message_id": str(msg.message_id),
|
|
||||||
"created_at": msg.timestamp.isoformat() if msg.timestamp else "",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error inserting messages to database: {e}")
|
|
||||||
await session.rollback()
|
|
||||||
|
|
||||||
# 标记所有消息为失败
|
|
||||||
for _, temp_uuid in messages_to_insert:
|
|
||||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
|
|
||||||
|
|
||||||
|
|
||||||
# 全局消息队列实例
|
|
||||||
message_queue = MessageQueue()
|
|
||||||
|
|
||||||
|
|
||||||
async def start_message_queue():
|
|
||||||
"""启动消息队列处理"""
|
|
||||||
await message_queue.start_processing()
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_message_queue():
|
|
||||||
"""停止消息队列处理"""
|
|
||||||
await message_queue.stop_processing()
|
|
||||||
@@ -1,290 +0,0 @@
|
|||||||
"""
|
|
||||||
消息队列处理服务
|
|
||||||
专门处理 Redis 消息队列的异步写入数据库
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from datetime import datetime
|
|
||||||
import json
|
|
||||||
|
|
||||||
from app.database.chat import ChatMessage, MessageType
|
|
||||||
from app.dependencies.database import get_redis_message, with_db
|
|
||||||
from app.log import logger
|
|
||||||
|
|
||||||
|
|
||||||
class MessageQueueProcessor:
|
|
||||||
"""消息队列处理器"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.redis_message = get_redis_message()
|
|
||||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
|
||||||
self._processing = False
|
|
||||||
self._queue_task = None
|
|
||||||
|
|
||||||
async def _redis_exec(self, func, *args, **kwargs):
|
|
||||||
"""在线程池中执行 Redis 操作"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
|
|
||||||
|
|
||||||
async def cache_message(self, channel_id: int, message_data: dict, temp_uuid: str):
|
|
||||||
"""将消息缓存到 Redis"""
|
|
||||||
try:
|
|
||||||
# 存储消息数据
|
|
||||||
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data)
|
|
||||||
await self._redis_exec(self.redis_message.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
|
|
||||||
|
|
||||||
# 加入频道消息列表
|
|
||||||
await self._redis_exec(self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid)
|
|
||||||
await self._redis_exec(self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99) # 保持最新100条
|
|
||||||
await self._redis_exec(self.redis_message.expire, f"channel:{channel_id}:messages", 86400) # 24小时过期
|
|
||||||
|
|
||||||
# 加入异步处理队列
|
|
||||||
await self._redis_exec(self.redis_message.lpush, "message_write_queue", temp_uuid)
|
|
||||||
|
|
||||||
logger.info(f"Message cached to Redis: {temp_uuid}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to cache message to Redis: {e}")
|
|
||||||
|
|
||||||
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
|
||||||
"""从 Redis 获取缓存的消息"""
|
|
||||||
try:
|
|
||||||
message_uuids = await self._redis_exec(
|
|
||||||
self.redis_message.lrange,
|
|
||||||
f"channel:{channel_id}:messages",
|
|
||||||
0,
|
|
||||||
limit - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
for temp_uuid in message_uuids:
|
|
||||||
# 解码 UUID 如果它是字节类型
|
|
||||||
if isinstance(temp_uuid, bytes):
|
|
||||||
temp_uuid = temp_uuid.decode("utf-8")
|
|
||||||
|
|
||||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
|
||||||
if raw_data:
|
|
||||||
# 解码 Redis 返回的字节数据
|
|
||||||
message_data = {
|
|
||||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8")
|
|
||||||
if isinstance(v, bytes)
|
|
||||||
else v
|
|
||||||
for k, v in raw_data.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# 检查 since 条件
|
|
||||||
if since > 0 and message_data.get("message_id"):
|
|
||||||
if int(message_data["message_id"]) <= since:
|
|
||||||
continue
|
|
||||||
messages.append(message_data)
|
|
||||||
|
|
||||||
return messages[::-1] # 按时间顺序返回
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get cached messages: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def update_message_status(self, temp_uuid: str, status: str, message_id: int | None = None):
|
|
||||||
"""更新消息状态"""
|
|
||||||
try:
|
|
||||||
update_data = {"status": status}
|
|
||||||
if message_id:
|
|
||||||
update_data["message_id"] = str(message_id)
|
|
||||||
update_data["db_timestamp"] = datetime.now().isoformat()
|
|
||||||
|
|
||||||
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to update message status: {e}")
|
|
||||||
|
|
||||||
async def get_message_status(self, temp_uuid: str) -> dict | None:
|
|
||||||
"""获取消息状态"""
|
|
||||||
try:
|
|
||||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
|
||||||
if not raw_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 解码 Redis 返回的字节数据
|
|
||||||
return {
|
|
||||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v
|
|
||||||
for k, v in raw_data.items()
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get message status: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _process_message_queue(self):
|
|
||||||
"""处理消息队列,异步写入数据库"""
|
|
||||||
logger.info("Message queue processing started")
|
|
||||||
|
|
||||||
while self._processing:
|
|
||||||
try:
|
|
||||||
# 批量获取消息
|
|
||||||
message_uuids = []
|
|
||||||
for _ in range(20): # 批量处理20条消息
|
|
||||||
result = await self._redis_exec(self.redis_message.brpop, ["message_write_queue"], timeout=1)
|
|
||||||
if result:
|
|
||||||
# result是 (queue_name, value) 的元组,需要解码
|
|
||||||
uuid_value = result[1]
|
|
||||||
if isinstance(uuid_value, bytes):
|
|
||||||
uuid_value = uuid_value.decode("utf-8")
|
|
||||||
message_uuids.append(uuid_value)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not message_uuids:
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 批量写入数据库
|
|
||||||
await self._process_message_batch(message_uuids)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in message queue processing: {e}")
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
logger.info("Message queue processing stopped")
|
|
||||||
|
|
||||||
async def _process_message_batch(self, message_uuids: list[str]):
|
|
||||||
"""批量处理消息写入数据库"""
|
|
||||||
async with with_db() as session:
|
|
||||||
for temp_uuid in message_uuids:
|
|
||||||
try:
|
|
||||||
# 获取消息数据并解码
|
|
||||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
|
||||||
if not raw_data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 解码 Redis 返回的字节数据
|
|
||||||
message_data = {
|
|
||||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8")
|
|
||||||
if isinstance(v, bytes)
|
|
||||||
else v
|
|
||||||
for k, v in raw_data.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_data.get("status") != "pending":
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 更新状态为处理中
|
|
||||||
await self.update_message_status(temp_uuid, "processing")
|
|
||||||
|
|
||||||
# 创建数据库消息
|
|
||||||
msg = ChatMessage(
|
|
||||||
channel_id=int(message_data["channel_id"]),
|
|
||||||
content=message_data["content"],
|
|
||||||
sender_id=int(message_data["sender_id"]),
|
|
||||||
type=MessageType(message_data["type"]),
|
|
||||||
uuid=message_data.get("user_uuid") or None,
|
|
||||||
)
|
|
||||||
|
|
||||||
session.add(msg)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(msg)
|
|
||||||
|
|
||||||
# 更新成功状态,包含临时消息ID映射
|
|
||||||
await self.update_message_status(temp_uuid, "completed", msg.message_id)
|
|
||||||
|
|
||||||
# 如果有临时消息ID,存储映射关系并通知客户端更新
|
|
||||||
if message_data.get("temp_message_id"):
|
|
||||||
temp_msg_id = int(message_data["temp_message_id"])
|
|
||||||
await self._redis_exec(
|
|
||||||
self.redis_message.set,
|
|
||||||
f"temp_to_real:{temp_msg_id}",
|
|
||||||
str(msg.message_id),
|
|
||||||
ex=3600, # 1小时过期
|
|
||||||
)
|
|
||||||
|
|
||||||
# 发送消息ID更新通知到频道
|
|
||||||
channel_id = int(message_data["channel_id"])
|
|
||||||
await self._notify_message_update(channel_id, temp_msg_id, msg.message_id, message_data)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, "
|
|
||||||
f"temp_id: {message_data.get('temp_message_id')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to process message {temp_uuid}: {e}")
|
|
||||||
await self.update_message_status(temp_uuid, "failed")
|
|
||||||
|
|
||||||
async def _notify_message_update(
|
|
||||||
self,
|
|
||||||
channel_id: int,
|
|
||||||
temp_message_id: int,
|
|
||||||
real_message_id: int,
|
|
||||||
message_data: dict,
|
|
||||||
):
|
|
||||||
"""通知客户端消息ID已更新"""
|
|
||||||
try:
|
|
||||||
# 通过 Redis 发布消息更新事件,由聊天通知服务分发到客户端
|
|
||||||
update_event = {
|
|
||||||
"event": "chat.message.update",
|
|
||||||
"data": {
|
|
||||||
"channel_id": channel_id,
|
|
||||||
"temp_message_id": temp_message_id,
|
|
||||||
"real_message_id": real_message_id,
|
|
||||||
"timestamp": message_data.get("timestamp"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
await self._redis_exec(
|
|
||||||
self.redis_message.publish,
|
|
||||||
f"chat_updates:{channel_id}",
|
|
||||||
json.dumps(update_event),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to notify message update: {e}")
|
|
||||||
|
|
||||||
def start_processing(self):
|
|
||||||
"""启动消息队列处理"""
|
|
||||||
if not self._processing:
|
|
||||||
self._processing = True
|
|
||||||
self._queue_task = asyncio.create_task(self._process_message_queue())
|
|
||||||
logger.info("Message queue processor started")
|
|
||||||
|
|
||||||
def stop_processing(self):
|
|
||||||
"""停止消息队列处理"""
|
|
||||||
if self._processing:
|
|
||||||
self._processing = False
|
|
||||||
if self._queue_task:
|
|
||||||
self._queue_task.cancel()
|
|
||||||
self._queue_task = None
|
|
||||||
logger.info("Message queue processor stopped")
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""清理资源"""
|
|
||||||
if hasattr(self, "executor"):
|
|
||||||
self.executor.shutdown(wait=False)
|
|
||||||
|
|
||||||
|
|
||||||
# 全局消息队列处理器实例
|
|
||||||
message_queue_processor = MessageQueueProcessor()
|
|
||||||
|
|
||||||
|
|
||||||
def start_message_processing():
|
|
||||||
"""启动消息队列处理"""
|
|
||||||
message_queue_processor.start_processing()
|
|
||||||
|
|
||||||
|
|
||||||
def stop_message_processing():
|
|
||||||
"""停止消息队列处理"""
|
|
||||||
message_queue_processor.stop_processing()
|
|
||||||
|
|
||||||
|
|
||||||
async def cache_message_to_redis(channel_id: int, message_data: dict, temp_uuid: str):
|
|
||||||
"""将消息缓存到 Redis - 便捷接口"""
|
|
||||||
await message_queue_processor.cache_message(channel_id, message_data, temp_uuid)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_cached_messages(channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
|
||||||
"""从 Redis 获取缓存的消息 - 便捷接口"""
|
|
||||||
return await message_queue_processor.get_cached_messages(channel_id, limit, since)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_message_status(temp_uuid: str) -> dict | None:
|
|
||||||
"""获取消息状态 - 便捷接口"""
|
|
||||||
return await message_queue_processor.get_message_status(temp_uuid)
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
"""
|
|
||||||
优化的消息服务
|
|
||||||
结合 Redis 缓存和异步数据库写入实现实时消息传送
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from app.database.chat import (
|
|
||||||
ChannelType,
|
|
||||||
ChatMessageResp,
|
|
||||||
MessageType,
|
|
||||||
)
|
|
||||||
from app.database.user import User
|
|
||||||
from app.log import logger
|
|
||||||
from app.service.message_queue import message_queue
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizedMessageService:
|
|
||||||
"""优化的消息服务"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.message_queue = message_queue
|
|
||||||
|
|
||||||
async def send_message_fast(
|
|
||||||
self,
|
|
||||||
channel_id: int,
|
|
||||||
channel_type: ChannelType,
|
|
||||||
channel_name: str,
|
|
||||||
content: str,
|
|
||||||
sender: User,
|
|
||||||
is_action: bool = False,
|
|
||||||
user_uuid: str | None = None,
|
|
||||||
session: AsyncSession | None = None,
|
|
||||||
) -> ChatMessageResp:
|
|
||||||
"""
|
|
||||||
快速发送消息(先缓存到 Redis,异步写入数据库)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: 频道 ID
|
|
||||||
channel_type: 频道类型
|
|
||||||
channel_name: 频道名称
|
|
||||||
content: 消息内容
|
|
||||||
sender: 发送者
|
|
||||||
is_action: 是否为动作消息
|
|
||||||
user_uuid: 用户提供的 UUID
|
|
||||||
session: 数据库会话(可选,用于一些验证)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
消息响应对象
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 准备消息数据
|
|
||||||
message_data = {
|
|
||||||
"channel_id": str(channel_id),
|
|
||||||
"content": content,
|
|
||||||
"sender_id": str(sender.id),
|
|
||||||
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
|
|
||||||
"user_uuid": user_uuid or "",
|
|
||||||
"channel_type": channel_type.value,
|
|
||||||
"channel_name": channel_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 立即将消息加入 Redis 队列(实时响应)
|
|
||||||
temp_uuid = await self.message_queue.enqueue_message(message_data)
|
|
||||||
|
|
||||||
# 缓存到频道消息列表
|
|
||||||
await self.message_queue.cache_channel_message(channel_id, temp_uuid)
|
|
||||||
|
|
||||||
# 创建临时响应对象(简化版本,用于立即响应)
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.database.user import UserResp
|
|
||||||
|
|
||||||
# 创建基本的用户响应对象
|
|
||||||
user_resp = UserResp(
|
|
||||||
id=sender.id,
|
|
||||||
username=sender.username,
|
|
||||||
country_code=getattr(sender, "country_code", "XX"),
|
|
||||||
# 基本字段,其他复杂字段可以后续异步加载
|
|
||||||
)
|
|
||||||
|
|
||||||
temp_response = ChatMessageResp(
|
|
||||||
message_id=0, # 临时 ID,等数据库写入后会更新
|
|
||||||
channel_id=channel_id,
|
|
||||||
content=content,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
sender_id=sender.id,
|
|
||||||
sender=user_resp,
|
|
||||||
is_action=is_action,
|
|
||||||
uuid=user_uuid,
|
|
||||||
)
|
|
||||||
temp_response.temp_uuid = temp_uuid # 添加临时 UUID 用于后续更新
|
|
||||||
|
|
||||||
logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}")
|
|
||||||
return temp_response
|
|
||||||
|
|
||||||
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
|
||||||
"""
|
|
||||||
获取缓存的消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel_id: 频道 ID
|
|
||||||
limit: 限制数量
|
|
||||||
since: 获取自此消息 ID 之后的消息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
消息列表
|
|
||||||
"""
|
|
||||||
return await self.message_queue.get_cached_messages(channel_id, limit, since)
|
|
||||||
|
|
||||||
async def get_message_status(self, temp_uuid: str) -> dict | None:
|
|
||||||
"""
|
|
||||||
获取消息状态
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temp_uuid: 临时消息 UUID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
消息状态信息
|
|
||||||
"""
|
|
||||||
return await self.message_queue.get_message_status(temp_uuid)
|
|
||||||
|
|
||||||
async def wait_for_message_persisted(self, temp_uuid: str, timeout: int = 30) -> dict | None: # noqa: ASYNC109
|
|
||||||
"""
|
|
||||||
等待消息持久化到数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temp_uuid: 临时消息 UUID
|
|
||||||
timeout: 超时时间(秒)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
完成后的消息状态
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
for _ in range(timeout * 10): # 每100ms检查一次
|
|
||||||
status = await self.get_message_status(temp_uuid)
|
|
||||||
if status and status.get("status") in ["completed", "failed"]:
|
|
||||||
return status
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# 全局优化消息服务实例
|
|
||||||
optimized_message_service = OptimizedMessageService()
|
|
||||||
Reference in New Issue
Block a user