refactor(message): replace synchronous Redis client with asynchronous client
This commit is contained in:
@@ -9,7 +9,6 @@ from app.config import settings
|
||||
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
import redis as sync_redis
|
||||
import redis.asyncio as redis
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
@@ -41,8 +40,8 @@ redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||
# Redis 二进制数据连接 (不自动解码响应,用于存储音频等二进制数据)
|
||||
redis_binary_client = redis.from_url(settings.redis_url, decode_responses=False)
|
||||
|
||||
# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行
|
||||
redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1)
|
||||
# Redis 消息缓存连接 (db1)
|
||||
redis_message_client: redis.Redis = redis.from_url(settings.redis_url, decode_responses=True, db=1)
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
@@ -97,7 +96,7 @@ def get_redis_binary():
|
||||
return redis_binary_client
|
||||
|
||||
|
||||
def get_redis_message():
|
||||
def get_redis_message() -> redis.Redis:
|
||||
"""获取消息专用的 Redis 客户端 (db1)"""
|
||||
return redis_message_client
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import json
|
||||
import time
|
||||
@@ -23,18 +22,12 @@ class RedisMessageSystem:
|
||||
"""Redis 消息系统"""
|
||||
|
||||
def __init__(self):
|
||||
self.redis = get_redis_message()
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
self.redis: Any = get_redis_message()
|
||||
self._batch_timer: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self.batch_interval = 5.0 # 5秒批量存储一次
|
||||
self.max_batch_size = 100 # 每批最多处理100条消息
|
||||
|
||||
async def _redis_exec(self, func, *args, **kwargs):
|
||||
"""在线程池中执行 Redis 操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
channel_id: int,
|
||||
@@ -216,10 +209,10 @@ class RedisMessageSystem:
|
||||
async def _generate_message_id(self, channel_id: int) -> int:
|
||||
"""生成唯一的消息ID - 确保全局唯一且严格递增"""
|
||||
# 使用全局计数器确保所有频道的消息ID都是严格递增的
|
||||
message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter")
|
||||
message_id = await self.redis.incr("global_message_id_counter")
|
||||
|
||||
# 同时更新频道的最后消息ID,用于客户端状态同步
|
||||
await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id)
|
||||
await self.redis.set(f"channel:{channel_id}:last_msg_id", message_id)
|
||||
|
||||
return message_id
|
||||
|
||||
@@ -230,73 +223,70 @@ class RedisMessageSystem:
|
||||
is_multiplayer = message_data.get("is_multiplayer", False)
|
||||
|
||||
# 存储消息数据
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
await self.redis.hset(
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
mapping={k: json.dumps(v) if isinstance(v, dict | list) else str(v) for k, v in message_data.items()},
|
||||
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in message_data.items()},
|
||||
)
|
||||
|
||||
# 设置消息过期时间(7天)
|
||||
await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800)
|
||||
await self.redis.expire(f"msg:{channel_id}:{message_id}", 604800)
|
||||
|
||||
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
|
||||
channel_messages_key = f"channel:{channel_id}:messages"
|
||||
|
||||
# 更健壮的键类型检查和清理
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, channel_messages_key)
|
||||
key_type = await self.redis.type(channel_messages_key)
|
||||
if key_type == "none":
|
||||
# 键不存在,这是正常的
|
||||
pass
|
||||
elif key_type != "zset":
|
||||
# 键类型错误,需要清理
|
||||
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
await self.redis.delete(channel_messages_key)
|
||||
|
||||
# 验证删除是否成功
|
||||
verify_type = await self._redis_exec(self.redis.type, channel_messages_key)
|
||||
verify_type = await self.redis.type(channel_messages_key)
|
||||
if verify_type != "none":
|
||||
logger.error(
|
||||
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
|
||||
)
|
||||
# 强制删除
|
||||
await self._redis_exec(self.redis.unlink, channel_messages_key)
|
||||
await self.redis.unlink(channel_messages_key)
|
||||
|
||||
except Exception as type_check_error:
|
||||
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
||||
# 如果检查失败,尝试强制删除键以确保清理
|
||||
try:
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
await self.redis.delete(channel_messages_key)
|
||||
except Exception:
|
||||
# 最后的努力:使用unlink
|
||||
try:
|
||||
await self._redis_exec(self.redis.unlink, channel_messages_key)
|
||||
await self.redis.unlink(channel_messages_key)
|
||||
except Exception as final_error:
|
||||
logger.error(f"Critical: Unable to clear problematic key {channel_messages_key}: {final_error}")
|
||||
|
||||
# 添加到频道消息列表(sorted set)
|
||||
try:
|
||||
await self._redis_exec(
|
||||
self.redis.zadd,
|
||||
await self.redis.zadd(
|
||||
channel_messages_key,
|
||||
{f"msg:{channel_id}:{message_id}": message_id},
|
||||
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
||||
)
|
||||
except Exception as zadd_error:
|
||||
logger.error(f"Failed to add message to sorted set {channel_messages_key}: {zadd_error}")
|
||||
# 如果添加失败,再次尝试清理并重试
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
await self._redis_exec(
|
||||
self.redis.zadd,
|
||||
await self.redis.delete(channel_messages_key)
|
||||
await self.redis.zadd(
|
||||
channel_messages_key,
|
||||
{f"msg:{channel_id}:{message_id}": message_id},
|
||||
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
||||
)
|
||||
|
||||
# 保持频道消息列表大小(最多1000条)
|
||||
await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001)
|
||||
await self.redis.zremrangebyrank(channel_messages_key, 0, -1001)
|
||||
|
||||
# 只有非多人房间消息才添加到待持久化队列
|
||||
if not is_multiplayer:
|
||||
await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}")
|
||||
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
|
||||
logger.debug(f"Message {message_id} added to persistence queue")
|
||||
else:
|
||||
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
|
||||
@@ -311,8 +301,7 @@ class RedisMessageSystem:
|
||||
# 获取消息键列表,按消息ID排序
|
||||
if since > 0:
|
||||
# 获取指定ID之后的消息(正序)
|
||||
message_keys = await self._redis_exec(
|
||||
self.redis.zrangebyscore,
|
||||
message_keys = await self.redis.zrangebyscore(
|
||||
f"channel:{channel_id}:messages",
|
||||
since + 1,
|
||||
"+inf",
|
||||
@@ -321,32 +310,24 @@ class RedisMessageSystem:
|
||||
)
|
||||
else:
|
||||
# 获取最新的消息(倒序获取,然后反转)
|
||||
message_keys = await self._redis_exec(
|
||||
self.redis.zrevrange, f"channel:{channel_id}:messages", 0, limit - 1
|
||||
)
|
||||
message_keys = await self.redis.zrevrange(f"channel:{channel_id}:messages", 0, limit - 1)
|
||||
|
||||
messages = []
|
||||
for key in message_keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode("utf-8")
|
||||
|
||||
# 获取消息数据
|
||||
raw_data = await self._redis_exec(self.redis.hgetall, key)
|
||||
raw_data = await self.redis.hgetall(key)
|
||||
if raw_data:
|
||||
# 解码数据
|
||||
message_data = {}
|
||||
message_data: dict[str, Any] = {}
|
||||
for k, v in raw_data.items():
|
||||
if isinstance(k, bytes):
|
||||
k = k.decode("utf-8")
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode("utf-8")
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
|
||||
message_data[k] = json.loads(v)
|
||||
elif k in ["message_id", "channel_id", "sender_id"]:
|
||||
message_data[k] = int(v)
|
||||
elif k == "is_multiplayer":
|
||||
message_data[k] = v == "True"
|
||||
elif k == "created_at":
|
||||
message_data[k] = float(v)
|
||||
else:
|
||||
@@ -442,12 +423,10 @@ class RedisMessageSystem:
|
||||
# 获取待处理的消息
|
||||
message_keys = []
|
||||
for _ in range(self.max_batch_size):
|
||||
key = await self._redis_exec(self.redis.brpop, ["pending_messages"], timeout=1)
|
||||
key = await self.redis.brpop("pending_messages", timeout=1)
|
||||
if key:
|
||||
# key 是 (queue_name, value) 的元组
|
||||
value = key[1]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
_, value = key
|
||||
message_keys.append(value)
|
||||
else:
|
||||
break
|
||||
@@ -472,7 +451,7 @@ class RedisMessageSystem:
|
||||
channel_id, message_id = map(int, key.split(":"))
|
||||
|
||||
# 从 Redis 获取消息数据
|
||||
raw_data = await self._redis_exec(self.redis.hgetall, f"msg:{channel_id}:{message_id}")
|
||||
raw_data = await self.redis.hgetall(f"msg:{channel_id}:{message_id}")
|
||||
|
||||
if not raw_data:
|
||||
continue
|
||||
@@ -480,18 +459,13 @@ class RedisMessageSystem:
|
||||
# 解码数据
|
||||
message_data = {}
|
||||
for k, v in raw_data.items():
|
||||
if isinstance(k, bytes):
|
||||
k = k.decode("utf-8")
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode("utf-8")
|
||||
message_data[k] = v
|
||||
|
||||
# 检查是否是多人房间消息,如果是则跳过数据库存储
|
||||
is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
|
||||
if is_multiplayer:
|
||||
# 多人房间消息不存储到数据库,直接标记为已跳过
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
await self.redis.hset(
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
"status",
|
||||
"skipped_multiplayer",
|
||||
@@ -518,8 +492,7 @@ class RedisMessageSystem:
|
||||
session.add(db_message)
|
||||
|
||||
# 更新 Redis 中的状态
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
await self.redis.hset(
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
"status",
|
||||
"persisted",
|
||||
@@ -563,57 +536,54 @@ class RedisMessageSystem:
|
||||
max_id = result.one() or 0
|
||||
|
||||
# 检查 Redis 中的计数器值
|
||||
current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter")
|
||||
current_counter = await self.redis.get("global_message_id_counter")
|
||||
current_counter = int(current_counter) if current_counter else 0
|
||||
|
||||
# 设置计数器为两者中的最大值
|
||||
initial_counter = max(max_id, current_counter)
|
||||
await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter)
|
||||
await self.redis.set("global_message_id_counter", initial_counter)
|
||||
|
||||
logger.info(f"Initialized global message ID counter to {initial_counter}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize message counter: {e}")
|
||||
# 如果初始化失败,设置一个安全的起始值
|
||||
await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000)
|
||||
await self.redis.setnx("global_message_id_counter", 1000000)
|
||||
|
||||
async def _cleanup_redis_keys(self):
|
||||
"""清理可能存在问题的 Redis 键"""
|
||||
try:
|
||||
# 扫描所有 channel:*:messages 键并检查类型
|
||||
keys_pattern = "channel:*:messages"
|
||||
keys = await self._redis_exec(self.redis.keys, keys_pattern)
|
||||
keys = await self.redis.keys(keys_pattern)
|
||||
|
||||
fixed_count = 0
|
||||
for key in keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode("utf-8")
|
||||
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, key)
|
||||
key_type = await self.redis.type(key)
|
||||
if key_type == "none":
|
||||
# 键不存在,正常情况
|
||||
continue
|
||||
elif key_type != "zset":
|
||||
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}")
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
await self.redis.delete(key)
|
||||
|
||||
# 验证删除是否成功
|
||||
verify_type = await self._redis_exec(self.redis.type, key)
|
||||
verify_type = await self.redis.type(key)
|
||||
if verify_type != "none":
|
||||
logger.error(f"Failed to delete problematic key {key}, trying unlink...")
|
||||
await self._redis_exec(self.redis.unlink, key)
|
||||
await self.redis.unlink(key)
|
||||
|
||||
fixed_count += 1
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")
|
||||
# 强制删除问题键
|
||||
try:
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
await self.redis.delete(key)
|
||||
fixed_count += 1
|
||||
except Exception:
|
||||
try:
|
||||
await self._redis_exec(self.redis.unlink, key)
|
||||
await self.redis.unlink(key)
|
||||
fixed_count += 1
|
||||
except Exception as final_error:
|
||||
logger.error(f"Critical: Unable to clear problematic key {key}: {final_error}")
|
||||
@@ -654,11 +624,6 @@ class RedisMessageSystem:
|
||||
self._batch_timer = None
|
||||
logger.info("Redis message system stopped")
|
||||
|
||||
def __del__(self):
|
||||
"""清理资源"""
|
||||
if hasattr(self, "executor"):
|
||||
self.executor.shutdown(wait=False)
|
||||
|
||||
|
||||
# 全局消息系统实例
|
||||
redis_message_system = RedisMessageSystem()
|
||||
|
||||
Reference in New Issue
Block a user