""" 基于 Redis 的实时消息系统 - 消息立即存储到 Redis 并实时返回 - 定时批量存储到数据库 - 支持消息状态同步和故障恢复 """ from __future__ import annotations import asyncio from concurrent.futures import ThreadPoolExecutor from datetime import datetime import json import time from typing import Any from app.database.chat import ChatMessage, ChatMessageResp, MessageType from app.database.lazer_user import RANKING_INCLUDES, User, UserResp from app.dependencies.database import get_redis_message, with_db from app.log import logger from app.utils import bg_tasks class RedisMessageSystem: """Redis 消息系统""" def __init__(self): self.redis = get_redis_message() self.executor = ThreadPoolExecutor(max_workers=2) self._batch_timer: asyncio.Task | None = None self._running = False self.batch_interval = 5.0 # 5秒批量存储一次 self.max_batch_size = 100 # 每批最多处理100条消息 async def _redis_exec(self, func, *args, **kwargs): """在线程池中执行 Redis 操作""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs)) async def send_message( self, channel_id: int, user: User, content: str, is_action: bool = False, user_uuid: str | None = None, ) -> ChatMessageResp: """ 发送消息 - 立即存储到 Redis 并返回 Args: channel_id: 频道ID user: 发送用户 content: 消息内容 is_action: 是否为动作消息 user_uuid: 用户UUID Returns: ChatMessageResp: 消息响应对象 """ # 生成消息ID和时间戳 message_id = await self._generate_message_id(channel_id) timestamp = datetime.now() # 确保用户ID存在 if not user.id: raise ValueError("User ID is required") # 获取频道类型以判断是否需要存储到数据库 async with with_db() as session: from app.database.chat import ChannelType, ChatChannel from sqlmodel import select channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id)) channel_type = channel_result.first() is_multiplayer = channel_type == ChannelType.MULTIPLAYER # 准备消息数据 message_data = { "message_id": message_id, "channel_id": channel_id, "sender_id": user.id, "content": content, "timestamp": timestamp.isoformat(), "type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value, "uuid": user_uuid or "", "status": "cached", # Redis 缓存状态 "created_at": time.time(), "is_multiplayer": is_multiplayer, # 标记是否为多人房间消息 } # 立即存储到 Redis await self._store_to_redis(message_id, channel_id, message_data) # 创建响应对象 async with with_db() as session: user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES) # 确保 statistics 不为空 if user_resp.statistics is None: from app.database.statistics import UserStatisticsResp user_resp.statistics = UserStatisticsResp( mode=user.playmode, global_rank=0, country_rank=0, pp=0.0, ranked_score=0, hit_accuracy=0.0, play_count=0, play_time=0, total_score=0, total_hits=0, maximum_combo=0, replays_watched_by_others=0, is_ranked=False, grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0}, level={"current": 1, "progress": 0}, ) response = ChatMessageResp( message_id=message_id, channel_id=channel_id, content=content, timestamp=timestamp, sender_id=user.id, sender=user_resp, is_action=is_action, uuid=user_uuid, ) if is_multiplayer: logger.info( f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}," " will not be persisted to database" ) else: logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}") return response async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]: """ 获取频道消息 - 优先从 Redis 获取最新消息 Args: channel_id: 频道ID limit: 消息数量限制 since: 起始消息ID Returns: List[ChatMessageResp]: 消息列表 """ messages = [] try: # 从 Redis 获取最新消息 redis_messages = await self._get_from_redis(channel_id, limit, since) # 为每条消息构建响应对象 async with with_db() as session: for msg_data in redis_messages: # 获取发送者信息 sender = await session.get(User, msg_data["sender_id"]) if sender: user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES) if user_resp.statistics is None: from app.database.statistics import UserStatisticsResp user_resp.statistics = UserStatisticsResp( mode=sender.playmode, global_rank=0, country_rank=0, pp=0.0, ranked_score=0, hit_accuracy=0.0, play_count=0, play_time=0, total_score=0, total_hits=0, maximum_combo=0, replays_watched_by_others=0, is_ranked=False, grade_counts={ "ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0, }, level={"current": 1, "progress": 0}, ) message_resp = ChatMessageResp( message_id=msg_data["message_id"], channel_id=msg_data["channel_id"], content=msg_data["content"], timestamp=datetime.fromisoformat(msg_data["timestamp"]), sender_id=msg_data["sender_id"], sender=user_resp, is_action=msg_data["type"] == MessageType.ACTION.value, uuid=msg_data.get("uuid") or None, ) messages.append(message_resp) # 如果 Redis 消息不够,从数据库补充 if len(messages) < limit and since == 0: await self._backfill_from_database(channel_id, messages, limit) except Exception as e: logger.error(f"Failed to get messages from Redis: {e}") # 回退到数据库查询 messages = await self._get_from_database_only(channel_id, limit, since) return messages[:limit] async def _generate_message_id(self, channel_id: int) -> int: """生成唯一的消息ID - 确保全局唯一且严格递增""" # 使用全局计数器确保所有频道的消息ID都是严格递增的 message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter") # 同时更新频道的最后消息ID,用于客户端状态同步 await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id) return message_id async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]): """存储消息到 Redis""" try: # 检查是否是多人房间消息 is_multiplayer = message_data.get("is_multiplayer", False) # 存储消息数据 await self._redis_exec( self.redis.hset, f"msg:{channel_id}:{message_id}", mapping={k: json.dumps(v) if isinstance(v, dict | list) else str(v) for k, v in message_data.items()}, ) # 设置消息过期时间(7天) await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800) # 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序) channel_messages_key = f"channel:{channel_id}:messages" # 检查键的类型,如果不是 zset 类型则删除 try: key_type = await self._redis_exec(self.redis.type, channel_messages_key) if key_type and key_type != "zset": logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}") await self._redis_exec(self.redis.delete, channel_messages_key) except Exception as type_check_error: logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}") # 如果检查失败,直接删除键以确保清理 await self._redis_exec(self.redis.delete, channel_messages_key) # 添加到频道消息列表(sorted set) await self._redis_exec( self.redis.zadd, channel_messages_key, {f"msg:{channel_id}:{message_id}": message_id}, ) # 保持频道消息列表大小(最多1000条) await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001) # 只有非多人房间消息才添加到待持久化队列 if not is_multiplayer: await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}") logger.debug(f"Message {message_id} added to persistence queue") else: logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue") except Exception as e: logger.error(f"Failed to store message to Redis: {e}") raise async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]: """从 Redis 获取消息""" try: # 获取消息键列表,按消息ID排序 if since > 0: # 获取指定ID之后的消息(正序) message_keys = await self._redis_exec( self.redis.zrangebyscore, f"channel:{channel_id}:messages", since + 1, "+inf", start=0, num=limit, ) else: # 获取最新的消息(倒序获取,然后反转) message_keys = await self._redis_exec( self.redis.zrevrange, f"channel:{channel_id}:messages", 0, limit - 1 ) messages = [] for key in message_keys: if isinstance(key, bytes): key = key.decode("utf-8") # 获取消息数据 raw_data = await self._redis_exec(self.redis.hgetall, key) if raw_data: # 解码数据 message_data = {} for k, v in raw_data.items(): if isinstance(k, bytes): k = k.decode("utf-8") if isinstance(v, bytes): v = v.decode("utf-8") # 尝试解析 JSON try: if k in ["grade_counts", "level"] or v.startswith(("{", "[")): message_data[k] = json.loads(v) elif k in ["message_id", "channel_id", "sender_id"]: message_data[k] = int(v) elif k == "created_at": message_data[k] = float(v) else: message_data[k] = v except (json.JSONDecodeError, ValueError): message_data[k] = v messages.append(message_data) # 确保消息按ID正序排序(时间顺序) messages.sort(key=lambda x: x.get("message_id", 0)) # 如果是获取最新消息(since=0),需要保持倒序(最新的在前面) if since == 0: messages.reverse() return messages except Exception as e: logger.error(f"Failed to get messages from Redis: {e}") return [] async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int): """从数据库补充历史消息""" try: # 找到最小的消息ID min_id = float("inf") if existing_messages: for msg in existing_messages: if msg.message_id is not None and msg.message_id < min_id: min_id = msg.message_id needed = limit - len(existing_messages) if needed <= 0: return async with with_db() as session: from sqlmodel import col, select query = select(ChatMessage).where(ChatMessage.channel_id == channel_id) if min_id != float("inf"): query = query.where(col(ChatMessage.message_id) < min_id) query = query.order_by(col(ChatMessage.message_id).desc()).limit(needed) db_messages = (await session.exec(query)).all() for msg in reversed(db_messages): # 按时间正序插入 msg_resp = await ChatMessageResp.from_db(msg, session) existing_messages.insert(0, msg_resp) except Exception as e: logger.error(f"Failed to backfill from database: {e}") async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]: """仅从数据库获取消息(回退方案)""" try: async with with_db() as session: from sqlmodel import col, select query = select(ChatMessage).where(ChatMessage.channel_id == channel_id) if since > 0: # 获取指定ID之后的消息,按ID正序 query = query.where(col(ChatMessage.message_id) > since) query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit) else: # 获取最新消息,按ID倒序(最新的在前面) query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit) messages = (await session.exec(query)).all() results = [await ChatMessageResp.from_db(msg, session) for msg in messages] # 如果是 since > 0,保持正序;否则反转为时间正序 if since == 0: results.reverse() return results except Exception as e: logger.error(f"Failed to get messages from database: {e}") return [] async def _batch_persist_to_database(self): """批量持久化消息到数据库""" logger.info("Starting batch persistence to database") while self._running: try: # 获取待处理的消息 message_keys = [] for _ in range(self.max_batch_size): key = await self._redis_exec(self.redis.brpop, ["pending_messages"], timeout=1) if key: # key 是 (queue_name, value) 的元组 value = key[1] if isinstance(value, bytes): value = value.decode("utf-8") message_keys.append(value) else: break if message_keys: await self._process_message_batch(message_keys) else: await asyncio.sleep(self.batch_interval) except Exception as e: logger.error(f"Error in batch persistence: {e}") await asyncio.sleep(1) logger.info("Stopped batch persistence to database") async def _process_message_batch(self, message_keys: list[str]): """处理消息批次""" async with with_db() as session: for key in message_keys: try: # 解析频道ID和消息ID channel_id, message_id = map(int, key.split(":")) # 从 Redis 获取消息数据 raw_data = await self._redis_exec(self.redis.hgetall, f"msg:{channel_id}:{message_id}") if not raw_data: continue # 解码数据 message_data = {} for k, v in raw_data.items(): if isinstance(k, bytes): k = k.decode("utf-8") if isinstance(v, bytes): v = v.decode("utf-8") message_data[k] = v # 检查是否是多人房间消息,如果是则跳过数据库存储 is_multiplayer = message_data.get("is_multiplayer", "False") == "True" if is_multiplayer: # 多人房间消息不存储到数据库,直接标记为已跳过 await self._redis_exec( self.redis.hset, f"msg:{channel_id}:{message_id}", "status", "skipped_multiplayer", ) logger.debug(f"Message {message_id} in multiplayer room skipped from database storage") continue # 检查消息是否已存在于数据库 existing = await session.get(ChatMessage, int(message_id)) if existing: continue # 创建数据库消息 - 使用 Redis 生成的正数ID db_message = ChatMessage( message_id=int(message_id), # 使用 Redis 系统生成的正数ID channel_id=int(message_data["channel_id"]), sender_id=int(message_data["sender_id"]), content=message_data["content"], timestamp=datetime.fromisoformat(message_data["timestamp"]), type=MessageType(message_data["type"]), uuid=message_data.get("uuid") or None, ) session.add(db_message) # 更新 Redis 中的状态 await self._redis_exec( self.redis.hset, f"msg:{channel_id}:{message_id}", "status", "persisted", ) logger.debug(f"Message {message_id} persisted to database") except Exception as e: logger.error(f"Failed to process message {key}: {e}") # 提交批次 try: await session.commit() logger.info(f"Batch of {len(message_keys)} messages committed to database") except Exception as e: logger.error(f"Failed to commit message batch: {e}") await session.rollback() def start(self): """启动系统""" if not self._running: self._running = True self._batch_timer = asyncio.create_task(self._batch_persist_to_database()) # 启动时初始化消息ID计数器 bg_tasks.add_task(self._initialize_message_counter) logger.info("Redis message system started") async def _initialize_message_counter(self): """初始化全局消息ID计数器,确保从数据库最大ID开始""" try: # 清理可能存在的问题键 await self._cleanup_redis_keys() async with with_db() as session: from sqlmodel import func, select # 获取数据库中最大的消息ID result = await session.exec(select(func.max(ChatMessage.message_id))) max_id = result.one() or 0 # 检查 Redis 中的计数器值 current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter") current_counter = int(current_counter) if current_counter else 0 # 设置计数器为两者中的最大值 initial_counter = max(max_id, current_counter) await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter) logger.info(f"Initialized global message ID counter to {initial_counter}") except Exception as e: logger.error(f"Failed to initialize message counter: {e}") # 如果初始化失败,设置一个安全的起始值 await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000) async def _cleanup_redis_keys(self): """清理可能存在问题的 Redis 键""" try: # 扫描所有 channel:*:messages 键并检查类型 keys_pattern = "channel:*:messages" keys = await self._redis_exec(self.redis.keys, keys_pattern) for key in keys: if isinstance(key, bytes): key = key.decode("utf-8") try: key_type = await self._redis_exec(self.redis.type, key) if key_type and key_type != "zset": logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}") await self._redis_exec(self.redis.delete, key) except Exception as cleanup_error: logger.warning(f"Failed to cleanup key {key}: {cleanup_error}") # 强制删除问题键 await self._redis_exec(self.redis.delete, key) logger.info("Redis keys cleanup completed") except Exception as e: logger.error(f"Failed to cleanup Redis keys: {e}") def stop(self): """停止系统""" if self._running: self._running = False if self._batch_timer: self._batch_timer.cancel() self._batch_timer = None logger.info("Redis message system stopped") def __del__(self): """清理资源""" if hasattr(self, "executor"): self.executor.shutdown(wait=False) # 全局消息系统实例 redis_message_system = RedisMessageSystem()