diff --git a/app/dependencies/database.py b/app/dependencies/database.py index 49c377c..3d490f2 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -9,7 +9,6 @@ from app.config import settings from fastapi import Depends from pydantic import BaseModel -import redis as sync_redis import redis.asyncio as redis from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel @@ -41,8 +40,8 @@ redis_client = redis.from_url(settings.redis_url, decode_responses=True) # Redis 二进制数据连接 (不自动解码响应,用于存储音频等二进制数据) redis_binary_client = redis.from_url(settings.redis_url, decode_responses=False) -# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行 -redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1) +# Redis 消息缓存连接 (db1) +redis_message_client: redis.Redis = redis.from_url(settings.redis_url, decode_responses=True, db=1) # 数据库依赖 @@ -97,7 +96,7 @@ def get_redis_binary(): return redis_binary_client -def get_redis_message(): +def get_redis_message() -> redis.Redis: """获取消息专用的 Redis 客户端 (db1)""" return redis_message_client diff --git a/app/service/redis_message_system.py b/app/service/redis_message_system.py index 6bf4540..34fefba 100644 --- a/app/service/redis_message_system.py +++ b/app/service/redis_message_system.py @@ -6,7 +6,6 @@ """ import asyncio -from concurrent.futures import ThreadPoolExecutor from datetime import datetime import json import time @@ -23,18 +22,12 @@ class RedisMessageSystem: """Redis 消息系统""" def __init__(self): - self.redis = get_redis_message() - self.executor = ThreadPoolExecutor(max_workers=2) + self.redis: Any = get_redis_message() self._batch_timer: asyncio.Task | None = None self._running = False self.batch_interval = 5.0 # 5秒批量存储一次 self.max_batch_size = 100 # 每批最多处理100条消息 - async def _redis_exec(self, func, *args, **kwargs): - """在线程池中执行 Redis 操作""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs)) - async def send_message( self, channel_id: int, @@ -216,10 +209,10 @@ class RedisMessageSystem: async def _generate_message_id(self, channel_id: int) -> int: """生成唯一的消息ID - 确保全局唯一且严格递增""" # 使用全局计数器确保所有频道的消息ID都是严格递增的 - message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter") + message_id = await self.redis.incr("global_message_id_counter") # 同时更新频道的最后消息ID,用于客户端状态同步 - await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id) + await self.redis.set(f"channel:{channel_id}:last_msg_id", message_id) return message_id @@ -230,73 +223,70 @@ class RedisMessageSystem: is_multiplayer = message_data.get("is_multiplayer", False) # 存储消息数据 - await self._redis_exec( - self.redis.hset, + await self.redis.hset( f"msg:{channel_id}:{message_id}", - mapping={k: json.dumps(v) if isinstance(v, dict | list) else str(v) for k, v in message_data.items()}, + mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in message_data.items()}, ) # 设置消息过期时间(7天) - await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800) + await self.redis.expire(f"msg:{channel_id}:{message_id}", 604800) # 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序) channel_messages_key = f"channel:{channel_id}:messages" # 更健壮的键类型检查和清理 try: - key_type = await self._redis_exec(self.redis.type, channel_messages_key) + key_type = await self.redis.type(channel_messages_key) if key_type == "none": # 键不存在,这是正常的 pass elif key_type != "zset": # 键类型错误,需要清理 logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}") - await self._redis_exec(self.redis.delete, channel_messages_key) + await self.redis.delete(channel_messages_key) # 验证删除是否成功 - verify_type = await self._redis_exec(self.redis.type, channel_messages_key) + verify_type = await self.redis.type(channel_messages_key) if verify_type != "none": logger.error( f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}" ) # 强制删除 - await self._redis_exec(self.redis.unlink, channel_messages_key) + await self.redis.unlink(channel_messages_key) except Exception as type_check_error: logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}") # 如果检查失败,尝试强制删除键以确保清理 try: - await self._redis_exec(self.redis.delete, channel_messages_key) + await self.redis.delete(channel_messages_key) except Exception: # 最后的努力:使用unlink try: - await self._redis_exec(self.redis.unlink, channel_messages_key) + await self.redis.unlink(channel_messages_key) except Exception as final_error: logger.error(f"Critical: Unable to clear problematic key {channel_messages_key}: {final_error}") # 添加到频道消息列表(sorted set) try: - await self._redis_exec( - self.redis.zadd, + await self.redis.zadd( channel_messages_key, - {f"msg:{channel_id}:{message_id}": message_id}, + mapping={f"msg:{channel_id}:{message_id}": message_id}, ) except Exception as zadd_error: logger.error(f"Failed to add message to sorted set {channel_messages_key}: {zadd_error}") # 如果添加失败,再次尝试清理并重试 - await self._redis_exec(self.redis.delete, channel_messages_key) - await self._redis_exec( - self.redis.zadd, + await self.redis.delete(channel_messages_key) + await self.redis.zadd( channel_messages_key, - {f"msg:{channel_id}:{message_id}": message_id}, + mapping={f"msg:{channel_id}:{message_id}": message_id}, ) # 保持频道消息列表大小(最多1000条) - await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001) + await self.redis.zremrangebyrank(channel_messages_key, 0, -1001) # 只有非多人房间消息才添加到待持久化队列 if not is_multiplayer: - await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}") + await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}") logger.debug(f"Message {message_id} added to persistence queue") else: logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue") @@ -311,8 +301,7 @@ class RedisMessageSystem: # 获取消息键列表,按消息ID排序 if since > 0: # 获取指定ID之后的消息(正序) - message_keys = await self._redis_exec( - self.redis.zrangebyscore, + message_keys = await self.redis.zrangebyscore( f"channel:{channel_id}:messages", since + 1, "+inf", @@ -321,32 +310,24 @@ class RedisMessageSystem: ) else: # 获取最新的消息(倒序获取,然后反转) - message_keys = await self._redis_exec( - self.redis.zrevrange, f"channel:{channel_id}:messages", 0, limit - 1 - ) + message_keys = await self.redis.zrevrange(f"channel:{channel_id}:messages", 0, limit - 1) messages = [] for key in message_keys: - if isinstance(key, bytes): - key = key.decode("utf-8") - # 获取消息数据 - raw_data = await self._redis_exec(self.redis.hgetall, key) + raw_data = await self.redis.hgetall(key) if raw_data: # 解码数据 - message_data = {} + message_data: dict[str, Any] = {} for k, v in raw_data.items(): - if isinstance(k, bytes): - k = k.decode("utf-8") - if isinstance(v, bytes): - v = v.decode("utf-8") - # 尝试解析 JSON try: if k in ["grade_counts", "level"] or v.startswith(("{", "[")): message_data[k] = json.loads(v) elif k in ["message_id", "channel_id", "sender_id"]: message_data[k] = int(v) + elif k == "is_multiplayer": + message_data[k] = v == "True" elif k == "created_at": message_data[k] = float(v) else: @@ -442,12 +423,10 @@ class RedisMessageSystem: # 获取待处理的消息 message_keys = [] for _ in range(self.max_batch_size): - key = await self._redis_exec(self.redis.brpop, ["pending_messages"], timeout=1) + key = await self.redis.brpop("pending_messages", timeout=1) if key: # key 是 (queue_name, value) 的元组 - value = key[1] - if isinstance(value, bytes): - value = value.decode("utf-8") + _, value = key message_keys.append(value) else: break @@ -472,7 +451,7 @@ class RedisMessageSystem: channel_id, message_id = map(int, key.split(":")) # 从 Redis 获取消息数据 - raw_data = await self._redis_exec(self.redis.hgetall, f"msg:{channel_id}:{message_id}") + raw_data = await self.redis.hgetall(f"msg:{channel_id}:{message_id}") if not raw_data: continue @@ -480,18 +459,13 @@ class RedisMessageSystem: # 解码数据 message_data = {} for k, v in raw_data.items(): - if isinstance(k, bytes): - k = k.decode("utf-8") - if isinstance(v, bytes): - v = v.decode("utf-8") message_data[k] = v # 检查是否是多人房间消息,如果是则跳过数据库存储 is_multiplayer = message_data.get("is_multiplayer", "False") == "True" if is_multiplayer: # 多人房间消息不存储到数据库,直接标记为已跳过 - await self._redis_exec( - self.redis.hset, + await self.redis.hset( f"msg:{channel_id}:{message_id}", "status", "skipped_multiplayer", @@ -518,8 +492,7 @@ class RedisMessageSystem: session.add(db_message) # 更新 Redis 中的状态 - await self._redis_exec( - self.redis.hset, + await self.redis.hset( f"msg:{channel_id}:{message_id}", "status", "persisted", @@ -563,57 +536,54 @@ class RedisMessageSystem: max_id = result.one() or 0 # 检查 Redis 中的计数器值 - current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter") + current_counter = await self.redis.get("global_message_id_counter") current_counter = int(current_counter) if current_counter else 0 # 设置计数器为两者中的最大值 initial_counter = max(max_id, current_counter) - await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter) + await self.redis.set("global_message_id_counter", initial_counter) logger.info(f"Initialized global message ID counter to {initial_counter}") except Exception as e: logger.error(f"Failed to initialize message counter: {e}") # 如果初始化失败,设置一个安全的起始值 - await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000) + await self.redis.setnx("global_message_id_counter", 1000000) async def _cleanup_redis_keys(self): """清理可能存在问题的 Redis 键""" try: # 扫描所有 channel:*:messages 键并检查类型 keys_pattern = "channel:*:messages" - keys = await self._redis_exec(self.redis.keys, keys_pattern) + keys = await self.redis.keys(keys_pattern) fixed_count = 0 for key in keys: - if isinstance(key, bytes): - key = key.decode("utf-8") - try: - key_type = await self._redis_exec(self.redis.type, key) + key_type = await self.redis.type(key) if key_type == "none": # 键不存在,正常情况 continue elif key_type != "zset": logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}") - await self._redis_exec(self.redis.delete, key) + await self.redis.delete(key) # 验证删除是否成功 - verify_type = await self._redis_exec(self.redis.type, key) + verify_type = await self.redis.type(key) if verify_type != "none": logger.error(f"Failed to delete problematic key {key}, trying unlink...") - await self._redis_exec(self.redis.unlink, key) + await self.redis.unlink(key) fixed_count += 1 except Exception as cleanup_error: logger.warning(f"Failed to cleanup key {key}: {cleanup_error}") # 强制删除问题键 try: - await self._redis_exec(self.redis.delete, key) + await self.redis.delete(key) fixed_count += 1 except Exception: try: - await self._redis_exec(self.redis.unlink, key) + await self.redis.unlink(key) fixed_count += 1 except Exception as final_error: logger.error(f"Critical: Unable to clear problematic key {key}: {final_error}") @@ -654,11 +624,6 @@ class RedisMessageSystem: self._batch_timer = None logger.info("Redis message system stopped") - def __del__(self): - """清理资源""" - if hasattr(self, "executor"): - self.executor.shutdown(wait=False) - # 全局消息系统实例 redis_message_system = RedisMessageSystem()