From 1fe603f41698ae275a5528ce380457921693b808 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= <74496778+GooGuJiang@users.noreply.github.com> Date: Fri, 22 Aug 2025 01:49:03 +0800 Subject: [PATCH] add message redis --- app/database/chat.py | 14 +- app/dependencies/database.py | 9 + app/models/multiplayer_hub.py | 6 +- app/router/notification/channel.py | 96 ++++- app/router/notification/message.py | 160 ++++++-- app/router/notification/server.py | 73 +++- app/service/message_queue.py | 217 ++++++++++ app/service/message_queue_processor.py | 282 +++++++++++++ app/service/optimized_message.py | 150 +++++++ app/service/redis_message_system.py | 537 +++++++++++++++++++++++++ main.py | 3 + 11 files changed, 1461 insertions(+), 86 deletions(-) create mode 100644 app/service/message_queue.py create mode 100644 app/service/message_queue_processor.py create mode 100644 app/service/optimized_message.py create mode 100644 app/service/redis_message_system.py diff --git a/app/database/chat.py b/app/database/chat.py index 9647b48..8ac1f82 100644 --- a/app/database/chat.py +++ b/app/database/chat.py @@ -59,12 +59,17 @@ class ChatChannel(ChatChannelBase, table=True): cls, channel: str | int, session: AsyncSession ) -> "ChatChannel | None": if isinstance(channel, int) or channel.isdigit(): - channel_ = await session.get(ChatChannel, channel) + # 使用查询而不是 get() 来确保对象完全加载 + result = await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == int(channel)) + ) + channel_ = result.first() if channel_ is not None: return channel_ - return ( - await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) - ).first() + result = await session.exec( + select(ChatChannel).where(ChatChannel.name == channel) + ) + return result.first() @classmethod async def get_pm_channel( @@ -235,6 +240,7 @@ class UserSilenceResp(SQLModel): @classmethod def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp": + assert db_silence.id is not None return cls( id=db_silence.id, user_id=db_silence.user_id, diff --git a/app/dependencies/database.py b/app/dependencies/database.py index f345537..4ffefd0 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -11,6 +11,7 @@ from app.config import settings from fastapi import Depends from pydantic import BaseModel import redis.asyncio as redis +import redis as sync_redis from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession @@ -38,6 +39,9 @@ engine = create_async_engine( # Redis 连接 redis_client = redis.from_url(settings.redis_url, decode_responses=True) +# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行 +redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1) + # 数据库依赖 db_session_context: ContextVar[AsyncSession | None] = ContextVar( @@ -80,5 +84,10 @@ def get_redis(): return redis_client +def get_redis_message(): + """获取消息专用的 Redis 客户端 (db1)""" + return redis_message_client + + def get_redis_pubsub(): return redis_client.pubsub() diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index e583a69..a8cd613 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -151,7 +151,6 @@ class PlaylistItem(BaseModel): def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None: from typing import Literal, cast - API_MODS = self._get_api_mods() typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) for i, mod1 in enumerate(mods): @@ -168,7 +167,6 @@ class PlaylistItem(BaseModel): def _check_required_allowed_compatibility(self, ruleset_key: int) -> None: from typing import Literal, cast - API_MODS = self._get_api_mods() typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods} @@ -213,8 +211,6 @@ class PlaylistItem(BaseModel): """ from typing import Literal, cast - API_MODS = self._get_api_mods() - ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id) @@ -386,7 +382,7 @@ class MultiplayerRoom(BaseModel): match_state=None, playlist=playlist, active_countdowns=[], - channel_id=room.channel_id, + channel_id=room.channel_id or 0, ) diff --git a/app/router/notification/channel.py b/app/router/notification/channel.py index fbba4ae..9c5e941 100644 --- a/app/router/notification/channel.py +++ b/app/router/notification/channel.py @@ -53,16 +53,24 @@ async def get_update( assert current_user.id channel_ids = server.get_user_joined_channel(current_user.id) for channel_id in channel_ids: - channel = await ChatChannel.get(channel_id, session) - if channel: + # 使用明确的查询避免延迟加载 + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == channel_id) + ) + ).first() + if db_channel: + # 提取必要的属性避免惰性加载 + channel_type = db_channel.type + resp.presence.append( await ChatChannelResp.from_db( - channel, + db_channel, session, current_user, redis, server.channels.get(channel_id, []) - if channel.type != ChannelType.PUBLIC + if channel_type != ChannelType.PUBLIC else None, ) ) @@ -105,7 +113,19 @@ async def join_channel( user: str = Path(..., description="用户 ID"), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), ): - db_channel = await ChatChannel.get(channel, session) + # 使用明确的查询避免延迟加载 + if channel.isdigit(): + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == int(channel)) + ) + ).first() + else: + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.name == channel) + ) + ).first() if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") @@ -125,7 +145,19 @@ async def leave_channel( user: str = Path(..., description="用户 ID"), current_user: User = Security(get_current_user, scopes=["chat.write_manage"]), ): - db_channel = await ChatChannel.get(channel, session) + # 使用明确的查询避免延迟加载 + if channel.isdigit(): + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == int(channel)) + ) + ).first() + else: + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.name == channel) + ) + ).first() if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") @@ -152,15 +184,19 @@ async def get_channel_list( ).all() results = [] for channel in channels: - assert channel.channel_id is not None + # 提取必要的属性避免惰性加载 + channel_id = channel.channel_id + channel_type = channel.type + + assert channel_id is not None results.append( await ChatChannelResp.from_db( channel, session, current_user, redis, - server.channels.get(channel.channel_id, []) - if channel.type != ChannelType.PUBLIC + server.channels.get(channel_id, []) + if channel_type != ChannelType.PUBLIC else None, ) ) @@ -185,14 +221,33 @@ async def get_channel( current_user: User = Security(get_current_user, scopes=["chat.read"]), redis: Redis = Depends(get_redis), ): - db_channel = await ChatChannel.get(channel, session) + # 使用明确的查询避免延迟加载 + if channel.isdigit(): + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == int(channel)) + ) + ).first() + else: + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.name == channel) + ) + ).first() + if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") - assert db_channel.channel_id is not None + + # 立即提取需要的属性 + channel_id = db_channel.channel_id + channel_type = db_channel.type + channel_name = db_channel.name + + assert channel_id is not None users = [] - if db_channel.type == ChannelType.PM: - user_ids = db_channel.name.split("_")[1:] + if channel_type == ChannelType.PM: + user_ids = channel_name.split("_")[1:] if len(user_ids) != 2: raise HTTPException(status_code=404, detail="Target user not found") for id_ in user_ids: @@ -210,8 +265,8 @@ async def get_channel( session, current_user, redis, - server.channels.get(db_channel.channel_id, []) - if db_channel.type != ChannelType.PUBLIC + server.channels.get(channel_id, []) + if channel_type != ChannelType.PUBLIC else None, ) ) @@ -270,7 +325,8 @@ async def create_channel( channel_name = f"pm_{current_user.id}_{req.target_id}" else: channel_name = req.channel.name if req.channel else "Unnamed Channel" - channel = await ChatChannel.get(channel_name, session) + result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name)) + channel = result.first() if channel is None: channel = ChatChannel( @@ -294,12 +350,16 @@ async def create_channel( await server.batch_join_channel([*target_users, current_user], channel, session) await server.join_channel(current_user, channel, session) - assert channel.channel_id + + # 提取必要的属性避免惰性加载 + channel_id = channel.channel_id + assert channel_id + return await ChatChannelResp.from_db( channel, session, current_user, redis, - server.channels.get(channel.channel_id, []), + server.channels.get(channel_id, []), include_recent_messages=True, ) diff --git a/app/router/notification/message.py b/app/router/notification/message.py index 44b9f6e..ec3b419 100644 --- a/app/router/notification/message.py +++ b/app/router/notification/message.py @@ -1,5 +1,10 @@ from __future__ import annotations +import json +import uuid +from datetime import datetime +from typing import Optional + from app.database import ChatMessageResp from app.database.chat import ( ChannelType, @@ -11,11 +16,14 @@ from app.database.chat import ( UserSilenceResp, ) from app.database.lazer_user import User -from app.dependencies.database import Database, get_redis +from app.dependencies.database import Database, get_redis, get_redis_message from app.dependencies.param import BodyOrForm from app.dependencies.user import get_current_user from app.models.notification import ChannelMessage, ChannelMessageTeam from app.router.v2 import api_v2_router as router +from app.service.optimized_message import optimized_message_service +from app.service.redis_message_system import redis_message_system +from app.log import logger from .banchobot import bot from .server import server @@ -89,42 +97,73 @@ async def send_message( req: MessageReq = Depends(BodyOrForm(MessageReq)), current_user: User = Security(get_current_user, scopes=["chat.write"]), ): - db_channel = await ChatChannel.get(channel, session) + # 使用明确的查询来获取 channel,避免延迟加载 + if channel.isdigit(): + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == int(channel)) + ) + ).first() + else: + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.name == channel) + ) + ).first() + if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") - assert db_channel.channel_id + # 立即提取所有需要的属性,避免后续延迟加载 + channel_id = db_channel.channel_id + channel_type = db_channel.type + channel_name = db_channel.name + + assert channel_id is not None assert current_user.id - msg = ChatMessage( - channel_id=db_channel.channel_id, + + # 使用 Redis 消息系统发送消息 - 立即返回 + resp = await redis_message_system.send_message( + channel_id=channel_id, + user=current_user, content=req.message, - sender_id=current_user.id, - type=MessageType.ACTION if req.is_action else MessageType.PLAIN, - uuid=req.uuid, + is_action=req.is_action, + user_uuid=req.uuid ) - session.add(msg) - await session.commit() - await session.refresh(msg) - await session.refresh(current_user) - await session.refresh(db_channel) - resp = await ChatMessageResp.from_db(msg, session, current_user) + + # 立即广播消息给所有客户端 is_bot_command = req.message.startswith("!") await server.send_message_to_channel( - resp, is_bot_command and db_channel.type == ChannelType.PUBLIC + resp, is_bot_command and channel_type == ChannelType.PUBLIC ) + + # 处理机器人命令 if is_bot_command: await bot.try_handle(current_user, db_channel, req.message, session) - if db_channel.type == ChannelType.PM: - user_ids = db_channel.name.split("_")[1:] - await server.new_private_notification( - ChannelMessage.init( - msg, current_user, [int(u) for u in user_ids], db_channel.type + + # 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道) + if channel_type in [ChannelType.PM, ChannelType.TEAM]: + temp_msg = ChatMessage( + message_id=resp.message_id, # 使用 Redis 系统生成的ID + channel_id=channel_id, + content=req.message, + sender_id=current_user.id, + type=MessageType.ACTION if req.is_action else MessageType.PLAIN, + uuid=req.uuid, + ) + + if channel_type == ChannelType.PM: + user_ids = channel_name.split("_")[1:] + await server.new_private_notification( + ChannelMessage.init( + temp_msg, current_user, [int(u) for u in user_ids], channel_type + ) ) - ) - elif db_channel.type == ChannelType.TEAM: - await server.new_private_notification( - ChannelMessageTeam.init(msg, current_user) - ) + elif channel_type == ChannelType.TEAM: + await server.new_private_notification( + ChannelMessageTeam.init(temp_msg, current_user) + ) + return resp @@ -143,21 +182,46 @@ async def get_message( until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"), current_user: User = Security(get_current_user, scopes=["chat.read"]), ): - db_channel = await ChatChannel.get(channel, session) + # 使用明确的查询获取 channel,避免延迟加载 + if channel.isdigit(): + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == int(channel)) + ) + ).first() + else: + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.name == channel) + ) + ).first() + if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") - messages = await session.exec( - select(ChatMessage) - .where( - ChatMessage.channel_id == db_channel.channel_id, - col(ChatMessage.message_id) > since, - col(ChatMessage.message_id) < until if until is not None else True, - ) - .order_by(col(ChatMessage.timestamp).desc()) - .limit(limit) - ) + + # 提取必要的属性避免惰性加载 + channel_id = db_channel.channel_id + assert channel_id is not None + + # 使用 Redis 消息系统获取消息 + try: + messages = await redis_message_system.get_messages(channel_id, limit, since) + return messages + except Exception as e: + logger.warning(f"Failed to get messages from Redis system: {e}") + # 回退到传统数据库查询 + pass + + # 回退到数据库查询 + query = select(ChatMessage).where(ChatMessage.channel_id == channel_id) + if since > 0: + query = query.where(col(ChatMessage.message_id) > since) + if until is not None: + query = query.where(col(ChatMessage.message_id) < until) + + query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit) + messages = (await session.exec(query)).all() resp = [await ChatMessageResp.from_db(msg, session) for msg in messages] - resp.reverse() return resp @@ -174,12 +238,28 @@ async def mark_as_read( message: int = Path(..., description="消息 ID"), current_user: User = Security(get_current_user, scopes=["chat.read"]), ): - db_channel = await ChatChannel.get(channel, session) + # 使用明确的查询获取 channel,避免延迟加载 + if channel.isdigit(): + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == int(channel)) + ) + ).first() + else: + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.name == channel) + ) + ).first() + if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") - assert db_channel.channel_id + + # 立即提取需要的属性 + channel_id = db_channel.channel_id + assert channel_id assert current_user.id - await server.mark_as_read(db_channel.channel_id, current_user.id, message) + await server.mark_as_read(channel_id, current_user.id, message) class PMReq(BaseModel): diff --git a/app/router/notification/server.py b/app/router/notification/server.py index 0f891c3..e24106f 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -59,9 +59,14 @@ class ChatServer: for channel_id, channel in self.channels.items(): if user_id in channel: channel.remove(user_id) - channel = await ChatChannel.get(channel_id, session) - if channel: - await self.leave_channel(user, channel, session) + # 使用明确的查询避免延迟加载 + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == channel_id) + ) + ).first() + if db_channel: + await self.leave_channel(user, db_channel, session) @overload async def send_event(self, client: int, event: ChatEvent): ... @@ -79,8 +84,11 @@ class ChatServer: await client.send_text(event.model_dump_json()) async def broadcast(self, channel_id: int, event: ChatEvent): - for user_id in self.channels.get(channel_id, []): + users_in_channel = self.channels.get(channel_id, []) + logger.info(f"Broadcasting to channel {channel_id}, users: {users_in_channel}") + for user_id in users_in_channel: await self.send_event(user_id, event) + logger.debug(f"Sent event to user {user_id} in channel {channel_id}") async def mark_as_read(self, channel_id: int, user_id: int, message_id: int): await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id) @@ -88,24 +96,35 @@ class ChatServer: async def send_message_to_channel( self, message: ChatMessageResp, is_bot_command: bool = False ): + logger.info(f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}") + event = ChatEvent( event="chat.message.new", data={"messages": [message], "users": [message.sender]}, ) if is_bot_command: + logger.info(f"Sending bot command to user {message.sender_id}") self._add_task(self.send_event(message.sender_id, event)) else: + # 总是广播消息,无论是临时ID还是真实ID + logger.info(f"Broadcasting message to all users in channel {message.channel_id}") self._add_task( self.broadcast( message.channel_id, event, ) ) - assert message.message_id - await self.mark_as_read( - message.channel_id, message.sender_id, message.message_id - ) - await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id) + + # 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息 + # Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理 + if message.message_id and message.message_id > 0: + await self.mark_as_read( + message.channel_id, message.sender_id, message.message_id + ) + await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id) + logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}") + else: + logger.debug(f"Skipping last message update for message ID: {message.message_id}") async def batch_join_channel( self, users: list[User], channel: ChatChannel, session: AsyncSession @@ -206,27 +225,37 @@ class ChatServer: async def join_room_channel(self, channel_id: int, user_id: int): async with with_db() as session: - channel = await ChatChannel.get(channel_id, session) - if channel is None: + # 使用明确的查询避免延迟加载 + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == channel_id) + ) + ).first() + if db_channel is None: return user = await session.get(User, user_id) if user is None: return - await self.join_channel(user, channel, session) + await self.join_channel(user, db_channel, session) async def leave_room_channel(self, channel_id: int, user_id: int): async with with_db() as session: - channel = await ChatChannel.get(channel_id, session) - if channel is None: + # 使用明确的查询避免延迟加载 + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == channel_id) + ) + ).first() + if db_channel is None: return user = await session.get(User, user_id) if user is None: return - await self.leave_channel(user, channel, session) + await self.leave_channel(user, db_channel, session) async def new_private_notification(self, detail: NotificationDetail): async with with_db() as session: @@ -309,7 +338,13 @@ async def chat_websocket( user_id = user.id assert user_id server.connect(user_id, websocket) - channel = await ChatChannel.get(1, session) - if channel is not None: - await server.join_channel(user, channel, session) - await _listen_stop(websocket, user_id, factory) + # 使用明确的查询避免延迟加载 + db_channel = ( + await session.exec( + select(ChatChannel).where(ChatChannel.channel_id == 1) + ) + ).first() + if db_channel is not None: + await server.join_channel(user, db_channel, session) + + await _listen_stop(websocket, user_id, factory) diff --git a/app/service/message_queue.py b/app/service/message_queue.py new file mode 100644 index 0000000..cc921d5 --- /dev/null +++ b/app/service/message_queue.py @@ -0,0 +1,217 @@ +""" +Redis 消息队列服务 +用于实现实时消息推送和异步数据库持久化 +""" + +import asyncio +import json +import uuid +from datetime import datetime +from functools import partial +from typing import Optional, Union +import concurrent.futures + +from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType +from app.dependencies.database import get_redis, with_db +from app.log import logger + + +class MessageQueue: + """Redis 消息队列服务""" + + def __init__(self): + self.redis = get_redis() + self._processing = False + self._batch_size = 50 # 批量处理大小 + self._batch_timeout = 1.0 # 批量处理超时时间(秒) + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) + + async def _run_in_executor(self, func, *args): + """在线程池中运行同步 Redis 操作""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, func, *args) + + async def start_processing(self): + """启动消息处理任务""" + if not self._processing: + self._processing = True + asyncio.create_task(self._process_message_queue()) + logger.info("Message queue processing started") + + async def stop_processing(self): + """停止消息处理""" + self._processing = False + logger.info("Message queue processing stopped") + + async def enqueue_message(self, message_data: dict) -> str: + """ + 将消息加入 Redis 队列(实时响应) + + Args: + message_data: 消息数据字典,包含所有必要的字段 + + Returns: + 消息的临时 UUID + """ + # 生成临时 UUID + temp_uuid = str(uuid.uuid4()) + message_data["temp_uuid"] = temp_uuid + message_data["timestamp"] = datetime.now().isoformat() + message_data["status"] = "pending" # pending, processing, completed, failed + + # 将消息存储到 Redis + await self._run_in_executor( + lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data) + ) + await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期 + + # 加入处理队列 + await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid) + + logger.info(f"Message enqueued with temp_uuid: {temp_uuid}") + return temp_uuid + + async def get_message_status(self, temp_uuid: str) -> Optional[dict]: + """获取消息状态""" + message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}") + if not message_data: + return None + + return message_data + + async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: + """ + 从 Redis 获取缓存的消息 + + Args: + channel_id: 频道 ID + limit: 限制数量 + since: 获取自此消息 ID 之后的消息 + + Returns: + 消息列表 + """ + # 从 Redis 获取频道最近的消息 UUID 列表 + message_uuids = await self._run_in_executor( + self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1 + ) + + messages = [] + for uuid_str in message_uuids: + message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}") + if message_data: + # 检查是否满足 since 条件 + if since > 0 and "message_id" in message_data: + if int(message_data["message_id"]) <= since: + continue + + messages.append(message_data) + + return messages[::-1] # 返回时间顺序 + + async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100): + """将消息 UUID 缓存到频道消息列表""" + # 添加到频道消息列表开头 + await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid) + # 限制缓存大小 + await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1) + # 设置过期时间(24小时) + await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400) + + async def _process_message_queue(self): + """异步处理消息队列,批量写入数据库""" + while self._processing: + try: + # 批量获取消息 + message_uuids = [] + for _ in range(self._batch_size): + result = await self._run_in_executor( + lambda: self.redis.brpop(["message_queue"], timeout=1) + ) + if result: + message_uuids.append(result[1]) + else: + break + + if message_uuids: + await self._process_message_batch(message_uuids) + else: + # 没有消息时短暂等待 + await asyncio.sleep(0.1) + + except Exception as e: + logger.error(f"Error processing message queue: {e}") + await asyncio.sleep(1) # 错误时等待1秒再重试 + + async def _process_message_batch(self, message_uuids: list[str]): + """批量处理消息写入数据库""" + async with with_db() as session: + messages_to_insert = [] + + for temp_uuid in message_uuids: + try: + # 获取消息数据 + message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}") + if not message_data: + continue + + # 更新状态为处理中 + await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing") + + # 创建数据库消息对象 + msg = ChatMessage( + channel_id=int(message_data["channel_id"]), + content=message_data["content"], + sender_id=int(message_data["sender_id"]), + type=MessageType(message_data["type"]), + uuid=message_data.get("user_uuid") # 用户提供的 UUID(如果有) + ) + + messages_to_insert.append((msg, temp_uuid)) + + except Exception as e: + logger.error(f"Error preparing message {temp_uuid}: {e}") + await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed") + + if messages_to_insert: + try: + # 批量插入数据库 + for msg, temp_uuid in messages_to_insert: + session.add(msg) + + await session.commit() + + # 更新所有消息状态和真实 ID + for msg, temp_uuid in messages_to_insert: + await session.refresh(msg) + await self._run_in_executor( + lambda: self.redis.hset(f"msg:{temp_uuid}", mapping={ + "status": "completed", + "message_id": str(msg.message_id), + "created_at": msg.timestamp.isoformat() if msg.timestamp else "" + }) + ) + + logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}") + + except Exception as e: + logger.error(f"Error inserting messages to database: {e}") + await session.rollback() + + # 标记所有消息为失败 + for _, temp_uuid in messages_to_insert: + await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed") + + +# 全局消息队列实例 +message_queue = MessageQueue() + + +async def start_message_queue(): + """启动消息队列处理""" + await message_queue.start_processing() + + +async def stop_message_queue(): + """停止消息队列处理""" + await message_queue.stop_processing() diff --git a/app/service/message_queue_processor.py b/app/service/message_queue_processor.py new file mode 100644 index 0000000..3f4180f --- /dev/null +++ b/app/service/message_queue_processor.py @@ -0,0 +1,282 @@ +""" +消息队列处理服务 +专门处理 Redis 消息队列的异步写入数据库 +""" + +import asyncio +import json +import uuid +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Optional + +from app.database.chat import ChatMessage, MessageType +from app.dependencies.database import get_redis_message, with_db +from app.log import logger + + +class MessageQueueProcessor: + """消息队列处理器""" + + def __init__(self): + self.redis_message = get_redis_message() + self.executor = ThreadPoolExecutor(max_workers=2) + self._processing = False + self._queue_task = None + + async def _redis_exec(self, func, *args, **kwargs): + """在线程池中执行 Redis 操作""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs)) + + async def cache_message(self, channel_id: int, message_data: dict, temp_uuid: str): + """将消息缓存到 Redis""" + try: + # 存储消息数据 + await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data) + await self._redis_exec(self.redis_message.expire, f"msg:{temp_uuid}", 3600) # 1小时过期 + + # 加入频道消息列表 + await self._redis_exec(self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid) + await self._redis_exec(self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99) # 保持最新100条 + await self._redis_exec(self.redis_message.expire, f"channel:{channel_id}:messages", 86400) # 24小时过期 + + # 加入异步处理队列 + await self._redis_exec(self.redis_message.lpush, "message_write_queue", temp_uuid) + + logger.info(f"Message cached to Redis: {temp_uuid}") + except Exception as e: + logger.error(f"Failed to cache message to Redis: {e}") + + async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: + """从 Redis 获取缓存的消息""" + try: + message_uuids = await self._redis_exec( + self.redis_message.lrange, f"channel:{channel_id}:messages", 0, limit - 1 + ) + + messages = [] + for temp_uuid in message_uuids: + # 解码 UUID 如果它是字节类型 + if isinstance(temp_uuid, bytes): + temp_uuid = temp_uuid.decode('utf-8') + + raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") + if raw_data: + # 解码 Redis 返回的字节数据 + message_data = { + k.decode('utf-8') if isinstance(k, bytes) else k: + v.decode('utf-8') if isinstance(v, bytes) else v + for k, v in raw_data.items() + } + + # 检查 since 条件 + if since > 0 and message_data.get("message_id"): + if int(message_data["message_id"]) <= since: + continue + messages.append(message_data) + + return messages[::-1] # 按时间顺序返回 + except Exception as e: + logger.error(f"Failed to get cached messages: {e}") + return [] + + async def update_message_status(self, temp_uuid: str, status: str, message_id: Optional[int] = None): + """更新消息状态""" + try: + update_data = {"status": status} + if message_id: + update_data["message_id"] = str(message_id) + update_data["db_timestamp"] = datetime.now().isoformat() + + await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data) + except Exception as e: + logger.error(f"Failed to update message status: {e}") + + async def get_message_status(self, temp_uuid: str) -> Optional[dict]: + """获取消息状态""" + try: + raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") + if not raw_data: + return None + + # 解码 Redis 返回的字节数据 + return { + k.decode('utf-8') if isinstance(k, bytes) else k: + v.decode('utf-8') if isinstance(v, bytes) else v + for k, v in raw_data.items() + } + except Exception as e: + logger.error(f"Failed to get message status: {e}") + return None + + async def _process_message_queue(self): + """处理消息队列,异步写入数据库""" + logger.info("Message queue processing started") + + while self._processing: + try: + # 批量获取消息 + message_uuids = [] + for _ in range(20): # 批量处理20条消息 + result = await self._redis_exec( + self.redis_message.brpop, ["message_write_queue"], timeout=1 + ) + if result: + # result是 (queue_name, value) 的元组,需要解码 + uuid_value = result[1] + if isinstance(uuid_value, bytes): + uuid_value = uuid_value.decode('utf-8') + message_uuids.append(uuid_value) + else: + break + + if not message_uuids: + await asyncio.sleep(0.5) + continue + + # 批量写入数据库 + await self._process_message_batch(message_uuids) + + except Exception as e: + logger.error(f"Error in message queue processing: {e}") + await asyncio.sleep(1) + + logger.info("Message queue processing stopped") + + async def _process_message_batch(self, message_uuids: list[str]): + """批量处理消息写入数据库""" + async with with_db() as session: + for temp_uuid in message_uuids: + try: + # 获取消息数据并解码 + raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") + if not raw_data: + continue + + # 解码 Redis 返回的字节数据 + message_data = { + k.decode('utf-8') if isinstance(k, bytes) else k: + v.decode('utf-8') if isinstance(v, bytes) else v + for k, v in raw_data.items() + } + + if message_data.get("status") != "pending": + continue + + # 更新状态为处理中 + await self.update_message_status(temp_uuid, "processing") + + # 创建数据库消息 + msg = ChatMessage( + channel_id=int(message_data["channel_id"]), + content=message_data["content"], + sender_id=int(message_data["sender_id"]), + type=MessageType(message_data["type"]), + uuid=message_data.get("user_uuid") or None, + ) + + session.add(msg) + await session.commit() + await session.refresh(msg) + + # 更新成功状态,包含临时消息ID映射 + assert msg.message_id is not None + await self.update_message_status(temp_uuid, "completed", msg.message_id) + + # 如果有临时消息ID,存储映射关系并通知客户端更新 + if message_data.get("temp_message_id"): + temp_msg_id = int(message_data["temp_message_id"]) + await self._redis_exec( + self.redis_message.set, + f"temp_to_real:{temp_msg_id}", + str(msg.message_id), + ex=3600 # 1小时过期 + ) + + # 发送消息ID更新通知到频道 + channel_id = int(message_data["channel_id"]) + await self._notify_message_update(channel_id, temp_msg_id, msg.message_id, message_data) + + logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}") + + except Exception as e: + logger.error(f"Failed to process message {temp_uuid}: {e}") + await self.update_message_status(temp_uuid, "failed") + + async def _notify_message_update(self, channel_id: int, temp_message_id: int, real_message_id: int, message_data: dict): + """通知客户端消息ID已更新""" + try: + # 这里我们需要通过 SignalR 发送消息更新通知 + # 但为了避免循环依赖,我们将通过 Redis 发布消息更新事件 + update_event = { + "event": "chat.message.update", + "data": { + "channel_id": channel_id, + "temp_message_id": temp_message_id, + "real_message_id": real_message_id, + "timestamp": message_data.get("timestamp") + } + } + + # 发布到 Redis 频道,让 SignalR 服务处理 + await self._redis_exec( + self.redis_message.publish, + f"chat_updates:{channel_id}", + json.dumps(update_event) + ) + + logger.info(f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}") + + except Exception as e: + logger.error(f"Failed to notify message update: {e}") + + def start_processing(self): + """启动消息队列处理""" + if not self._processing: + self._processing = True + self._queue_task = asyncio.create_task(self._process_message_queue()) + logger.info("Message queue processor started") + + def stop_processing(self): + """停止消息队列处理""" + if self._processing: + self._processing = False + if self._queue_task: + self._queue_task.cancel() + self._queue_task = None + logger.info("Message queue processor stopped") + + def __del__(self): + """清理资源""" + if hasattr(self, 'executor'): + self.executor.shutdown(wait=False) + + +# 全局消息队列处理器实例 +message_queue_processor = MessageQueueProcessor() + + +def start_message_processing(): + """启动消息队列处理""" + message_queue_processor.start_processing() + + +def stop_message_processing(): + """停止消息队列处理""" + message_queue_processor.stop_processing() + + +async def cache_message_to_redis(channel_id: int, message_data: dict, temp_uuid: str): + """将消息缓存到 Redis - 便捷接口""" + await message_queue_processor.cache_message(channel_id, message_data, temp_uuid) + + +async def get_cached_messages(channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: + """从 Redis 获取缓存的消息 - 便捷接口""" + return await message_queue_processor.get_cached_messages(channel_id, limit, since) + + +async def get_message_status(temp_uuid: str) -> Optional[dict]: + """获取消息状态 - 便捷接口""" + return await message_queue_processor.get_message_status(temp_uuid) diff --git a/app/service/optimized_message.py b/app/service/optimized_message.py new file mode 100644 index 0000000..6ffd10d --- /dev/null +++ b/app/service/optimized_message.py @@ -0,0 +1,150 @@ +""" +优化的消息服务 +结合 Redis 缓存和异步数据库写入实现实时消息传送 +""" + +from typing import Optional +from fastapi import HTTPException + +from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType, ChatMessageResp +from app.database.lazer_user import User +from app.router.notification.server import server +from app.service.message_queue import message_queue +from app.log import logger +from sqlalchemy.ext.asyncio import AsyncSession + + +class OptimizedMessageService: + """优化的消息服务""" + + def __init__(self): + self.message_queue = message_queue + + async def send_message_fast( + self, + channel_id: int, + channel_type: ChannelType, + channel_name: str, + content: str, + sender: User, + is_action: bool = False, + user_uuid: Optional[str] = None, + session: Optional[AsyncSession] = None + ) -> ChatMessageResp: + """ + 快速发送消息(先缓存到 Redis,异步写入数据库) + + Args: + channel_id: 频道 ID + channel_type: 频道类型 + channel_name: 频道名称 + content: 消息内容 + sender: 发送者 + is_action: 是否为动作消息 + user_uuid: 用户提供的 UUID + session: 数据库会话(可选,用于一些验证) + + Returns: + 消息响应对象 + """ + assert sender.id is not None + + # 准备消息数据 + message_data = { + "channel_id": str(channel_id), + "content": content, + "sender_id": str(sender.id), + "type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value, + "user_uuid": user_uuid or "", + "channel_type": channel_type.value, + "channel_name": channel_name + } + + # 立即将消息加入 Redis 队列(实时响应) + temp_uuid = await self.message_queue.enqueue_message(message_data) + + # 缓存到频道消息列表 + await self.message_queue.cache_channel_message(channel_id, temp_uuid) + + # 创建临时响应对象(简化版本,用于立即响应) + from datetime import datetime + from app.database.lazer_user import UserResp + + # 创建基本的用户响应对象 + user_resp = UserResp( + id=sender.id, + username=sender.username, + country_code=getattr(sender, 'country_code', 'XX'), + # 基本字段,其他复杂字段可以后续异步加载 + ) + + temp_response = ChatMessageResp( + message_id=0, # 临时 ID,等数据库写入后会更新 + channel_id=channel_id, + content=content, + timestamp=datetime.now(), + sender_id=sender.id, + sender=user_resp, + is_action=is_action, + uuid=user_uuid + ) + temp_response.temp_uuid = temp_uuid # 添加临时 UUID 用于后续更新 + + logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}") + return temp_response + + async def get_cached_messages( + self, + channel_id: int, + limit: int = 50, + since: int = 0 + ) -> list[dict]: + """ + 获取缓存的消息 + + Args: + channel_id: 频道 ID + limit: 限制数量 + since: 获取自此消息 ID 之后的消息 + + Returns: + 消息列表 + """ + return await self.message_queue.get_cached_messages(channel_id, limit, since) + + async def get_message_status(self, temp_uuid: str) -> Optional[dict]: + """ + 获取消息状态 + + Args: + temp_uuid: 临时消息 UUID + + Returns: + 消息状态信息 + """ + return await self.message_queue.get_message_status(temp_uuid) + + async def wait_for_message_persisted(self, temp_uuid: str, timeout: int = 30) -> Optional[dict]: + """ + 等待消息持久化到数据库 + + Args: + temp_uuid: 临时消息 UUID + timeout: 超时时间(秒) + + Returns: + 完成后的消息状态 + """ + import asyncio + + for _ in range(timeout * 10): # 每100ms检查一次 + status = await self.get_message_status(temp_uuid) + if status and status.get("status") in ["completed", "failed"]: + return status + await asyncio.sleep(0.1) + + return None + + +# 全局优化消息服务实例 +optimized_message_service = OptimizedMessageService() diff --git a/app/service/redis_message_system.py b/app/service/redis_message_system.py new file mode 100644 index 0000000..cd58ffb --- /dev/null +++ b/app/service/redis_message_system.py @@ -0,0 +1,537 @@ +""" +基于 Redis 的实时消息系统 +- 消息立即存储到 Redis 并实时返回 +- 定时批量存储到数据库 +- 支持消息状态同步和故障恢复 +""" + +import asyncio +import json +import time +import uuid +from datetime import datetime +from typing import Optional, List, Dict, Any +from concurrent.futures import ThreadPoolExecutor + +from app.database.chat import ChatMessage, MessageType, ChatMessageResp +from app.database.lazer_user import User, UserResp, RANKING_INCLUDES +from app.dependencies.database import get_redis_message, with_db +from app.log import logger + + +class RedisMessageSystem: + """Redis 消息系统""" + + def __init__(self): + self.redis = get_redis_message() + self.executor = ThreadPoolExecutor(max_workers=2) + self._batch_timer: Optional[asyncio.Task] = None + self._running = False + self.batch_interval = 5.0 # 5秒批量存储一次 + self.max_batch_size = 100 # 每批最多处理100条消息 + + async def _redis_exec(self, func, *args, **kwargs): + """在线程池中执行 Redis 操作""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs)) + + async def send_message(self, channel_id: int, user: User, content: str, + is_action: bool = False, user_uuid: Optional[str] = None) -> ChatMessageResp: + """ + 发送消息 - 立即存储到 Redis 并返回 + + Args: + channel_id: 频道ID + user: 发送用户 + content: 消息内容 + is_action: 是否为动作消息 + user_uuid: 用户UUID + + Returns: + ChatMessageResp: 消息响应对象 + """ + # 生成消息ID和时间戳 + message_id = await self._generate_message_id(channel_id) + timestamp = datetime.now() + + # 确保用户ID存在 + if not user.id: + raise ValueError("User ID is required") + + # 准备消息数据 + message_data = { + "message_id": message_id, + "channel_id": channel_id, + "sender_id": user.id, + "content": content, + "timestamp": timestamp.isoformat(), + "type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value, + "uuid": user_uuid or "", + "status": "cached", # Redis 缓存状态 + "created_at": time.time() + } + + # 立即存储到 Redis + await self._store_to_redis(message_id, channel_id, message_data) + + # 创建响应对象 + async with with_db() as session: + user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES) + + # 确保 statistics 不为空 + if user_resp.statistics is None: + from app.database.statistics import UserStatisticsResp + user_resp.statistics = UserStatisticsResp( + mode=user.playmode, + global_rank=0, + country_rank=0, + pp=0.0, + ranked_score=0, + hit_accuracy=0.0, + play_count=0, + play_time=0, + total_score=0, + total_hits=0, + maximum_combo=0, + replays_watched_by_others=0, + is_ranked=False, + grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0}, + level={"current": 1, "progress": 0} + ) + + response = ChatMessageResp( + message_id=message_id, + channel_id=channel_id, + content=content, + timestamp=timestamp, + sender_id=user.id, + sender=user_resp, + is_action=is_action, + uuid=user_uuid + ) + + logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}") + return response + + async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> List[ChatMessageResp]: + """ + 获取频道消息 - 优先从 Redis 获取最新消息 + + Args: + channel_id: 频道ID + limit: 消息数量限制 + since: 起始消息ID + + Returns: + List[ChatMessageResp]: 消息列表 + """ + messages = [] + + try: + # 从 Redis 获取最新消息 + redis_messages = await self._get_from_redis(channel_id, limit, since) + + # 为每条消息构建响应对象 + async with with_db() as session: + for msg_data in redis_messages: + # 获取发送者信息 + sender = await session.get(User, msg_data["sender_id"]) + if sender: + user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES) + + if user_resp.statistics is None: + from app.database.statistics import UserStatisticsResp + user_resp.statistics = UserStatisticsResp( + mode=sender.playmode, + global_rank=0, country_rank=0, pp=0.0, + ranked_score=0, hit_accuracy=0.0, play_count=0, + play_time=0, total_score=0, total_hits=0, + maximum_combo=0, replays_watched_by_others=0, + is_ranked=False, + grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0}, + level={"current": 1, "progress": 0} + ) + + message_resp = ChatMessageResp( + message_id=msg_data["message_id"], + channel_id=msg_data["channel_id"], + content=msg_data["content"], + timestamp=datetime.fromisoformat(msg_data["timestamp"]), + sender_id=msg_data["sender_id"], + sender=user_resp, + is_action=msg_data["type"] == MessageType.ACTION.value, + uuid=msg_data.get("uuid") or None + ) + messages.append(message_resp) + + # 如果 Redis 消息不够,从数据库补充 + if len(messages) < limit and since == 0: + await self._backfill_from_database(channel_id, messages, limit) + + except Exception as e: + logger.error(f"Failed to get messages from Redis: {e}") + # 回退到数据库查询 + messages = await self._get_from_database_only(channel_id, limit, since) + + return messages[:limit] + + async def _generate_message_id(self, channel_id: int) -> int: + """生成唯一的消息ID - 确保全局唯一且严格递增""" + # 使用全局计数器确保所有频道的消息ID都是严格递增的 + message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter") + + # 同时更新频道的最后消息ID,用于客户端状态同步 + await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id) + + return message_id + + async def _store_to_redis(self, message_id: int, channel_id: int, message_data: Dict[str, Any]): + """存储消息到 Redis""" + try: + # 存储消息数据 + await self._redis_exec( + self.redis.hset, + f"msg:{channel_id}:{message_id}", + mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) + for k, v in message_data.items()} + ) + + # 设置消息过期时间(7天) + await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800) + + # 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序) + channel_messages_key = f"channel:{channel_id}:messages" + + # 检查键的类型,如果不是 zset 类型则删除 + try: + key_type = await self._redis_exec(self.redis.type, channel_messages_key) + if key_type and key_type != "zset": + logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}") + await self._redis_exec(self.redis.delete, channel_messages_key) + except Exception as type_check_error: + logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}") + # 如果检查失败,直接删除键以确保清理 + await self._redis_exec(self.redis.delete, channel_messages_key) + + # 添加到频道消息列表(sorted set) + await self._redis_exec( + self.redis.zadd, + channel_messages_key, + {f"msg:{channel_id}:{message_id}": message_id} + ) + + # 保持频道消息列表大小(最多1000条) + await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001) + + # 添加到待持久化队列 + await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}") + + except Exception as e: + logger.error(f"Failed to store message to Redis: {e}") + raise + + async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> List[Dict[str, Any]]: + """从 Redis 获取消息""" + try: + # 获取消息键列表,按消息ID排序 + if since > 0: + # 获取指定ID之后的消息(正序) + message_keys = await self._redis_exec( + self.redis.zrangebyscore, + f"channel:{channel_id}:messages", + since + 1, "+inf", + start=0, num=limit + ) + else: + # 获取最新的消息(倒序获取,然后反转) + message_keys = await self._redis_exec( + self.redis.zrevrange, + f"channel:{channel_id}:messages", + 0, limit - 1 + ) + + messages = [] + for key in message_keys: + if isinstance(key, bytes): + key = key.decode('utf-8') + + # 获取消息数据 + raw_data = await self._redis_exec(self.redis.hgetall, key) + if raw_data: + # 解码数据 + message_data = {} + for k, v in raw_data.items(): + if isinstance(k, bytes): + k = k.decode('utf-8') + if isinstance(v, bytes): + v = v.decode('utf-8') + + # 尝试解析 JSON + try: + if k in ['grade_counts', 'level'] or v.startswith(('{', '[')): + message_data[k] = json.loads(v) + elif k in ['message_id', 'channel_id', 'sender_id']: + message_data[k] = int(v) + elif k == 'created_at': + message_data[k] = float(v) + else: + message_data[k] = v + except (json.JSONDecodeError, ValueError): + message_data[k] = v + + messages.append(message_data) + + # 确保消息按ID正序排序(时间顺序) + messages.sort(key=lambda x: x.get('message_id', 0)) + + # 如果是获取最新消息(since=0),需要保持倒序(最新的在前面) + if since == 0: + messages.reverse() + + return messages + + except Exception as e: + logger.error(f"Failed to get messages from Redis: {e}") + return [] + + async def _backfill_from_database(self, channel_id: int, existing_messages: List[ChatMessageResp], limit: int): + """从数据库补充历史消息""" + try: + # 找到最小的消息ID + min_id = float('inf') + if existing_messages: + for msg in existing_messages: + if msg.message_id is not None and msg.message_id < min_id: + min_id = msg.message_id + + needed = limit - len(existing_messages) + + if needed <= 0: + return + + async with with_db() as session: + from sqlmodel import select, col + query = select(ChatMessage).where( + ChatMessage.channel_id == channel_id + ) + + if min_id != float('inf'): + query = query.where(col(ChatMessage.message_id) < min_id) + + query = query.order_by(col(ChatMessage.message_id).desc()).limit(needed) + + db_messages = (await session.exec(query)).all() + + for msg in reversed(db_messages): # 按时间正序插入 + msg_resp = await ChatMessageResp.from_db(msg, session) + existing_messages.insert(0, msg_resp) + + except Exception as e: + logger.error(f"Failed to backfill from database: {e}") + + async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> List[ChatMessageResp]: + """仅从数据库获取消息(回退方案)""" + try: + async with with_db() as session: + from sqlmodel import select, col + query = select(ChatMessage).where(ChatMessage.channel_id == channel_id) + + if since > 0: + # 获取指定ID之后的消息,按ID正序 + query = query.where(col(ChatMessage.message_id) > since) + query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit) + else: + # 获取最新消息,按ID倒序(最新的在前面) + query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit) + + messages = (await session.exec(query)).all() + + results = [await ChatMessageResp.from_db(msg, session) for msg in messages] + + # 如果是 since > 0,保持正序;否则反转为时间正序 + if since == 0: + results.reverse() + + return results + + except Exception as e: + logger.error(f"Failed to get messages from database: {e}") + return [] + + async def _batch_persist_to_database(self): + """批量持久化消息到数据库""" + logger.info("Starting batch persistence to database") + + while self._running: + try: + # 获取待处理的消息 + message_keys = [] + for _ in range(self.max_batch_size): + key = await self._redis_exec( + self.redis.brpop, ["pending_messages"], timeout=1 + ) + if key: + # key 是 (queue_name, value) 的元组 + value = key[1] + if isinstance(value, bytes): + value = value.decode('utf-8') + message_keys.append(value) + else: + break + + if message_keys: + await self._process_message_batch(message_keys) + else: + await asyncio.sleep(self.batch_interval) + + except Exception as e: + logger.error(f"Error in batch persistence: {e}") + await asyncio.sleep(1) + + logger.info("Stopped batch persistence to database") + + async def _process_message_batch(self, message_keys: List[str]): + """处理消息批次""" + async with with_db() as session: + for key in message_keys: + try: + # 解析频道ID和消息ID + channel_id, message_id = map(int, key.split(':')) + + # 从 Redis 获取消息数据 + raw_data = await self._redis_exec( + self.redis.hgetall, f"msg:{channel_id}:{message_id}" + ) + + if not raw_data: + continue + + # 解码数据 + message_data = {} + for k, v in raw_data.items(): + if isinstance(k, bytes): + k = k.decode('utf-8') + if isinstance(v, bytes): + v = v.decode('utf-8') + message_data[k] = v + + # 检查消息是否已存在于数据库 + existing = await session.get(ChatMessage, int(message_id)) + if existing: + continue + + # 创建数据库消息 - 使用 Redis 生成的正数ID + db_message = ChatMessage( + message_id=int(message_id), # 使用 Redis 系统生成的正数ID + channel_id=int(message_data["channel_id"]), + sender_id=int(message_data["sender_id"]), + content=message_data["content"], + timestamp=datetime.fromisoformat(message_data["timestamp"]), + type=MessageType(message_data["type"]), + uuid=message_data.get("uuid") or None + ) + + session.add(db_message) + + # 更新 Redis 中的状态 + await self._redis_exec( + self.redis.hset, + f"msg:{channel_id}:{message_id}", + "status", "persisted" + ) + + logger.debug(f"Message {message_id} persisted to database") + + except Exception as e: + logger.error(f"Failed to process message {key}: {e}") + + # 提交批次 + try: + await session.commit() + logger.info(f"Batch of {len(message_keys)} messages committed to database") + except Exception as e: + logger.error(f"Failed to commit message batch: {e}") + await session.rollback() + + def start(self): + """启动系统""" + if not self._running: + self._running = True + self._batch_timer = asyncio.create_task(self._batch_persist_to_database()) + # 启动时初始化消息ID计数器 + asyncio.create_task(self._initialize_message_counter()) + logger.info("Redis message system started") + + async def _initialize_message_counter(self): + """初始化全局消息ID计数器,确保从数据库最大ID开始""" + try: + # 清理可能存在的问题键 + await self._cleanup_redis_keys() + + async with with_db() as session: + from sqlmodel import select, func + + # 获取数据库中最大的消息ID + result = await session.exec( + select(func.max(ChatMessage.message_id)) + ) + max_id = result.one() or 0 + + # 检查 Redis 中的计数器值 + current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter") + current_counter = int(current_counter) if current_counter else 0 + + # 设置计数器为两者中的最大值 + initial_counter = max(max_id, current_counter) + await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter) + + logger.info(f"Initialized global message ID counter to {initial_counter}") + + except Exception as e: + logger.error(f"Failed to initialize message counter: {e}") + # 如果初始化失败,设置一个安全的起始值 + await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000) + + async def _cleanup_redis_keys(self): + """清理可能存在问题的 Redis 键""" + try: + # 扫描所有 channel:*:messages 键并检查类型 + keys_pattern = "channel:*:messages" + keys = await self._redis_exec(self.redis.keys, keys_pattern) + + for key in keys: + if isinstance(key, bytes): + key = key.decode('utf-8') + + try: + key_type = await self._redis_exec(self.redis.type, key) + if key_type and key_type != "zset": + logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}") + await self._redis_exec(self.redis.delete, key) + except Exception as cleanup_error: + logger.warning(f"Failed to cleanup key {key}: {cleanup_error}") + # 强制删除问题键 + await self._redis_exec(self.redis.delete, key) + + logger.info("Redis keys cleanup completed") + + except Exception as e: + logger.error(f"Failed to cleanup Redis keys: {e}") + + def stop(self): + """停止系统""" + if self._running: + self._running = False + if self._batch_timer: + self._batch_timer.cancel() + self._batch_timer = None + logger.info("Redis message system stopped") + + def __del__(self): + """清理资源""" + if hasattr(self, 'executor'): + self.executor.shutdown(wait=False) + + +# 全局消息系统实例 +redis_message_system = RedisMessageSystem() diff --git a/main.py b/main.py index 51d1f6c..2f449ea 100644 --- a/main.py +++ b/main.py @@ -32,6 +32,7 @@ from app.service.init_geoip import init_geoip from app.service.load_achievements import load_achievements from app.service.osu_rx_statistics import create_rx_statistics from app.service.recalculate import recalculate +from app.service.redis_message_system import redis_message_system # 检查 New Relic 配置文件是否存在,如果存在则初始化 New Relic newrelic_config_path = os.path.join(os.path.dirname(__file__), "newrelic.ini") @@ -77,10 +78,12 @@ async def lifespan(app: FastAPI): await create_banchobot() await download_service.start_health_check() # 启动下载服务健康检查 await start_cache_scheduler() # 启动缓存调度器 + redis_message_system.start() # 启动 Redis 消息系统 load_achievements() # on shutdown yield stop_scheduler() + redis_message_system.stop() # 停止 Redis 消息系统 await stop_cache_scheduler() # 停止缓存调度器 await download_service.stop_health_check() # 停止下载服务健康检查 await engine.dispose()