248 lines
8.6 KiB
Python
248 lines
8.6 KiB
Python
"""
|
||
Redis 消息队列服务
|
||
用于实现实时消息推送和异步数据库持久化
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import concurrent.futures
|
||
from datetime import datetime
|
||
import uuid
|
||
|
||
from app.database.chat import ChatMessage, MessageType
|
||
from app.dependencies.database import get_redis, with_db
|
||
from app.log import logger
|
||
|
||
|
||
class MessageQueue:
|
||
"""Redis 消息队列服务"""
|
||
|
||
def __init__(self):
|
||
self.redis = get_redis()
|
||
self._processing = False
|
||
self._batch_size = 50 # 批量处理大小
|
||
self._batch_timeout = 1.0 # 批量处理超时时间(秒)
|
||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||
|
||
async def _run_in_executor(self, func, *args):
|
||
"""在线程池中运行同步 Redis 操作"""
|
||
loop = asyncio.get_event_loop()
|
||
return await loop.run_in_executor(self._executor, func, *args)
|
||
|
||
async def start_processing(self):
|
||
"""启动消息处理任务"""
|
||
if not self._processing:
|
||
self._processing = True
|
||
asyncio.create_task(self._process_message_queue())
|
||
logger.info("Message queue processing started")
|
||
|
||
async def stop_processing(self):
|
||
"""停止消息处理"""
|
||
self._processing = False
|
||
logger.info("Message queue processing stopped")
|
||
|
||
async def enqueue_message(self, message_data: dict) -> str:
|
||
"""
|
||
将消息加入 Redis 队列(实时响应)
|
||
|
||
Args:
|
||
message_data: 消息数据字典,包含所有必要的字段
|
||
|
||
Returns:
|
||
消息的临时 UUID
|
||
"""
|
||
# 生成临时 UUID
|
||
temp_uuid = str(uuid.uuid4())
|
||
message_data["temp_uuid"] = temp_uuid
|
||
message_data["timestamp"] = datetime.now().isoformat()
|
||
message_data["status"] = "pending" # pending, processing, completed, failed
|
||
|
||
# 将消息存储到 Redis
|
||
await self._run_in_executor(
|
||
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data)
|
||
)
|
||
await self._run_in_executor(
|
||
self.redis.expire, f"msg:{temp_uuid}", 3600
|
||
) # 1小时过期
|
||
|
||
# 加入处理队列
|
||
await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid)
|
||
|
||
logger.info(f"Message enqueued with temp_uuid: {temp_uuid}")
|
||
return temp_uuid
|
||
|
||
async def get_message_status(self, temp_uuid: str) -> dict | None:
|
||
"""获取消息状态"""
|
||
message_data = await self._run_in_executor(
|
||
self.redis.hgetall, f"msg:{temp_uuid}"
|
||
)
|
||
if not message_data:
|
||
return None
|
||
|
||
return message_data
|
||
|
||
async def get_cached_messages(
|
||
self, channel_id: int, limit: int = 50, since: int = 0
|
||
) -> list[dict]:
|
||
"""
|
||
从 Redis 获取缓存的消息
|
||
|
||
Args:
|
||
channel_id: 频道 ID
|
||
limit: 限制数量
|
||
since: 获取自此消息 ID 之后的消息
|
||
|
||
Returns:
|
||
消息列表
|
||
"""
|
||
# 从 Redis 获取频道最近的消息 UUID 列表
|
||
message_uuids = await self._run_in_executor(
|
||
self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1
|
||
)
|
||
|
||
messages = []
|
||
for uuid_str in message_uuids:
|
||
message_data = await self._run_in_executor(
|
||
self.redis.hgetall, f"msg:{uuid_str}"
|
||
)
|
||
if message_data:
|
||
# 检查是否满足 since 条件
|
||
if since > 0 and "message_id" in message_data:
|
||
if int(message_data["message_id"]) <= since:
|
||
continue
|
||
|
||
messages.append(message_data)
|
||
|
||
return messages[::-1] # 返回时间顺序
|
||
|
||
async def cache_channel_message(
|
||
self, channel_id: int, temp_uuid: str, max_cache: int = 100
|
||
):
|
||
"""将消息 UUID 缓存到频道消息列表"""
|
||
# 添加到频道消息列表开头
|
||
await self._run_in_executor(
|
||
self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid
|
||
)
|
||
# 限制缓存大小
|
||
await self._run_in_executor(
|
||
self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1
|
||
)
|
||
# 设置过期时间(24小时)
|
||
await self._run_in_executor(
|
||
self.redis.expire, f"channel:{channel_id}:messages", 86400
|
||
)
|
||
|
||
async def _process_message_queue(self):
|
||
"""异步处理消息队列,批量写入数据库"""
|
||
while self._processing:
|
||
try:
|
||
# 批量获取消息
|
||
message_uuids = []
|
||
for _ in range(self._batch_size):
|
||
result = await self._run_in_executor(
|
||
lambda: self.redis.brpop(["message_queue"], timeout=1)
|
||
)
|
||
if result:
|
||
message_uuids.append(result[1])
|
||
else:
|
||
break
|
||
|
||
if message_uuids:
|
||
await self._process_message_batch(message_uuids)
|
||
else:
|
||
# 没有消息时短暂等待
|
||
await asyncio.sleep(0.1)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error processing message queue: {e}")
|
||
await asyncio.sleep(1) # 错误时等待1秒再重试
|
||
|
||
async def _process_message_batch(self, message_uuids: list[str]):
|
||
"""批量处理消息写入数据库"""
|
||
async with with_db() as session:
|
||
messages_to_insert = []
|
||
|
||
for temp_uuid in message_uuids:
|
||
try:
|
||
# 获取消息数据
|
||
message_data = await self._run_in_executor(
|
||
self.redis.hgetall, f"msg:{temp_uuid}"
|
||
)
|
||
if not message_data:
|
||
continue
|
||
|
||
# 更新状态为处理中
|
||
await self._run_in_executor(
|
||
self.redis.hset, f"msg:{temp_uuid}", "status", "processing"
|
||
)
|
||
|
||
# 创建数据库消息对象
|
||
msg = ChatMessage(
|
||
channel_id=int(message_data["channel_id"]),
|
||
content=message_data["content"],
|
||
sender_id=int(message_data["sender_id"]),
|
||
type=MessageType(message_data["type"]),
|
||
uuid=message_data.get("user_uuid"), # 用户提供的 UUID(如果有)
|
||
)
|
||
|
||
messages_to_insert.append((msg, temp_uuid))
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error preparing message {temp_uuid}: {e}")
|
||
await self._run_in_executor(
|
||
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
|
||
)
|
||
|
||
if messages_to_insert:
|
||
try:
|
||
# 批量插入数据库
|
||
for msg, temp_uuid in messages_to_insert:
|
||
session.add(msg)
|
||
|
||
await session.commit()
|
||
|
||
# 更新所有消息状态和真实 ID
|
||
for msg, temp_uuid in messages_to_insert:
|
||
await session.refresh(msg)
|
||
await self._run_in_executor(
|
||
lambda: self.redis.hset(
|
||
f"msg:{temp_uuid}",
|
||
mapping={
|
||
"status": "completed",
|
||
"message_id": str(msg.message_id),
|
||
"created_at": msg.timestamp.isoformat()
|
||
if msg.timestamp
|
||
else "",
|
||
},
|
||
)
|
||
)
|
||
|
||
logger.info(
|
||
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error inserting messages to database: {e}")
|
||
await session.rollback()
|
||
|
||
# 标记所有消息为失败
|
||
for _, temp_uuid in messages_to_insert:
|
||
await self._run_in_executor(
|
||
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
|
||
)
|
||
|
||
|
||
# 全局消息队列实例
|
||
message_queue = MessageQueue()
|
||
|
||
|
||
async def start_message_queue():
|
||
"""启动消息队列处理"""
|
||
await message_queue.start_processing()
|
||
|
||
|
||
async def stop_message_queue():
|
||
"""停止消息队列处理"""
|
||
await message_queue.stop_processing()
|