整理代码

This commit is contained in:
咕谷酱
2025-08-22 05:57:28 +08:00
parent ad131c0158
commit ce465aa049
20 changed files with 1078 additions and 799 deletions

View File

@@ -3,53 +3,52 @@ Redis 消息队列服务
用于实现实时消息推送和异步数据库持久化
"""
import asyncio
import json
import uuid
from datetime import datetime
from functools import partial
from typing import Optional, Union
import concurrent.futures
from __future__ import annotations
from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType
import asyncio
import concurrent.futures
from datetime import datetime
import uuid
from app.database.chat import ChatMessage, MessageType
from app.dependencies.database import get_redis, with_db
from app.log import logger
class MessageQueue:
"""Redis 消息队列服务"""
def __init__(self):
self.redis = get_redis()
self._processing = False
self._batch_size = 50 # 批量处理大小
self._batch_timeout = 1.0 # 批量处理超时时间(秒)
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
async def _run_in_executor(self, func, *args):
"""在线程池中运行同步 Redis 操作"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self._executor, func, *args)
async def start_processing(self):
"""启动消息处理任务"""
if not self._processing:
self._processing = True
asyncio.create_task(self._process_message_queue())
logger.info("Message queue processing started")
async def stop_processing(self):
"""停止消息处理"""
self._processing = False
logger.info("Message queue processing stopped")
async def enqueue_message(self, message_data: dict) -> str:
"""
将消息加入 Redis 队列(实时响应)
Args:
message_data: 消息数据字典,包含所有必要的字段
Returns:
消息的临时 UUID
"""
@@ -58,36 +57,42 @@ class MessageQueue:
message_data["temp_uuid"] = temp_uuid
message_data["timestamp"] = datetime.now().isoformat()
message_data["status"] = "pending" # pending, processing, completed, failed
# 将消息存储到 Redis
await self._run_in_executor(
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data)
)
await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
await self._run_in_executor(
self.redis.expire, f"msg:{temp_uuid}", 3600
) # 1小时过期
# 加入处理队列
await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid)
logger.info(f"Message enqueued with temp_uuid: {temp_uuid}")
return temp_uuid
async def get_message_status(self, temp_uuid: str) -> Optional[dict]:
async def get_message_status(self, temp_uuid: str) -> dict | None:
"""获取消息状态"""
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
message_data = await self._run_in_executor(
self.redis.hgetall, f"msg:{temp_uuid}"
)
if not message_data:
return None
return message_data
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
async def get_cached_messages(
self, channel_id: int, limit: int = 50, since: int = 0
) -> list[dict]:
"""
从 Redis 获取缓存的消息
Args:
channel_id: 频道 ID
limit: 限制数量
since: 获取自此消息 ID 之后的消息
Returns:
消息列表
"""
@@ -95,29 +100,39 @@ class MessageQueue:
message_uuids = await self._run_in_executor(
self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1
)
messages = []
for uuid_str in message_uuids:
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}")
message_data = await self._run_in_executor(
self.redis.hgetall, f"msg:{uuid_str}"
)
if message_data:
# 检查是否满足 since 条件
if since > 0 and "message_id" in message_data:
if int(message_data["message_id"]) <= since:
continue
messages.append(message_data)
return messages[::-1] # 返回时间顺序
async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100):
async def cache_channel_message(
self, channel_id: int, temp_uuid: str, max_cache: int = 100
):
"""将消息 UUID 缓存到频道消息列表"""
# 添加到频道消息列表开头
await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid)
await self._run_in_executor(
self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid
)
# 限制缓存大小
await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1)
await self._run_in_executor(
self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1
)
# 设置过期时间24小时
await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400)
await self._run_in_executor(
self.redis.expire, f"channel:{channel_id}:messages", 86400
)
async def _process_message_queue(self):
"""异步处理消息队列,批量写入数据库"""
while self._processing:
@@ -132,75 +147,90 @@ class MessageQueue:
message_uuids.append(result[1])
else:
break
if message_uuids:
await self._process_message_batch(message_uuids)
else:
# 没有消息时短暂等待
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error processing message queue: {e}")
await asyncio.sleep(1) # 错误时等待1秒再重试
async def _process_message_batch(self, message_uuids: list[str]):
"""批量处理消息写入数据库"""
async with with_db() as session:
messages_to_insert = []
for temp_uuid in message_uuids:
try:
# 获取消息数据
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
message_data = await self._run_in_executor(
self.redis.hgetall, f"msg:{temp_uuid}"
)
if not message_data:
continue
# 更新状态为处理中
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing")
await self._run_in_executor(
self.redis.hset, f"msg:{temp_uuid}", "status", "processing"
)
# 创建数据库消息对象
msg = ChatMessage(
channel_id=int(message_data["channel_id"]),
content=message_data["content"],
sender_id=int(message_data["sender_id"]),
type=MessageType(message_data["type"]),
uuid=message_data.get("user_uuid") # 用户提供的 UUID如果有
uuid=message_data.get("user_uuid"), # 用户提供的 UUID如果有
)
messages_to_insert.append((msg, temp_uuid))
except Exception as e:
logger.error(f"Error preparing message {temp_uuid}: {e}")
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
await self._run_in_executor(
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
)
if messages_to_insert:
try:
# 批量插入数据库
for msg, temp_uuid in messages_to_insert:
session.add(msg)
await session.commit()
# 更新所有消息状态和真实 ID
for msg, temp_uuid in messages_to_insert:
await session.refresh(msg)
await self._run_in_executor(
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping={
"status": "completed",
"message_id": str(msg.message_id),
"created_at": msg.timestamp.isoformat() if msg.timestamp else ""
})
lambda: self.redis.hset(
f"msg:{temp_uuid}",
mapping={
"status": "completed",
"message_id": str(msg.message_id),
"created_at": msg.timestamp.isoformat()
if msg.timestamp
else "",
},
)
)
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}")
logger.info(
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}"
)
except Exception as e:
logger.error(f"Error inserting messages to database: {e}")
await session.rollback()
# 标记所有消息为失败
for _, temp_uuid in messages_to_insert:
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
await self._run_in_executor(
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
)
# 全局消息队列实例