add message redis

This commit is contained in:
咕谷酱
2025-08-22 01:49:03 +08:00
parent 36b695b531
commit 1fe603f416
11 changed files with 1461 additions and 86 deletions

View File

@@ -59,12 +59,17 @@ class ChatChannel(ChatChannelBase, table=True):
cls, channel: str | int, session: AsyncSession
) -> "ChatChannel | None":
if isinstance(channel, int) or channel.isdigit():
channel_ = await session.get(ChatChannel, channel)
# 使用查询而不是 get() 来确保对象完全加载
result = await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
channel_ = result.first()
if channel_ is not None:
return channel_
return (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
result = await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
return result.first()
@classmethod
async def get_pm_channel(
@@ -235,6 +240,7 @@ class UserSilenceResp(SQLModel):
@classmethod
def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp":
assert db_silence.id is not None
return cls(
id=db_silence.id,
user_id=db_silence.user_id,

View File

@@ -11,6 +11,7 @@ from app.config import settings
from fastapi import Depends
from pydantic import BaseModel
import redis.asyncio as redis
import redis as sync_redis
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -38,6 +39,9 @@ engine = create_async_engine(
# Redis 连接
redis_client = redis.from_url(settings.redis_url, decode_responses=True)
# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行
redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1)
# 数据库依赖
db_session_context: ContextVar[AsyncSession | None] = ContextVar(
@@ -80,5 +84,10 @@ def get_redis():
return redis_client
def get_redis_message():
"""获取消息专用的 Redis 客户端 (db1)"""
return redis_message_client
def get_redis_pubsub():
return redis_client.pubsub()

View File

@@ -151,7 +151,6 @@ class PlaylistItem(BaseModel):
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
from typing import Literal, cast
API_MODS = self._get_api_mods()
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
for i, mod1 in enumerate(mods):
@@ -168,7 +167,6 @@ class PlaylistItem(BaseModel):
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
from typing import Literal, cast
API_MODS = self._get_api_mods()
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
@@ -213,8 +211,6 @@ class PlaylistItem(BaseModel):
"""
from typing import Literal, cast
API_MODS = self._get_api_mods()
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
@@ -386,7 +382,7 @@ class MultiplayerRoom(BaseModel):
match_state=None,
playlist=playlist,
active_countdowns=[],
channel_id=room.channel_id,
channel_id=room.channel_id or 0,
)

View File

@@ -53,16 +53,24 @@ async def get_update(
assert current_user.id
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
channel = await ChatChannel.get(channel_id, session)
if channel:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel:
# 提取必要的属性避免惰性加载
channel_type = db_channel.type
resp.presence.append(
await ChatChannelResp.from_db(
channel,
db_channel,
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel.type != ChannelType.PUBLIC
if channel_type != ChannelType.PUBLIC
else None,
)
)
@@ -105,7 +113,19 @@ async def join_channel(
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -125,7 +145,19 @@ async def leave_channel(
user: str = Path(..., description="用户 ID"),
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -152,15 +184,19 @@ async def get_channel_list(
).all()
results = []
for channel in channels:
assert channel.channel_id is not None
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
channel_type = channel.type
assert channel_id is not None
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, [])
if channel.type != ChannelType.PUBLIC
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
)
)
@@ -185,14 +221,33 @@ async def get_channel(
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id is not None
# 立即提取需要的属性
channel_id = db_channel.channel_id
channel_type = db_channel.type
channel_name = db_channel.name
assert channel_id is not None
users = []
if db_channel.type == ChannelType.PM:
user_ids = db_channel.name.split("_")[1:]
if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:]
if len(user_ids) != 2:
raise HTTPException(status_code=404, detail="Target user not found")
for id_ in user_ids:
@@ -210,8 +265,8 @@ async def get_channel(
session,
current_user,
redis,
server.channels.get(db_channel.channel_id, [])
if db_channel.type != ChannelType.PUBLIC
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
)
)
@@ -270,7 +325,8 @@ async def create_channel(
channel_name = f"pm_{current_user.id}_{req.target_id}"
else:
channel_name = req.channel.name if req.channel else "Unnamed Channel"
channel = await ChatChannel.get(channel_name, session)
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
channel = result.first()
if channel is None:
channel = ChatChannel(
@@ -294,12 +350,16 @@ async def create_channel(
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session)
assert channel.channel_id
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
assert channel_id
return await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel.channel_id, []),
server.channels.get(channel_id, []),
include_recent_messages=True,
)

View File

@@ -1,5 +1,10 @@
from __future__ import annotations
import json
import uuid
from datetime import datetime
from typing import Optional
from app.database import ChatMessageResp
from app.database.chat import (
ChannelType,
@@ -11,11 +16,14 @@ from app.database.chat import (
UserSilenceResp,
)
from app.database.lazer_user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.database import Database, get_redis, get_redis_message
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.models.notification import ChannelMessage, ChannelMessageTeam
from app.router.v2 import api_v2_router as router
from app.service.optimized_message import optimized_message_service
from app.service.redis_message_system import redis_message_system
from app.log import logger
from .banchobot import bot
from .server import server
@@ -89,42 +97,73 @@ async def send_message(
req: MessageReq = Depends(BodyOrForm(MessageReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询来获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id
# 立即提取所有需要的属性,避免后续延迟加载
channel_id = db_channel.channel_id
channel_type = db_channel.type
channel_name = db_channel.name
assert channel_id is not None
assert current_user.id
msg = ChatMessage(
channel_id=db_channel.channel_id,
# 使用 Redis 消息系统发送消息 - 立即返回
resp = await redis_message_system.send_message(
channel_id=channel_id,
user=current_user,
content=req.message,
sender_id=current_user.id,
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
uuid=req.uuid,
is_action=req.is_action,
user_uuid=req.uuid
)
session.add(msg)
await session.commit()
await session.refresh(msg)
await session.refresh(current_user)
await session.refresh(db_channel)
resp = await ChatMessageResp.from_db(msg, session, current_user)
# 立即广播消息给所有客户端
is_bot_command = req.message.startswith("!")
await server.send_message_to_channel(
resp, is_bot_command and db_channel.type == ChannelType.PUBLIC
resp, is_bot_command and channel_type == ChannelType.PUBLIC
)
# 处理机器人命令
if is_bot_command:
await bot.try_handle(current_user, db_channel, req.message, session)
if db_channel.type == ChannelType.PM:
user_ids = db_channel.name.split("_")[1:]
await server.new_private_notification(
ChannelMessage.init(
msg, current_user, [int(u) for u in user_ids], db_channel.type
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
temp_msg = ChatMessage(
message_id=resp.message_id, # 使用 Redis 系统生成的ID
channel_id=channel_id,
content=req.message,
sender_id=current_user.id,
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
uuid=req.uuid,
)
if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:]
await server.new_private_notification(
ChannelMessage.init(
temp_msg, current_user, [int(u) for u in user_ids], channel_type
)
)
)
elif db_channel.type == ChannelType.TEAM:
await server.new_private_notification(
ChannelMessageTeam.init(msg, current_user)
)
elif channel_type == ChannelType.TEAM:
await server.new_private_notification(
ChannelMessageTeam.init(temp_msg, current_user)
)
return resp
@@ -143,21 +182,46 @@ async def get_message(
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
messages = await session.exec(
select(ChatMessage)
.where(
ChatMessage.channel_id == db_channel.channel_id,
col(ChatMessage.message_id) > since,
col(ChatMessage.message_id) < until if until is not None else True,
)
.order_by(col(ChatMessage.timestamp).desc())
.limit(limit)
)
# 提取必要的属性避免惰性加载
channel_id = db_channel.channel_id
assert channel_id is not None
# 使用 Redis 消息系统获取消息
try:
messages = await redis_message_system.get_messages(channel_id, limit, since)
return messages
except Exception as e:
logger.warning(f"Failed to get messages from Redis system: {e}")
# 回退到传统数据库查询
pass
# 回退到数据库查询
query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
if since > 0:
query = query.where(col(ChatMessage.message_id) > since)
if until is not None:
query = query.where(col(ChatMessage.message_id) < until)
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
messages = (await session.exec(query)).all()
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
resp.reverse()
return resp
@@ -174,12 +238,28 @@ async def mark_as_read(
message: int = Path(..., description="消息 ID"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
):
db_channel = await ChatChannel.get(channel, session)
# 使用明确的查询获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else:
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
assert db_channel.channel_id
# 立即提取需要的属性
channel_id = db_channel.channel_id
assert channel_id
assert current_user.id
await server.mark_as_read(db_channel.channel_id, current_user.id, message)
await server.mark_as_read(channel_id, current_user.id, message)
class PMReq(BaseModel):

View File

@@ -59,9 +59,14 @@ class ChatServer:
for channel_id, channel in self.channels.items():
if user_id in channel:
channel.remove(user_id)
channel = await ChatChannel.get(channel_id, session)
if channel:
await self.leave_channel(user, channel, session)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel:
await self.leave_channel(user, db_channel, session)
@overload
async def send_event(self, client: int, event: ChatEvent): ...
@@ -79,8 +84,11 @@ class ChatServer:
await client.send_text(event.model_dump_json())
async def broadcast(self, channel_id: int, event: ChatEvent):
for user_id in self.channels.get(channel_id, []):
users_in_channel = self.channels.get(channel_id, [])
logger.info(f"Broadcasting to channel {channel_id}, users: {users_in_channel}")
for user_id in users_in_channel:
await self.send_event(user_id, event)
logger.debug(f"Sent event to user {user_id} in channel {channel_id}")
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
@@ -88,24 +96,35 @@ class ChatServer:
async def send_message_to_channel(
self, message: ChatMessageResp, is_bot_command: bool = False
):
logger.info(f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}")
event = ChatEvent(
event="chat.message.new",
data={"messages": [message], "users": [message.sender]},
)
if is_bot_command:
logger.info(f"Sending bot command to user {message.sender_id}")
self._add_task(self.send_event(message.sender_id, event))
else:
# 总是广播消息无论是临时ID还是真实ID
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
self._add_task(
self.broadcast(
message.channel_id,
event,
)
)
assert message.message_id
await self.mark_as_read(
message.channel_id, message.sender_id, message.message_id
)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
# 只有真实消息 ID正数且非零才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数所以这里应该都能正常处理
if message.message_id and message.message_id > 0:
await self.mark_as_read(
message.channel_id, message.sender_id, message.message_id
)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
else:
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
async def batch_join_channel(
self, users: list[User], channel: ChatChannel, session: AsyncSession
@@ -206,27 +225,37 @@ class ChatServer:
async def join_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.join_channel(user, channel, session)
await self.join_channel(user, db_channel, session)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
channel = await ChatChannel.get(channel_id, session)
if channel is None:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel is None:
return
user = await session.get(User, user_id)
if user is None:
return
await self.leave_channel(user, channel, session)
await self.leave_channel(user, db_channel, session)
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
@@ -309,7 +338,13 @@ async def chat_websocket(
user_id = user.id
assert user_id
server.connect(user_id, websocket)
channel = await ChatChannel.get(1, session)
if channel is not None:
await server.join_channel(user, channel, session)
await _listen_stop(websocket, user_id, factory)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == 1)
)
).first()
if db_channel is not None:
await server.join_channel(user, db_channel, session)
await _listen_stop(websocket, user_id, factory)

View File

@@ -0,0 +1,217 @@
"""
Redis 消息队列服务
用于实现实时消息推送和异步数据库持久化
"""
import asyncio
import json
import uuid
from datetime import datetime
from functools import partial
from typing import Optional, Union
import concurrent.futures
from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType
from app.dependencies.database import get_redis, with_db
from app.log import logger
class MessageQueue:
"""Redis 消息队列服务"""
def __init__(self):
self.redis = get_redis()
self._processing = False
self._batch_size = 50 # 批量处理大小
self._batch_timeout = 1.0 # 批量处理超时时间(秒)
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
async def _run_in_executor(self, func, *args):
"""在线程池中运行同步 Redis 操作"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self._executor, func, *args)
async def start_processing(self):
"""启动消息处理任务"""
if not self._processing:
self._processing = True
asyncio.create_task(self._process_message_queue())
logger.info("Message queue processing started")
async def stop_processing(self):
"""停止消息处理"""
self._processing = False
logger.info("Message queue processing stopped")
async def enqueue_message(self, message_data: dict) -> str:
"""
将消息加入 Redis 队列(实时响应)
Args:
message_data: 消息数据字典,包含所有必要的字段
Returns:
消息的临时 UUID
"""
# 生成临时 UUID
temp_uuid = str(uuid.uuid4())
message_data["temp_uuid"] = temp_uuid
message_data["timestamp"] = datetime.now().isoformat()
message_data["status"] = "pending" # pending, processing, completed, failed
# 将消息存储到 Redis
await self._run_in_executor(
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data)
)
await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
# 加入处理队列
await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid)
logger.info(f"Message enqueued with temp_uuid: {temp_uuid}")
return temp_uuid
async def get_message_status(self, temp_uuid: str) -> Optional[dict]:
"""获取消息状态"""
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
if not message_data:
return None
return message_data
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
"""
从 Redis 获取缓存的消息
Args:
channel_id: 频道 ID
limit: 限制数量
since: 获取自此消息 ID 之后的消息
Returns:
消息列表
"""
# 从 Redis 获取频道最近的消息 UUID 列表
message_uuids = await self._run_in_executor(
self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1
)
messages = []
for uuid_str in message_uuids:
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}")
if message_data:
# 检查是否满足 since 条件
if since > 0 and "message_id" in message_data:
if int(message_data["message_id"]) <= since:
continue
messages.append(message_data)
return messages[::-1] # 返回时间顺序
async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100):
"""将消息 UUID 缓存到频道消息列表"""
# 添加到频道消息列表开头
await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid)
# 限制缓存大小
await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1)
# 设置过期时间24小时
await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400)
async def _process_message_queue(self):
"""异步处理消息队列,批量写入数据库"""
while self._processing:
try:
# 批量获取消息
message_uuids = []
for _ in range(self._batch_size):
result = await self._run_in_executor(
lambda: self.redis.brpop(["message_queue"], timeout=1)
)
if result:
message_uuids.append(result[1])
else:
break
if message_uuids:
await self._process_message_batch(message_uuids)
else:
# 没有消息时短暂等待
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error processing message queue: {e}")
await asyncio.sleep(1) # 错误时等待1秒再重试
async def _process_message_batch(self, message_uuids: list[str]):
"""批量处理消息写入数据库"""
async with with_db() as session:
messages_to_insert = []
for temp_uuid in message_uuids:
try:
# 获取消息数据
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
if not message_data:
continue
# 更新状态为处理中
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing")
# 创建数据库消息对象
msg = ChatMessage(
channel_id=int(message_data["channel_id"]),
content=message_data["content"],
sender_id=int(message_data["sender_id"]),
type=MessageType(message_data["type"]),
uuid=message_data.get("user_uuid") # 用户提供的 UUID如果有
)
messages_to_insert.append((msg, temp_uuid))
except Exception as e:
logger.error(f"Error preparing message {temp_uuid}: {e}")
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
if messages_to_insert:
try:
# 批量插入数据库
for msg, temp_uuid in messages_to_insert:
session.add(msg)
await session.commit()
# 更新所有消息状态和真实 ID
for msg, temp_uuid in messages_to_insert:
await session.refresh(msg)
await self._run_in_executor(
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping={
"status": "completed",
"message_id": str(msg.message_id),
"created_at": msg.timestamp.isoformat() if msg.timestamp else ""
})
)
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}")
except Exception as e:
logger.error(f"Error inserting messages to database: {e}")
await session.rollback()
# 标记所有消息为失败
for _, temp_uuid in messages_to_insert:
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
# 全局消息队列实例
message_queue = MessageQueue()
async def start_message_queue():
"""启动消息队列处理"""
await message_queue.start_processing()
async def stop_message_queue():
"""停止消息队列处理"""
await message_queue.stop_processing()

View File

@@ -0,0 +1,282 @@
"""
消息队列处理服务
专门处理 Redis 消息队列的异步写入数据库
"""
import asyncio
import json
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Optional
from app.database.chat import ChatMessage, MessageType
from app.dependencies.database import get_redis_message, with_db
from app.log import logger
class MessageQueueProcessor:
"""消息队列处理器"""
def __init__(self):
self.redis_message = get_redis_message()
self.executor = ThreadPoolExecutor(max_workers=2)
self._processing = False
self._queue_task = None
async def _redis_exec(self, func, *args, **kwargs):
"""在线程池中执行 Redis 操作"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
async def cache_message(self, channel_id: int, message_data: dict, temp_uuid: str):
"""将消息缓存到 Redis"""
try:
# 存储消息数据
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data)
await self._redis_exec(self.redis_message.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
# 加入频道消息列表
await self._redis_exec(self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid)
await self._redis_exec(self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99) # 保持最新100条
await self._redis_exec(self.redis_message.expire, f"channel:{channel_id}:messages", 86400) # 24小时过期
# 加入异步处理队列
await self._redis_exec(self.redis_message.lpush, "message_write_queue", temp_uuid)
logger.info(f"Message cached to Redis: {temp_uuid}")
except Exception as e:
logger.error(f"Failed to cache message to Redis: {e}")
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
"""从 Redis 获取缓存的消息"""
try:
message_uuids = await self._redis_exec(
self.redis_message.lrange, f"channel:{channel_id}:messages", 0, limit - 1
)
messages = []
for temp_uuid in message_uuids:
# 解码 UUID 如果它是字节类型
if isinstance(temp_uuid, bytes):
temp_uuid = temp_uuid.decode('utf-8')
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
if raw_data:
# 解码 Redis 返回的字节数据
message_data = {
k.decode('utf-8') if isinstance(k, bytes) else k:
v.decode('utf-8') if isinstance(v, bytes) else v
for k, v in raw_data.items()
}
# 检查 since 条件
if since > 0 and message_data.get("message_id"):
if int(message_data["message_id"]) <= since:
continue
messages.append(message_data)
return messages[::-1] # 按时间顺序返回
except Exception as e:
logger.error(f"Failed to get cached messages: {e}")
return []
async def update_message_status(self, temp_uuid: str, status: str, message_id: Optional[int] = None):
"""更新消息状态"""
try:
update_data = {"status": status}
if message_id:
update_data["message_id"] = str(message_id)
update_data["db_timestamp"] = datetime.now().isoformat()
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data)
except Exception as e:
logger.error(f"Failed to update message status: {e}")
async def get_message_status(self, temp_uuid: str) -> Optional[dict]:
"""获取消息状态"""
try:
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
if not raw_data:
return None
# 解码 Redis 返回的字节数据
return {
k.decode('utf-8') if isinstance(k, bytes) else k:
v.decode('utf-8') if isinstance(v, bytes) else v
for k, v in raw_data.items()
}
except Exception as e:
logger.error(f"Failed to get message status: {e}")
return None
async def _process_message_queue(self):
"""处理消息队列,异步写入数据库"""
logger.info("Message queue processing started")
while self._processing:
try:
# 批量获取消息
message_uuids = []
for _ in range(20): # 批量处理20条消息
result = await self._redis_exec(
self.redis_message.brpop, ["message_write_queue"], timeout=1
)
if result:
# result是 (queue_name, value) 的元组,需要解码
uuid_value = result[1]
if isinstance(uuid_value, bytes):
uuid_value = uuid_value.decode('utf-8')
message_uuids.append(uuid_value)
else:
break
if not message_uuids:
await asyncio.sleep(0.5)
continue
# 批量写入数据库
await self._process_message_batch(message_uuids)
except Exception as e:
logger.error(f"Error in message queue processing: {e}")
await asyncio.sleep(1)
logger.info("Message queue processing stopped")
async def _process_message_batch(self, message_uuids: list[str]):
"""批量处理消息写入数据库"""
async with with_db() as session:
for temp_uuid in message_uuids:
try:
# 获取消息数据并解码
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
if not raw_data:
continue
# 解码 Redis 返回的字节数据
message_data = {
k.decode('utf-8') if isinstance(k, bytes) else k:
v.decode('utf-8') if isinstance(v, bytes) else v
for k, v in raw_data.items()
}
if message_data.get("status") != "pending":
continue
# 更新状态为处理中
await self.update_message_status(temp_uuid, "processing")
# 创建数据库消息
msg = ChatMessage(
channel_id=int(message_data["channel_id"]),
content=message_data["content"],
sender_id=int(message_data["sender_id"]),
type=MessageType(message_data["type"]),
uuid=message_data.get("user_uuid") or None,
)
session.add(msg)
await session.commit()
await session.refresh(msg)
# 更新成功状态包含临时消息ID映射
assert msg.message_id is not None
await self.update_message_status(temp_uuid, "completed", msg.message_id)
# 如果有临时消息ID存储映射关系并通知客户端更新
if message_data.get("temp_message_id"):
temp_msg_id = int(message_data["temp_message_id"])
await self._redis_exec(
self.redis_message.set,
f"temp_to_real:{temp_msg_id}",
str(msg.message_id),
ex=3600 # 1小时过期
)
# 发送消息ID更新通知到频道
channel_id = int(message_data["channel_id"])
await self._notify_message_update(channel_id, temp_msg_id, msg.message_id, message_data)
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}")
except Exception as e:
logger.error(f"Failed to process message {temp_uuid}: {e}")
await self.update_message_status(temp_uuid, "failed")
async def _notify_message_update(self, channel_id: int, temp_message_id: int, real_message_id: int, message_data: dict):
"""通知客户端消息ID已更新"""
try:
# 这里我们需要通过 SignalR 发送消息更新通知
# 但为了避免循环依赖,我们将通过 Redis 发布消息更新事件
update_event = {
"event": "chat.message.update",
"data": {
"channel_id": channel_id,
"temp_message_id": temp_message_id,
"real_message_id": real_message_id,
"timestamp": message_data.get("timestamp")
}
}
# 发布到 Redis 频道,让 SignalR 服务处理
await self._redis_exec(
self.redis_message.publish,
f"chat_updates:{channel_id}",
json.dumps(update_event)
)
logger.info(f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}")
except Exception as e:
logger.error(f"Failed to notify message update: {e}")
def start_processing(self):
"""启动消息队列处理"""
if not self._processing:
self._processing = True
self._queue_task = asyncio.create_task(self._process_message_queue())
logger.info("Message queue processor started")
def stop_processing(self):
"""停止消息队列处理"""
if self._processing:
self._processing = False
if self._queue_task:
self._queue_task.cancel()
self._queue_task = None
logger.info("Message queue processor stopped")
def __del__(self):
"""清理资源"""
if hasattr(self, 'executor'):
self.executor.shutdown(wait=False)
# 全局消息队列处理器实例
message_queue_processor = MessageQueueProcessor()
def start_message_processing():
"""启动消息队列处理"""
message_queue_processor.start_processing()
def stop_message_processing():
"""停止消息队列处理"""
message_queue_processor.stop_processing()
async def cache_message_to_redis(channel_id: int, message_data: dict, temp_uuid: str):
"""将消息缓存到 Redis - 便捷接口"""
await message_queue_processor.cache_message(channel_id, message_data, temp_uuid)
async def get_cached_messages(channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
"""从 Redis 获取缓存的消息 - 便捷接口"""
return await message_queue_processor.get_cached_messages(channel_id, limit, since)
async def get_message_status(temp_uuid: str) -> Optional[dict]:
"""获取消息状态 - 便捷接口"""
return await message_queue_processor.get_message_status(temp_uuid)

View File

@@ -0,0 +1,150 @@
"""
优化的消息服务
结合 Redis 缓存和异步数据库写入实现实时消息传送
"""
from typing import Optional
from fastapi import HTTPException
from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType, ChatMessageResp
from app.database.lazer_user import User
from app.router.notification.server import server
from app.service.message_queue import message_queue
from app.log import logger
from sqlalchemy.ext.asyncio import AsyncSession
class OptimizedMessageService:
"""优化的消息服务"""
def __init__(self):
self.message_queue = message_queue
async def send_message_fast(
self,
channel_id: int,
channel_type: ChannelType,
channel_name: str,
content: str,
sender: User,
is_action: bool = False,
user_uuid: Optional[str] = None,
session: Optional[AsyncSession] = None
) -> ChatMessageResp:
"""
快速发送消息(先缓存到 Redis异步写入数据库
Args:
channel_id: 频道 ID
channel_type: 频道类型
channel_name: 频道名称
content: 消息内容
sender: 发送者
is_action: 是否为动作消息
user_uuid: 用户提供的 UUID
session: 数据库会话(可选,用于一些验证)
Returns:
消息响应对象
"""
assert sender.id is not None
# 准备消息数据
message_data = {
"channel_id": str(channel_id),
"content": content,
"sender_id": str(sender.id),
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
"user_uuid": user_uuid or "",
"channel_type": channel_type.value,
"channel_name": channel_name
}
# 立即将消息加入 Redis 队列(实时响应)
temp_uuid = await self.message_queue.enqueue_message(message_data)
# 缓存到频道消息列表
await self.message_queue.cache_channel_message(channel_id, temp_uuid)
# 创建临时响应对象(简化版本,用于立即响应)
from datetime import datetime
from app.database.lazer_user import UserResp
# 创建基本的用户响应对象
user_resp = UserResp(
id=sender.id,
username=sender.username,
country_code=getattr(sender, 'country_code', 'XX'),
# 基本字段,其他复杂字段可以后续异步加载
)
temp_response = ChatMessageResp(
message_id=0, # 临时 ID等数据库写入后会更新
channel_id=channel_id,
content=content,
timestamp=datetime.now(),
sender_id=sender.id,
sender=user_resp,
is_action=is_action,
uuid=user_uuid
)
temp_response.temp_uuid = temp_uuid # 添加临时 UUID 用于后续更新
logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}")
return temp_response
async def get_cached_messages(
self,
channel_id: int,
limit: int = 50,
since: int = 0
) -> list[dict]:
"""
获取缓存的消息
Args:
channel_id: 频道 ID
limit: 限制数量
since: 获取自此消息 ID 之后的消息
Returns:
消息列表
"""
return await self.message_queue.get_cached_messages(channel_id, limit, since)
async def get_message_status(self, temp_uuid: str) -> Optional[dict]:
"""
获取消息状态
Args:
temp_uuid: 临时消息 UUID
Returns:
消息状态信息
"""
return await self.message_queue.get_message_status(temp_uuid)
async def wait_for_message_persisted(self, temp_uuid: str, timeout: int = 30) -> Optional[dict]:
"""
等待消息持久化到数据库
Args:
temp_uuid: 临时消息 UUID
timeout: 超时时间(秒)
Returns:
完成后的消息状态
"""
import asyncio
for _ in range(timeout * 10): # 每100ms检查一次
status = await self.get_message_status(temp_uuid)
if status and status.get("status") in ["completed", "failed"]:
return status
await asyncio.sleep(0.1)
return None
# 全局优化消息服务实例
optimized_message_service = OptimizedMessageService()

View File

@@ -0,0 +1,537 @@
"""
基于 Redis 的实时消息系统
- 消息立即存储到 Redis 并实时返回
- 定时批量存储到数据库
- 支持消息状态同步和故障恢复
"""
import asyncio
import json
import time
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any
from concurrent.futures import ThreadPoolExecutor
from app.database.chat import ChatMessage, MessageType, ChatMessageResp
from app.database.lazer_user import User, UserResp, RANKING_INCLUDES
from app.dependencies.database import get_redis_message, with_db
from app.log import logger
class RedisMessageSystem:
"""Redis 消息系统"""
def __init__(self):
self.redis = get_redis_message()
self.executor = ThreadPoolExecutor(max_workers=2)
self._batch_timer: Optional[asyncio.Task] = None
self._running = False
self.batch_interval = 5.0 # 5秒批量存储一次
self.max_batch_size = 100 # 每批最多处理100条消息
async def _redis_exec(self, func, *args, **kwargs):
"""在线程池中执行 Redis 操作"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
async def send_message(self, channel_id: int, user: User, content: str,
is_action: bool = False, user_uuid: Optional[str] = None) -> ChatMessageResp:
"""
发送消息 - 立即存储到 Redis 并返回
Args:
channel_id: 频道ID
user: 发送用户
content: 消息内容
is_action: 是否为动作消息
user_uuid: 用户UUID
Returns:
ChatMessageResp: 消息响应对象
"""
# 生成消息ID和时间戳
message_id = await self._generate_message_id(channel_id)
timestamp = datetime.now()
# 确保用户ID存在
if not user.id:
raise ValueError("User ID is required")
# 准备消息数据
message_data = {
"message_id": message_id,
"channel_id": channel_id,
"sender_id": user.id,
"content": content,
"timestamp": timestamp.isoformat(),
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
"uuid": user_uuid or "",
"status": "cached", # Redis 缓存状态
"created_at": time.time()
}
# 立即存储到 Redis
await self._store_to_redis(message_id, channel_id, message_data)
# 创建响应对象
async with with_db() as session:
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
# 确保 statistics 不为空
if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp
user_resp.statistics = UserStatisticsResp(
mode=user.playmode,
global_rank=0,
country_rank=0,
pp=0.0,
ranked_score=0,
hit_accuracy=0.0,
play_count=0,
play_time=0,
total_score=0,
total_hits=0,
maximum_combo=0,
replays_watched_by_others=0,
is_ranked=False,
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
level={"current": 1, "progress": 0}
)
response = ChatMessageResp(
message_id=message_id,
channel_id=channel_id,
content=content,
timestamp=timestamp,
sender_id=user.id,
sender=user_resp,
is_action=is_action,
uuid=user_uuid
)
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
return response
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> List[ChatMessageResp]:
"""
获取频道消息 - 优先从 Redis 获取最新消息
Args:
channel_id: 频道ID
limit: 消息数量限制
since: 起始消息ID
Returns:
List[ChatMessageResp]: 消息列表
"""
messages = []
try:
# 从 Redis 获取最新消息
redis_messages = await self._get_from_redis(channel_id, limit, since)
# 为每条消息构建响应对象
async with with_db() as session:
for msg_data in redis_messages:
# 获取发送者信息
sender = await session.get(User, msg_data["sender_id"])
if sender:
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp
user_resp.statistics = UserStatisticsResp(
mode=sender.playmode,
global_rank=0, country_rank=0, pp=0.0,
ranked_score=0, hit_accuracy=0.0, play_count=0,
play_time=0, total_score=0, total_hits=0,
maximum_combo=0, replays_watched_by_others=0,
is_ranked=False,
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
level={"current": 1, "progress": 0}
)
message_resp = ChatMessageResp(
message_id=msg_data["message_id"],
channel_id=msg_data["channel_id"],
content=msg_data["content"],
timestamp=datetime.fromisoformat(msg_data["timestamp"]),
sender_id=msg_data["sender_id"],
sender=user_resp,
is_action=msg_data["type"] == MessageType.ACTION.value,
uuid=msg_data.get("uuid") or None
)
messages.append(message_resp)
# 如果 Redis 消息不够,从数据库补充
if len(messages) < limit and since == 0:
await self._backfill_from_database(channel_id, messages, limit)
except Exception as e:
logger.error(f"Failed to get messages from Redis: {e}")
# 回退到数据库查询
messages = await self._get_from_database_only(channel_id, limit, since)
return messages[:limit]
async def _generate_message_id(self, channel_id: int) -> int:
"""生成唯一的消息ID - 确保全局唯一且严格递增"""
# 使用全局计数器确保所有频道的消息ID都是严格递增的
message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter")
# 同时更新频道的最后消息ID用于客户端状态同步
await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id)
return message_id
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: Dict[str, Any]):
"""存储消息到 Redis"""
try:
# 存储消息数据
await self._redis_exec(
self.redis.hset,
f"msg:{channel_id}:{message_id}",
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)
for k, v in message_data.items()}
)
# 设置消息过期时间7天
await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800)
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
channel_messages_key = f"channel:{channel_id}:messages"
# 检查键的类型,如果不是 zset 类型则删除
try:
key_type = await self._redis_exec(self.redis.type, channel_messages_key)
if key_type and key_type != "zset":
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
await self._redis_exec(self.redis.delete, channel_messages_key)
except Exception as type_check_error:
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
# 如果检查失败,直接删除键以确保清理
await self._redis_exec(self.redis.delete, channel_messages_key)
# 添加到频道消息列表sorted set
await self._redis_exec(
self.redis.zadd,
channel_messages_key,
{f"msg:{channel_id}:{message_id}": message_id}
)
# 保持频道消息列表大小最多1000条
await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001)
# 添加到待持久化队列
await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}")
except Exception as e:
logger.error(f"Failed to store message to Redis: {e}")
raise
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> List[Dict[str, Any]]:
"""从 Redis 获取消息"""
try:
# 获取消息键列表按消息ID排序
if since > 0:
# 获取指定ID之后的消息正序
message_keys = await self._redis_exec(
self.redis.zrangebyscore,
f"channel:{channel_id}:messages",
since + 1, "+inf",
start=0, num=limit
)
else:
# 获取最新的消息(倒序获取,然后反转)
message_keys = await self._redis_exec(
self.redis.zrevrange,
f"channel:{channel_id}:messages",
0, limit - 1
)
messages = []
for key in message_keys:
if isinstance(key, bytes):
key = key.decode('utf-8')
# 获取消息数据
raw_data = await self._redis_exec(self.redis.hgetall, key)
if raw_data:
# 解码数据
message_data = {}
for k, v in raw_data.items():
if isinstance(k, bytes):
k = k.decode('utf-8')
if isinstance(v, bytes):
v = v.decode('utf-8')
# 尝试解析 JSON
try:
if k in ['grade_counts', 'level'] or v.startswith(('{', '[')):
message_data[k] = json.loads(v)
elif k in ['message_id', 'channel_id', 'sender_id']:
message_data[k] = int(v)
elif k == 'created_at':
message_data[k] = float(v)
else:
message_data[k] = v
except (json.JSONDecodeError, ValueError):
message_data[k] = v
messages.append(message_data)
# 确保消息按ID正序排序时间顺序
messages.sort(key=lambda x: x.get('message_id', 0))
# 如果是获取最新消息since=0需要保持倒序最新的在前面
if since == 0:
messages.reverse()
return messages
except Exception as e:
logger.error(f"Failed to get messages from Redis: {e}")
return []
async def _backfill_from_database(self, channel_id: int, existing_messages: List[ChatMessageResp], limit: int):
"""从数据库补充历史消息"""
try:
# 找到最小的消息ID
min_id = float('inf')
if existing_messages:
for msg in existing_messages:
if msg.message_id is not None and msg.message_id < min_id:
min_id = msg.message_id
needed = limit - len(existing_messages)
if needed <= 0:
return
async with with_db() as session:
from sqlmodel import select, col
query = select(ChatMessage).where(
ChatMessage.channel_id == channel_id
)
if min_id != float('inf'):
query = query.where(col(ChatMessage.message_id) < min_id)
query = query.order_by(col(ChatMessage.message_id).desc()).limit(needed)
db_messages = (await session.exec(query)).all()
for msg in reversed(db_messages): # 按时间正序插入
msg_resp = await ChatMessageResp.from_db(msg, session)
existing_messages.insert(0, msg_resp)
except Exception as e:
logger.error(f"Failed to backfill from database: {e}")
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> List[ChatMessageResp]:
"""仅从数据库获取消息(回退方案)"""
try:
async with with_db() as session:
from sqlmodel import select, col
query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
if since > 0:
# 获取指定ID之后的消息按ID正序
query = query.where(col(ChatMessage.message_id) > since)
query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit)
else:
# 获取最新消息按ID倒序最新的在前面
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
messages = (await session.exec(query)).all()
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
# 如果是 since > 0保持正序否则反转为时间正序
if since == 0:
results.reverse()
return results
except Exception as e:
logger.error(f"Failed to get messages from database: {e}")
return []
async def _batch_persist_to_database(self):
"""批量持久化消息到数据库"""
logger.info("Starting batch persistence to database")
while self._running:
try:
# 获取待处理的消息
message_keys = []
for _ in range(self.max_batch_size):
key = await self._redis_exec(
self.redis.brpop, ["pending_messages"], timeout=1
)
if key:
# key 是 (queue_name, value) 的元组
value = key[1]
if isinstance(value, bytes):
value = value.decode('utf-8')
message_keys.append(value)
else:
break
if message_keys:
await self._process_message_batch(message_keys)
else:
await asyncio.sleep(self.batch_interval)
except Exception as e:
logger.error(f"Error in batch persistence: {e}")
await asyncio.sleep(1)
logger.info("Stopped batch persistence to database")
async def _process_message_batch(self, message_keys: List[str]):
"""处理消息批次"""
async with with_db() as session:
for key in message_keys:
try:
# 解析频道ID和消息ID
channel_id, message_id = map(int, key.split(':'))
# 从 Redis 获取消息数据
raw_data = await self._redis_exec(
self.redis.hgetall, f"msg:{channel_id}:{message_id}"
)
if not raw_data:
continue
# 解码数据
message_data = {}
for k, v in raw_data.items():
if isinstance(k, bytes):
k = k.decode('utf-8')
if isinstance(v, bytes):
v = v.decode('utf-8')
message_data[k] = v
# 检查消息是否已存在于数据库
existing = await session.get(ChatMessage, int(message_id))
if existing:
continue
# 创建数据库消息 - 使用 Redis 生成的正数ID
db_message = ChatMessage(
message_id=int(message_id), # 使用 Redis 系统生成的正数ID
channel_id=int(message_data["channel_id"]),
sender_id=int(message_data["sender_id"]),
content=message_data["content"],
timestamp=datetime.fromisoformat(message_data["timestamp"]),
type=MessageType(message_data["type"]),
uuid=message_data.get("uuid") or None
)
session.add(db_message)
# 更新 Redis 中的状态
await self._redis_exec(
self.redis.hset,
f"msg:{channel_id}:{message_id}",
"status", "persisted"
)
logger.debug(f"Message {message_id} persisted to database")
except Exception as e:
logger.error(f"Failed to process message {key}: {e}")
# 提交批次
try:
await session.commit()
logger.info(f"Batch of {len(message_keys)} messages committed to database")
except Exception as e:
logger.error(f"Failed to commit message batch: {e}")
await session.rollback()
def start(self):
"""启动系统"""
if not self._running:
self._running = True
self._batch_timer = asyncio.create_task(self._batch_persist_to_database())
# 启动时初始化消息ID计数器
asyncio.create_task(self._initialize_message_counter())
logger.info("Redis message system started")
async def _initialize_message_counter(self):
"""初始化全局消息ID计数器确保从数据库最大ID开始"""
try:
# 清理可能存在的问题键
await self._cleanup_redis_keys()
async with with_db() as session:
from sqlmodel import select, func
# 获取数据库中最大的消息ID
result = await session.exec(
select(func.max(ChatMessage.message_id))
)
max_id = result.one() or 0
# 检查 Redis 中的计数器值
current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter")
current_counter = int(current_counter) if current_counter else 0
# 设置计数器为两者中的最大值
initial_counter = max(max_id, current_counter)
await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter)
logger.info(f"Initialized global message ID counter to {initial_counter}")
except Exception as e:
logger.error(f"Failed to initialize message counter: {e}")
# 如果初始化失败,设置一个安全的起始值
await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000)
async def _cleanup_redis_keys(self):
"""清理可能存在问题的 Redis 键"""
try:
# 扫描所有 channel:*:messages 键并检查类型
keys_pattern = "channel:*:messages"
keys = await self._redis_exec(self.redis.keys, keys_pattern)
for key in keys:
if isinstance(key, bytes):
key = key.decode('utf-8')
try:
key_type = await self._redis_exec(self.redis.type, key)
if key_type and key_type != "zset":
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}")
await self._redis_exec(self.redis.delete, key)
except Exception as cleanup_error:
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")
# 强制删除问题键
await self._redis_exec(self.redis.delete, key)
logger.info("Redis keys cleanup completed")
except Exception as e:
logger.error(f"Failed to cleanup Redis keys: {e}")
def stop(self):
"""停止系统"""
if self._running:
self._running = False
if self._batch_timer:
self._batch_timer.cancel()
self._batch_timer = None
logger.info("Redis message system stopped")
def __del__(self):
"""清理资源"""
if hasattr(self, 'executor'):
self.executor.shutdown(wait=False)
# 全局消息系统实例
redis_message_system = RedisMessageSystem()

View File

@@ -32,6 +32,7 @@ from app.service.init_geoip import init_geoip
from app.service.load_achievements import load_achievements
from app.service.osu_rx_statistics import create_rx_statistics
from app.service.recalculate import recalculate
from app.service.redis_message_system import redis_message_system
# 检查 New Relic 配置文件是否存在,如果存在则初始化 New Relic
newrelic_config_path = os.path.join(os.path.dirname(__file__), "newrelic.ini")
@@ -77,10 +78,12 @@ async def lifespan(app: FastAPI):
await create_banchobot()
await download_service.start_health_check() # 启动下载服务健康检查
await start_cache_scheduler() # 启动缓存调度器
redis_message_system.start() # 启动 Redis 消息系统
load_achievements()
# on shutdown
yield
stop_scheduler()
redis_message_system.stop() # 停止 Redis 消息系统
await stop_cache_scheduler() # 停止缓存调度器
await download_service.stop_health_check() # 停止下载服务健康检查
await engine.dispose()