add message redis
This commit is contained in:
@@ -59,12 +59,17 @@ class ChatChannel(ChatChannelBase, table=True):
|
||||
cls, channel: str | int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
if isinstance(channel, int) or channel.isdigit():
|
||||
channel_ = await session.get(ChatChannel, channel)
|
||||
# 使用查询而不是 get() 来确保对象完全加载
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
channel_ = result.first()
|
||||
if channel_ is not None:
|
||||
return channel_
|
||||
return (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
return result.first()
|
||||
|
||||
@classmethod
|
||||
async def get_pm_channel(
|
||||
@@ -235,6 +240,7 @@ class UserSilenceResp(SQLModel):
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp":
|
||||
assert db_silence.id is not None
|
||||
return cls(
|
||||
id=db_silence.id,
|
||||
user_id=db_silence.user_id,
|
||||
|
||||
@@ -11,6 +11,7 @@ from app.config import settings
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
import redis.asyncio as redis
|
||||
import redis as sync_redis
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -38,6 +39,9 @@ engine = create_async_engine(
|
||||
# Redis 连接
|
||||
redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||
|
||||
# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行
|
||||
redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1)
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
db_session_context: ContextVar[AsyncSession | None] = ContextVar(
|
||||
@@ -80,5 +84,10 @@ def get_redis():
|
||||
return redis_client
|
||||
|
||||
|
||||
def get_redis_message():
|
||||
"""获取消息专用的 Redis 客户端 (db1)"""
|
||||
return redis_message_client
|
||||
|
||||
|
||||
def get_redis_pubsub():
|
||||
return redis_client.pubsub()
|
||||
|
||||
@@ -151,7 +151,6 @@ class PlaylistItem(BaseModel):
|
||||
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
for i, mod1 in enumerate(mods):
|
||||
@@ -168,7 +167,6 @@ class PlaylistItem(BaseModel):
|
||||
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
allowed_acronyms = {mod["acronym"] for mod in self.allowed_mods}
|
||||
|
||||
@@ -213,8 +211,6 @@ class PlaylistItem(BaseModel):
|
||||
"""
|
||||
from typing import Literal, cast
|
||||
|
||||
API_MODS = self._get_api_mods()
|
||||
|
||||
ruleset_id = user.ruleset_id if user.ruleset_id is not None else self.ruleset_id
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_id)
|
||||
|
||||
@@ -386,7 +382,7 @@ class MultiplayerRoom(BaseModel):
|
||||
match_state=None,
|
||||
playlist=playlist,
|
||||
active_countdowns=[],
|
||||
channel_id=room.channel_id,
|
||||
channel_id=room.channel_id or 0,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -53,16 +53,24 @@ async def get_update(
|
||||
assert current_user.id
|
||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||
for channel_id in channel_ids:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
if db_channel:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_type = db_channel.type
|
||||
|
||||
resp.presence.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
db_channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
@@ -105,7 +113,19 @@ async def join_channel(
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -125,7 +145,19 @@ async def leave_channel(
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -152,15 +184,19 @@ async def get_channel_list(
|
||||
).all()
|
||||
results = []
|
||||
for channel in channels:
|
||||
assert channel.channel_id is not None
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
channel_type = channel.type
|
||||
|
||||
assert channel_id is not None
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel.channel_id, [])
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
@@ -185,14 +221,33 @@ async def get_channel(
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
assert db_channel.channel_id is not None
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
assert channel_id is not None
|
||||
|
||||
users = []
|
||||
if db_channel.type == ChannelType.PM:
|
||||
user_ids = db_channel.name.split("_")[1:]
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
if len(user_ids) != 2:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
for id_ in user_ids:
|
||||
@@ -210,8 +265,8 @@ async def get_channel(
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(db_channel.channel_id, [])
|
||||
if db_channel.type != ChannelType.PUBLIC
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
)
|
||||
)
|
||||
@@ -270,7 +325,8 @@ async def create_channel(
|
||||
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
||||
else:
|
||||
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
||||
channel = await ChatChannel.get(channel_name, session)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
|
||||
channel = result.first()
|
||||
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
@@ -294,12 +350,16 @@ async def create_channel(
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
assert channel.channel_id
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id
|
||||
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel.channel_id, []),
|
||||
server.channels.get(channel_id, []),
|
||||
include_recent_messages=True,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
@@ -11,11 +16,14 @@ from app.database.chat import (
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.database import Database, get_redis, get_redis_message
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.notification import ChannelMessage, ChannelMessageTeam
|
||||
from app.router.v2 import api_v2_router as router
|
||||
from app.service.optimized_message import optimized_message_service
|
||||
from app.service.redis_message_system import redis_message_system
|
||||
from app.log import logger
|
||||
|
||||
from .banchobot import bot
|
||||
from .server import server
|
||||
@@ -89,42 +97,73 @@ async def send_message(
|
||||
req: MessageReq = Depends(BodyOrForm(MessageReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
# 使用明确的查询来获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
assert db_channel.channel_id
|
||||
# 立即提取所有需要的属性,避免后续延迟加载
|
||||
channel_id = db_channel.channel_id
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
assert channel_id is not None
|
||||
assert current_user.id
|
||||
msg = ChatMessage(
|
||||
channel_id=db_channel.channel_id,
|
||||
|
||||
# 使用 Redis 消息系统发送消息 - 立即返回
|
||||
resp = await redis_message_system.send_message(
|
||||
channel_id=channel_id,
|
||||
user=current_user,
|
||||
content=req.message,
|
||||
sender_id=current_user.id,
|
||||
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||
uuid=req.uuid,
|
||||
is_action=req.is_action,
|
||||
user_uuid=req.uuid
|
||||
)
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
await session.refresh(current_user)
|
||||
await session.refresh(db_channel)
|
||||
resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||
|
||||
# 立即广播消息给所有客户端
|
||||
is_bot_command = req.message.startswith("!")
|
||||
await server.send_message_to_channel(
|
||||
resp, is_bot_command and db_channel.type == ChannelType.PUBLIC
|
||||
resp, is_bot_command and channel_type == ChannelType.PUBLIC
|
||||
)
|
||||
|
||||
# 处理机器人命令
|
||||
if is_bot_command:
|
||||
await bot.try_handle(current_user, db_channel, req.message, session)
|
||||
if db_channel.type == ChannelType.PM:
|
||||
user_ids = db_channel.name.split("_")[1:]
|
||||
await server.new_private_notification(
|
||||
ChannelMessage.init(
|
||||
msg, current_user, [int(u) for u in user_ids], db_channel.type
|
||||
|
||||
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
|
||||
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
|
||||
temp_msg = ChatMessage(
|
||||
message_id=resp.message_id, # 使用 Redis 系统生成的ID
|
||||
channel_id=channel_id,
|
||||
content=req.message,
|
||||
sender_id=current_user.id,
|
||||
type=MessageType.ACTION if req.is_action else MessageType.PLAIN,
|
||||
uuid=req.uuid,
|
||||
)
|
||||
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
await server.new_private_notification(
|
||||
ChannelMessage.init(
|
||||
temp_msg, current_user, [int(u) for u in user_ids], channel_type
|
||||
)
|
||||
)
|
||||
)
|
||||
elif db_channel.type == ChannelType.TEAM:
|
||||
await server.new_private_notification(
|
||||
ChannelMessageTeam.init(msg, current_user)
|
||||
)
|
||||
elif channel_type == ChannelType.TEAM:
|
||||
await server.new_private_notification(
|
||||
ChannelMessageTeam.init(temp_msg, current_user)
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@@ -143,21 +182,46 @@ async def get_message(
|
||||
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
# 使用明确的查询获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
messages = await session.exec(
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
ChatMessage.channel_id == db_channel.channel_id,
|
||||
col(ChatMessage.message_id) > since,
|
||||
col(ChatMessage.message_id) < until if until is not None else True,
|
||||
)
|
||||
.order_by(col(ChatMessage.timestamp).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = db_channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
# 使用 Redis 消息系统获取消息
|
||||
try:
|
||||
messages = await redis_message_system.get_messages(channel_id, limit, since)
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get messages from Redis system: {e}")
|
||||
# 回退到传统数据库查询
|
||||
pass
|
||||
|
||||
# 回退到数据库查询
|
||||
query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
|
||||
if since > 0:
|
||||
query = query.where(col(ChatMessage.message_id) > since)
|
||||
if until is not None:
|
||||
query = query.where(col(ChatMessage.message_id) < until)
|
||||
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
messages = (await session.exec(query)).all()
|
||||
resp = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
resp.reverse()
|
||||
return resp
|
||||
|
||||
|
||||
@@ -174,12 +238,28 @@ async def mark_as_read(
|
||||
message: int = Path(..., description="消息 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
# 使用明确的查询获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
assert db_channel.channel_id
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
assert channel_id
|
||||
assert current_user.id
|
||||
await server.mark_as_read(db_channel.channel_id, current_user.id, message)
|
||||
await server.mark_as_read(channel_id, current_user.id, message)
|
||||
|
||||
|
||||
class PMReq(BaseModel):
|
||||
|
||||
@@ -59,9 +59,14 @@ class ChatServer:
|
||||
for channel_id, channel in self.channels.items():
|
||||
if user_id in channel:
|
||||
channel.remove(user_id)
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel:
|
||||
await self.leave_channel(user, channel, session)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
if db_channel:
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
|
||||
@overload
|
||||
async def send_event(self, client: int, event: ChatEvent): ...
|
||||
@@ -79,8 +84,11 @@ class ChatServer:
|
||||
await client.send_text(event.model_dump_json())
|
||||
|
||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||
for user_id in self.channels.get(channel_id, []):
|
||||
users_in_channel = self.channels.get(channel_id, [])
|
||||
logger.info(f"Broadcasting to channel {channel_id}, users: {users_in_channel}")
|
||||
for user_id in users_in_channel:
|
||||
await self.send_event(user_id, event)
|
||||
logger.debug(f"Sent event to user {user_id} in channel {channel_id}")
|
||||
|
||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||
@@ -88,24 +96,35 @@ class ChatServer:
|
||||
async def send_message_to_channel(
|
||||
self, message: ChatMessageResp, is_bot_command: bool = False
|
||||
):
|
||||
logger.info(f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}")
|
||||
|
||||
event = ChatEvent(
|
||||
event="chat.message.new",
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
)
|
||||
if is_bot_command:
|
||||
logger.info(f"Sending bot command to user {message.sender_id}")
|
||||
self._add_task(self.send_event(message.sender_id, event))
|
||||
else:
|
||||
# 总是广播消息,无论是临时ID还是真实ID
|
||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
)
|
||||
assert message.message_id
|
||||
await self.mark_as_read(
|
||||
message.channel_id, message.sender_id, message.message_id
|
||||
)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
|
||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||
if message.message_id and message.message_id > 0:
|
||||
await self.mark_as_read(
|
||||
message.channel_id, message.sender_id, message.message_id
|
||||
)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||||
else:
|
||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||||
|
||||
async def batch_join_channel(
|
||||
self, users: list[User], channel: ChatChannel, session: AsyncSession
|
||||
@@ -206,27 +225,37 @@ class ChatServer:
|
||||
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
if db_channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.join_channel(user, channel, session)
|
||||
await self.join_channel(user, db_channel, session)
|
||||
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
if db_channel is None:
|
||||
return
|
||||
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
return
|
||||
|
||||
await self.leave_channel(user, channel, session)
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
|
||||
async def new_private_notification(self, detail: NotificationDetail):
|
||||
async with with_db() as session:
|
||||
@@ -309,7 +338,13 @@ async def chat_websocket(
|
||||
user_id = user.id
|
||||
assert user_id
|
||||
server.connect(user_id, websocket)
|
||||
channel = await ChatChannel.get(1, session)
|
||||
if channel is not None:
|
||||
await server.join_channel(user, channel, session)
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == 1)
|
||||
)
|
||||
).first()
|
||||
if db_channel is not None:
|
||||
await server.join_channel(user, db_channel, session)
|
||||
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
|
||||
217
app/service/message_queue.py
Normal file
217
app/service/message_queue.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Redis 消息队列服务
|
||||
用于实现实时消息推送和异步数据库持久化
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
import concurrent.futures
|
||||
|
||||
from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
"""Redis 消息队列服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.redis = get_redis()
|
||||
self._processing = False
|
||||
self._batch_size = 50 # 批量处理大小
|
||||
self._batch_timeout = 1.0 # 批量处理超时时间(秒)
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
async def _run_in_executor(self, func, *args):
|
||||
"""在线程池中运行同步 Redis 操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self._executor, func, *args)
|
||||
|
||||
async def start_processing(self):
|
||||
"""启动消息处理任务"""
|
||||
if not self._processing:
|
||||
self._processing = True
|
||||
asyncio.create_task(self._process_message_queue())
|
||||
logger.info("Message queue processing started")
|
||||
|
||||
async def stop_processing(self):
|
||||
"""停止消息处理"""
|
||||
self._processing = False
|
||||
logger.info("Message queue processing stopped")
|
||||
|
||||
async def enqueue_message(self, message_data: dict) -> str:
|
||||
"""
|
||||
将消息加入 Redis 队列(实时响应)
|
||||
|
||||
Args:
|
||||
message_data: 消息数据字典,包含所有必要的字段
|
||||
|
||||
Returns:
|
||||
消息的临时 UUID
|
||||
"""
|
||||
# 生成临时 UUID
|
||||
temp_uuid = str(uuid.uuid4())
|
||||
message_data["temp_uuid"] = temp_uuid
|
||||
message_data["timestamp"] = datetime.now().isoformat()
|
||||
message_data["status"] = "pending" # pending, processing, completed, failed
|
||||
|
||||
# 将消息存储到 Redis
|
||||
await self._run_in_executor(
|
||||
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data)
|
||||
)
|
||||
await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
|
||||
|
||||
# 加入处理队列
|
||||
await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid)
|
||||
|
||||
logger.info(f"Message enqueued with temp_uuid: {temp_uuid}")
|
||||
return temp_uuid
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> Optional[dict]:
|
||||
"""获取消息状态"""
|
||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
|
||||
if not message_data:
|
||||
return None
|
||||
|
||||
return message_data
|
||||
|
||||
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
||||
"""
|
||||
从 Redis 获取缓存的消息
|
||||
|
||||
Args:
|
||||
channel_id: 频道 ID
|
||||
limit: 限制数量
|
||||
since: 获取自此消息 ID 之后的消息
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
# 从 Redis 获取频道最近的消息 UUID 列表
|
||||
message_uuids = await self._run_in_executor(
|
||||
self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1
|
||||
)
|
||||
|
||||
messages = []
|
||||
for uuid_str in message_uuids:
|
||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}")
|
||||
if message_data:
|
||||
# 检查是否满足 since 条件
|
||||
if since > 0 and "message_id" in message_data:
|
||||
if int(message_data["message_id"]) <= since:
|
||||
continue
|
||||
|
||||
messages.append(message_data)
|
||||
|
||||
return messages[::-1] # 返回时间顺序
|
||||
|
||||
async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100):
|
||||
"""将消息 UUID 缓存到频道消息列表"""
|
||||
# 添加到频道消息列表开头
|
||||
await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid)
|
||||
# 限制缓存大小
|
||||
await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1)
|
||||
# 设置过期时间(24小时)
|
||||
await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400)
|
||||
|
||||
async def _process_message_queue(self):
|
||||
"""异步处理消息队列,批量写入数据库"""
|
||||
while self._processing:
|
||||
try:
|
||||
# 批量获取消息
|
||||
message_uuids = []
|
||||
for _ in range(self._batch_size):
|
||||
result = await self._run_in_executor(
|
||||
lambda: self.redis.brpop(["message_queue"], timeout=1)
|
||||
)
|
||||
if result:
|
||||
message_uuids.append(result[1])
|
||||
else:
|
||||
break
|
||||
|
||||
if message_uuids:
|
||||
await self._process_message_batch(message_uuids)
|
||||
else:
|
||||
# 没有消息时短暂等待
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message queue: {e}")
|
||||
await asyncio.sleep(1) # 错误时等待1秒再重试
|
||||
|
||||
async def _process_message_batch(self, message_uuids: list[str]):
|
||||
"""批量处理消息写入数据库"""
|
||||
async with with_db() as session:
|
||||
messages_to_insert = []
|
||||
|
||||
for temp_uuid in message_uuids:
|
||||
try:
|
||||
# 获取消息数据
|
||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
|
||||
if not message_data:
|
||||
continue
|
||||
|
||||
# 更新状态为处理中
|
||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing")
|
||||
|
||||
# 创建数据库消息对象
|
||||
msg = ChatMessage(
|
||||
channel_id=int(message_data["channel_id"]),
|
||||
content=message_data["content"],
|
||||
sender_id=int(message_data["sender_id"]),
|
||||
type=MessageType(message_data["type"]),
|
||||
uuid=message_data.get("user_uuid") # 用户提供的 UUID(如果有)
|
||||
)
|
||||
|
||||
messages_to_insert.append((msg, temp_uuid))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing message {temp_uuid}: {e}")
|
||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
|
||||
|
||||
if messages_to_insert:
|
||||
try:
|
||||
# 批量插入数据库
|
||||
for msg, temp_uuid in messages_to_insert:
|
||||
session.add(msg)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# 更新所有消息状态和真实 ID
|
||||
for msg, temp_uuid in messages_to_insert:
|
||||
await session.refresh(msg)
|
||||
await self._run_in_executor(
|
||||
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping={
|
||||
"status": "completed",
|
||||
"message_id": str(msg.message_id),
|
||||
"created_at": msg.timestamp.isoformat() if msg.timestamp else ""
|
||||
})
|
||||
)
|
||||
|
||||
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting messages to database: {e}")
|
||||
await session.rollback()
|
||||
|
||||
# 标记所有消息为失败
|
||||
for _, temp_uuid in messages_to_insert:
|
||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
|
||||
|
||||
|
||||
# 全局消息队列实例
|
||||
message_queue = MessageQueue()
|
||||
|
||||
|
||||
async def start_message_queue():
|
||||
"""启动消息队列处理"""
|
||||
await message_queue.start_processing()
|
||||
|
||||
|
||||
async def stop_message_queue():
|
||||
"""停止消息队列处理"""
|
||||
await message_queue.stop_processing()
|
||||
282
app/service/message_queue_processor.py
Normal file
282
app/service/message_queue_processor.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
消息队列处理服务
|
||||
专门处理 Redis 消息队列的异步写入数据库
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.database.chat import ChatMessage, MessageType
|
||||
from app.dependencies.database import get_redis_message, with_db
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class MessageQueueProcessor:
|
||||
"""消息队列处理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.redis_message = get_redis_message()
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
self._processing = False
|
||||
self._queue_task = None
|
||||
|
||||
async def _redis_exec(self, func, *args, **kwargs):
|
||||
"""在线程池中执行 Redis 操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
|
||||
|
||||
async def cache_message(self, channel_id: int, message_data: dict, temp_uuid: str):
|
||||
"""将消息缓存到 Redis"""
|
||||
try:
|
||||
# 存储消息数据
|
||||
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data)
|
||||
await self._redis_exec(self.redis_message.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
|
||||
|
||||
# 加入频道消息列表
|
||||
await self._redis_exec(self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid)
|
||||
await self._redis_exec(self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99) # 保持最新100条
|
||||
await self._redis_exec(self.redis_message.expire, f"channel:{channel_id}:messages", 86400) # 24小时过期
|
||||
|
||||
# 加入异步处理队列
|
||||
await self._redis_exec(self.redis_message.lpush, "message_write_queue", temp_uuid)
|
||||
|
||||
logger.info(f"Message cached to Redis: {temp_uuid}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cache message to Redis: {e}")
|
||||
|
||||
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
||||
"""从 Redis 获取缓存的消息"""
|
||||
try:
|
||||
message_uuids = await self._redis_exec(
|
||||
self.redis_message.lrange, f"channel:{channel_id}:messages", 0, limit - 1
|
||||
)
|
||||
|
||||
messages = []
|
||||
for temp_uuid in message_uuids:
|
||||
# 解码 UUID 如果它是字节类型
|
||||
if isinstance(temp_uuid, bytes):
|
||||
temp_uuid = temp_uuid.decode('utf-8')
|
||||
|
||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
||||
if raw_data:
|
||||
# 解码 Redis 返回的字节数据
|
||||
message_data = {
|
||||
k.decode('utf-8') if isinstance(k, bytes) else k:
|
||||
v.decode('utf-8') if isinstance(v, bytes) else v
|
||||
for k, v in raw_data.items()
|
||||
}
|
||||
|
||||
# 检查 since 条件
|
||||
if since > 0 and message_data.get("message_id"):
|
||||
if int(message_data["message_id"]) <= since:
|
||||
continue
|
||||
messages.append(message_data)
|
||||
|
||||
return messages[::-1] # 按时间顺序返回
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cached messages: {e}")
|
||||
return []
|
||||
|
||||
async def update_message_status(self, temp_uuid: str, status: str, message_id: Optional[int] = None):
|
||||
"""更新消息状态"""
|
||||
try:
|
||||
update_data = {"status": status}
|
||||
if message_id:
|
||||
update_data["message_id"] = str(message_id)
|
||||
update_data["db_timestamp"] = datetime.now().isoformat()
|
||||
|
||||
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update message status: {e}")
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> Optional[dict]:
|
||||
"""获取消息状态"""
|
||||
try:
|
||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
||||
if not raw_data:
|
||||
return None
|
||||
|
||||
# 解码 Redis 返回的字节数据
|
||||
return {
|
||||
k.decode('utf-8') if isinstance(k, bytes) else k:
|
||||
v.decode('utf-8') if isinstance(v, bytes) else v
|
||||
for k, v in raw_data.items()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get message status: {e}")
|
||||
return None
|
||||
|
||||
async def _process_message_queue(self):
|
||||
"""处理消息队列,异步写入数据库"""
|
||||
logger.info("Message queue processing started")
|
||||
|
||||
while self._processing:
|
||||
try:
|
||||
# 批量获取消息
|
||||
message_uuids = []
|
||||
for _ in range(20): # 批量处理20条消息
|
||||
result = await self._redis_exec(
|
||||
self.redis_message.brpop, ["message_write_queue"], timeout=1
|
||||
)
|
||||
if result:
|
||||
# result是 (queue_name, value) 的元组,需要解码
|
||||
uuid_value = result[1]
|
||||
if isinstance(uuid_value, bytes):
|
||||
uuid_value = uuid_value.decode('utf-8')
|
||||
message_uuids.append(uuid_value)
|
||||
else:
|
||||
break
|
||||
|
||||
if not message_uuids:
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
|
||||
# 批量写入数据库
|
||||
await self._process_message_batch(message_uuids)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in message queue processing: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info("Message queue processing stopped")
|
||||
|
||||
async def _process_message_batch(self, message_uuids: list[str]):
|
||||
"""批量处理消息写入数据库"""
|
||||
async with with_db() as session:
|
||||
for temp_uuid in message_uuids:
|
||||
try:
|
||||
# 获取消息数据并解码
|
||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
||||
if not raw_data:
|
||||
continue
|
||||
|
||||
# 解码 Redis 返回的字节数据
|
||||
message_data = {
|
||||
k.decode('utf-8') if isinstance(k, bytes) else k:
|
||||
v.decode('utf-8') if isinstance(v, bytes) else v
|
||||
for k, v in raw_data.items()
|
||||
}
|
||||
|
||||
if message_data.get("status") != "pending":
|
||||
continue
|
||||
|
||||
# 更新状态为处理中
|
||||
await self.update_message_status(temp_uuid, "processing")
|
||||
|
||||
# 创建数据库消息
|
||||
msg = ChatMessage(
|
||||
channel_id=int(message_data["channel_id"]),
|
||||
content=message_data["content"],
|
||||
sender_id=int(message_data["sender_id"]),
|
||||
type=MessageType(message_data["type"]),
|
||||
uuid=message_data.get("user_uuid") or None,
|
||||
)
|
||||
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
|
||||
# 更新成功状态,包含临时消息ID映射
|
||||
assert msg.message_id is not None
|
||||
await self.update_message_status(temp_uuid, "completed", msg.message_id)
|
||||
|
||||
# 如果有临时消息ID,存储映射关系并通知客户端更新
|
||||
if message_data.get("temp_message_id"):
|
||||
temp_msg_id = int(message_data["temp_message_id"])
|
||||
await self._redis_exec(
|
||||
self.redis_message.set,
|
||||
f"temp_to_real:{temp_msg_id}",
|
||||
str(msg.message_id),
|
||||
ex=3600 # 1小时过期
|
||||
)
|
||||
|
||||
# 发送消息ID更新通知到频道
|
||||
channel_id = int(message_data["channel_id"])
|
||||
await self._notify_message_update(channel_id, temp_msg_id, msg.message_id, message_data)
|
||||
|
||||
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process message {temp_uuid}: {e}")
|
||||
await self.update_message_status(temp_uuid, "failed")
|
||||
|
||||
async def _notify_message_update(self, channel_id: int, temp_message_id: int, real_message_id: int, message_data: dict):
|
||||
"""通知客户端消息ID已更新"""
|
||||
try:
|
||||
# 这里我们需要通过 SignalR 发送消息更新通知
|
||||
# 但为了避免循环依赖,我们将通过 Redis 发布消息更新事件
|
||||
update_event = {
|
||||
"event": "chat.message.update",
|
||||
"data": {
|
||||
"channel_id": channel_id,
|
||||
"temp_message_id": temp_message_id,
|
||||
"real_message_id": real_message_id,
|
||||
"timestamp": message_data.get("timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
# 发布到 Redis 频道,让 SignalR 服务处理
|
||||
await self._redis_exec(
|
||||
self.redis_message.publish,
|
||||
f"chat_updates:{channel_id}",
|
||||
json.dumps(update_event)
|
||||
)
|
||||
|
||||
logger.info(f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to notify message update: {e}")
|
||||
|
||||
def start_processing(self):
|
||||
"""启动消息队列处理"""
|
||||
if not self._processing:
|
||||
self._processing = True
|
||||
self._queue_task = asyncio.create_task(self._process_message_queue())
|
||||
logger.info("Message queue processor started")
|
||||
|
||||
def stop_processing(self):
|
||||
"""停止消息队列处理"""
|
||||
if self._processing:
|
||||
self._processing = False
|
||||
if self._queue_task:
|
||||
self._queue_task.cancel()
|
||||
self._queue_task = None
|
||||
logger.info("Message queue processor stopped")
|
||||
|
||||
def __del__(self):
|
||||
"""清理资源"""
|
||||
if hasattr(self, 'executor'):
|
||||
self.executor.shutdown(wait=False)
|
||||
|
||||
|
||||
# 全局消息队列处理器实例
|
||||
message_queue_processor = MessageQueueProcessor()
|
||||
|
||||
|
||||
def start_message_processing():
|
||||
"""启动消息队列处理"""
|
||||
message_queue_processor.start_processing()
|
||||
|
||||
|
||||
def stop_message_processing():
|
||||
"""停止消息队列处理"""
|
||||
message_queue_processor.stop_processing()
|
||||
|
||||
|
||||
async def cache_message_to_redis(channel_id: int, message_data: dict, temp_uuid: str):
|
||||
"""将消息缓存到 Redis - 便捷接口"""
|
||||
await message_queue_processor.cache_message(channel_id, message_data, temp_uuid)
|
||||
|
||||
|
||||
async def get_cached_messages(channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
||||
"""从 Redis 获取缓存的消息 - 便捷接口"""
|
||||
return await message_queue_processor.get_cached_messages(channel_id, limit, since)
|
||||
|
||||
|
||||
async def get_message_status(temp_uuid: str) -> Optional[dict]:
|
||||
"""获取消息状态 - 便捷接口"""
|
||||
return await message_queue_processor.get_message_status(temp_uuid)
|
||||
150
app/service/optimized_message.py
Normal file
150
app/service/optimized_message.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
优化的消息服务
|
||||
结合 Redis 缓存和异步数据库写入实现实时消息传送
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.database.chat import ChatMessage, ChatChannel, MessageType, ChannelType, ChatMessageResp
|
||||
from app.database.lazer_user import User
|
||||
from app.router.notification.server import server
|
||||
from app.service.message_queue import message_queue
|
||||
from app.log import logger
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class OptimizedMessageService:
|
||||
"""优化的消息服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.message_queue = message_queue
|
||||
|
||||
async def send_message_fast(
|
||||
self,
|
||||
channel_id: int,
|
||||
channel_type: ChannelType,
|
||||
channel_name: str,
|
||||
content: str,
|
||||
sender: User,
|
||||
is_action: bool = False,
|
||||
user_uuid: Optional[str] = None,
|
||||
session: Optional[AsyncSession] = None
|
||||
) -> ChatMessageResp:
|
||||
"""
|
||||
快速发送消息(先缓存到 Redis,异步写入数据库)
|
||||
|
||||
Args:
|
||||
channel_id: 频道 ID
|
||||
channel_type: 频道类型
|
||||
channel_name: 频道名称
|
||||
content: 消息内容
|
||||
sender: 发送者
|
||||
is_action: 是否为动作消息
|
||||
user_uuid: 用户提供的 UUID
|
||||
session: 数据库会话(可选,用于一些验证)
|
||||
|
||||
Returns:
|
||||
消息响应对象
|
||||
"""
|
||||
assert sender.id is not None
|
||||
|
||||
# 准备消息数据
|
||||
message_data = {
|
||||
"channel_id": str(channel_id),
|
||||
"content": content,
|
||||
"sender_id": str(sender.id),
|
||||
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
|
||||
"user_uuid": user_uuid or "",
|
||||
"channel_type": channel_type.value,
|
||||
"channel_name": channel_name
|
||||
}
|
||||
|
||||
# 立即将消息加入 Redis 队列(实时响应)
|
||||
temp_uuid = await self.message_queue.enqueue_message(message_data)
|
||||
|
||||
# 缓存到频道消息列表
|
||||
await self.message_queue.cache_channel_message(channel_id, temp_uuid)
|
||||
|
||||
# 创建临时响应对象(简化版本,用于立即响应)
|
||||
from datetime import datetime
|
||||
from app.database.lazer_user import UserResp
|
||||
|
||||
# 创建基本的用户响应对象
|
||||
user_resp = UserResp(
|
||||
id=sender.id,
|
||||
username=sender.username,
|
||||
country_code=getattr(sender, 'country_code', 'XX'),
|
||||
# 基本字段,其他复杂字段可以后续异步加载
|
||||
)
|
||||
|
||||
temp_response = ChatMessageResp(
|
||||
message_id=0, # 临时 ID,等数据库写入后会更新
|
||||
channel_id=channel_id,
|
||||
content=content,
|
||||
timestamp=datetime.now(),
|
||||
sender_id=sender.id,
|
||||
sender=user_resp,
|
||||
is_action=is_action,
|
||||
uuid=user_uuid
|
||||
)
|
||||
temp_response.temp_uuid = temp_uuid # 添加临时 UUID 用于后续更新
|
||||
|
||||
logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}")
|
||||
return temp_response
|
||||
|
||||
async def get_cached_messages(
|
||||
self,
|
||||
channel_id: int,
|
||||
limit: int = 50,
|
||||
since: int = 0
|
||||
) -> list[dict]:
|
||||
"""
|
||||
获取缓存的消息
|
||||
|
||||
Args:
|
||||
channel_id: 频道 ID
|
||||
limit: 限制数量
|
||||
since: 获取自此消息 ID 之后的消息
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return await self.message_queue.get_cached_messages(channel_id, limit, since)
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> Optional[dict]:
|
||||
"""
|
||||
获取消息状态
|
||||
|
||||
Args:
|
||||
temp_uuid: 临时消息 UUID
|
||||
|
||||
Returns:
|
||||
消息状态信息
|
||||
"""
|
||||
return await self.message_queue.get_message_status(temp_uuid)
|
||||
|
||||
async def wait_for_message_persisted(self, temp_uuid: str, timeout: int = 30) -> Optional[dict]:
|
||||
"""
|
||||
等待消息持久化到数据库
|
||||
|
||||
Args:
|
||||
temp_uuid: 临时消息 UUID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
完成后的消息状态
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
for _ in range(timeout * 10): # 每100ms检查一次
|
||||
status = await self.get_message_status(temp_uuid)
|
||||
if status and status.get("status") in ["completed", "failed"]:
|
||||
return status
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# 全局优化消息服务实例
|
||||
optimized_message_service = OptimizedMessageService()
|
||||
537
app/service/redis_message_system.py
Normal file
537
app/service/redis_message_system.py
Normal file
@@ -0,0 +1,537 @@
|
||||
"""
|
||||
基于 Redis 的实时消息系统
|
||||
- 消息立即存储到 Redis 并实时返回
|
||||
- 定时批量存储到数据库
|
||||
- 支持消息状态同步和故障恢复
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from app.database.chat import ChatMessage, MessageType, ChatMessageResp
|
||||
from app.database.lazer_user import User, UserResp, RANKING_INCLUDES
|
||||
from app.dependencies.database import get_redis_message, with_db
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class RedisMessageSystem:
|
||||
"""Redis 消息系统"""
|
||||
|
||||
def __init__(self):
|
||||
self.redis = get_redis_message()
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
self._batch_timer: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
self.batch_interval = 5.0 # 5秒批量存储一次
|
||||
self.max_batch_size = 100 # 每批最多处理100条消息
|
||||
|
||||
async def _redis_exec(self, func, *args, **kwargs):
|
||||
"""在线程池中执行 Redis 操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
|
||||
|
||||
async def send_message(self, channel_id: int, user: User, content: str,
|
||||
is_action: bool = False, user_uuid: Optional[str] = None) -> ChatMessageResp:
|
||||
"""
|
||||
发送消息 - 立即存储到 Redis 并返回
|
||||
|
||||
Args:
|
||||
channel_id: 频道ID
|
||||
user: 发送用户
|
||||
content: 消息内容
|
||||
is_action: 是否为动作消息
|
||||
user_uuid: 用户UUID
|
||||
|
||||
Returns:
|
||||
ChatMessageResp: 消息响应对象
|
||||
"""
|
||||
# 生成消息ID和时间戳
|
||||
message_id = await self._generate_message_id(channel_id)
|
||||
timestamp = datetime.now()
|
||||
|
||||
# 确保用户ID存在
|
||||
if not user.id:
|
||||
raise ValueError("User ID is required")
|
||||
|
||||
# 准备消息数据
|
||||
message_data = {
|
||||
"message_id": message_id,
|
||||
"channel_id": channel_id,
|
||||
"sender_id": user.id,
|
||||
"content": content,
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
|
||||
"uuid": user_uuid or "",
|
||||
"status": "cached", # Redis 缓存状态
|
||||
"created_at": time.time()
|
||||
}
|
||||
|
||||
# 立即存储到 Redis
|
||||
await self._store_to_redis(message_id, channel_id, message_data)
|
||||
|
||||
# 创建响应对象
|
||||
async with with_db() as session:
|
||||
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
|
||||
# 确保 statistics 不为空
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
user_resp.statistics = UserStatisticsResp(
|
||||
mode=user.playmode,
|
||||
global_rank=0,
|
||||
country_rank=0,
|
||||
pp=0.0,
|
||||
ranked_score=0,
|
||||
hit_accuracy=0.0,
|
||||
play_count=0,
|
||||
play_time=0,
|
||||
total_score=0,
|
||||
total_hits=0,
|
||||
maximum_combo=0,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
|
||||
level={"current": 1, "progress": 0}
|
||||
)
|
||||
|
||||
response = ChatMessageResp(
|
||||
message_id=message_id,
|
||||
channel_id=channel_id,
|
||||
content=content,
|
||||
timestamp=timestamp,
|
||||
sender_id=user.id,
|
||||
sender=user_resp,
|
||||
is_action=is_action,
|
||||
uuid=user_uuid
|
||||
)
|
||||
|
||||
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
||||
return response
|
||||
|
||||
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> List[ChatMessageResp]:
|
||||
"""
|
||||
获取频道消息 - 优先从 Redis 获取最新消息
|
||||
|
||||
Args:
|
||||
channel_id: 频道ID
|
||||
limit: 消息数量限制
|
||||
since: 起始消息ID
|
||||
|
||||
Returns:
|
||||
List[ChatMessageResp]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
|
||||
try:
|
||||
# 从 Redis 获取最新消息
|
||||
redis_messages = await self._get_from_redis(channel_id, limit, since)
|
||||
|
||||
# 为每条消息构建响应对象
|
||||
async with with_db() as session:
|
||||
for msg_data in redis_messages:
|
||||
# 获取发送者信息
|
||||
sender = await session.get(User, msg_data["sender_id"])
|
||||
if sender:
|
||||
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
|
||||
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
user_resp.statistics = UserStatisticsResp(
|
||||
mode=sender.playmode,
|
||||
global_rank=0, country_rank=0, pp=0.0,
|
||||
ranked_score=0, hit_accuracy=0.0, play_count=0,
|
||||
play_time=0, total_score=0, total_hits=0,
|
||||
maximum_combo=0, replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
|
||||
level={"current": 1, "progress": 0}
|
||||
)
|
||||
|
||||
message_resp = ChatMessageResp(
|
||||
message_id=msg_data["message_id"],
|
||||
channel_id=msg_data["channel_id"],
|
||||
content=msg_data["content"],
|
||||
timestamp=datetime.fromisoformat(msg_data["timestamp"]),
|
||||
sender_id=msg_data["sender_id"],
|
||||
sender=user_resp,
|
||||
is_action=msg_data["type"] == MessageType.ACTION.value,
|
||||
uuid=msg_data.get("uuid") or None
|
||||
)
|
||||
messages.append(message_resp)
|
||||
|
||||
# 如果 Redis 消息不够,从数据库补充
|
||||
if len(messages) < limit and since == 0:
|
||||
await self._backfill_from_database(channel_id, messages, limit)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages from Redis: {e}")
|
||||
# 回退到数据库查询
|
||||
messages = await self._get_from_database_only(channel_id, limit, since)
|
||||
|
||||
return messages[:limit]
|
||||
|
||||
async def _generate_message_id(self, channel_id: int) -> int:
|
||||
"""生成唯一的消息ID - 确保全局唯一且严格递增"""
|
||||
# 使用全局计数器确保所有频道的消息ID都是严格递增的
|
||||
message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter")
|
||||
|
||||
# 同时更新频道的最后消息ID,用于客户端状态同步
|
||||
await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id)
|
||||
|
||||
return message_id
|
||||
|
||||
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: Dict[str, Any]):
|
||||
"""存储消息到 Redis"""
|
||||
try:
|
||||
# 存储消息数据
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)
|
||||
for k, v in message_data.items()}
|
||||
)
|
||||
|
||||
# 设置消息过期时间(7天)
|
||||
await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800)
|
||||
|
||||
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
|
||||
channel_messages_key = f"channel:{channel_id}:messages"
|
||||
|
||||
# 检查键的类型,如果不是 zset 类型则删除
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, channel_messages_key)
|
||||
if key_type and key_type != "zset":
|
||||
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
except Exception as type_check_error:
|
||||
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
||||
# 如果检查失败,直接删除键以确保清理
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
|
||||
# 添加到频道消息列表(sorted set)
|
||||
await self._redis_exec(
|
||||
self.redis.zadd,
|
||||
channel_messages_key,
|
||||
{f"msg:{channel_id}:{message_id}": message_id}
|
||||
)
|
||||
|
||||
# 保持频道消息列表大小(最多1000条)
|
||||
await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001)
|
||||
|
||||
# 添加到待持久化队列
|
||||
await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store message to Redis: {e}")
|
||||
raise
|
||||
|
||||
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> List[Dict[str, Any]]:
|
||||
"""从 Redis 获取消息"""
|
||||
try:
|
||||
# 获取消息键列表,按消息ID排序
|
||||
if since > 0:
|
||||
# 获取指定ID之后的消息(正序)
|
||||
message_keys = await self._redis_exec(
|
||||
self.redis.zrangebyscore,
|
||||
f"channel:{channel_id}:messages",
|
||||
since + 1, "+inf",
|
||||
start=0, num=limit
|
||||
)
|
||||
else:
|
||||
# 获取最新的消息(倒序获取,然后反转)
|
||||
message_keys = await self._redis_exec(
|
||||
self.redis.zrevrange,
|
||||
f"channel:{channel_id}:messages",
|
||||
0, limit - 1
|
||||
)
|
||||
|
||||
messages = []
|
||||
for key in message_keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode('utf-8')
|
||||
|
||||
# 获取消息数据
|
||||
raw_data = await self._redis_exec(self.redis.hgetall, key)
|
||||
if raw_data:
|
||||
# 解码数据
|
||||
message_data = {}
|
||||
for k, v in raw_data.items():
|
||||
if isinstance(k, bytes):
|
||||
k = k.decode('utf-8')
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode('utf-8')
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
if k in ['grade_counts', 'level'] or v.startswith(('{', '[')):
|
||||
message_data[k] = json.loads(v)
|
||||
elif k in ['message_id', 'channel_id', 'sender_id']:
|
||||
message_data[k] = int(v)
|
||||
elif k == 'created_at':
|
||||
message_data[k] = float(v)
|
||||
else:
|
||||
message_data[k] = v
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
message_data[k] = v
|
||||
|
||||
messages.append(message_data)
|
||||
|
||||
# 确保消息按ID正序排序(时间顺序)
|
||||
messages.sort(key=lambda x: x.get('message_id', 0))
|
||||
|
||||
# 如果是获取最新消息(since=0),需要保持倒序(最新的在前面)
|
||||
if since == 0:
|
||||
messages.reverse()
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages from Redis: {e}")
|
||||
return []
|
||||
|
||||
async def _backfill_from_database(self, channel_id: int, existing_messages: List[ChatMessageResp], limit: int):
|
||||
"""从数据库补充历史消息"""
|
||||
try:
|
||||
# 找到最小的消息ID
|
||||
min_id = float('inf')
|
||||
if existing_messages:
|
||||
for msg in existing_messages:
|
||||
if msg.message_id is not None and msg.message_id < min_id:
|
||||
min_id = msg.message_id
|
||||
|
||||
needed = limit - len(existing_messages)
|
||||
|
||||
if needed <= 0:
|
||||
return
|
||||
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select, col
|
||||
query = select(ChatMessage).where(
|
||||
ChatMessage.channel_id == channel_id
|
||||
)
|
||||
|
||||
if min_id != float('inf'):
|
||||
query = query.where(col(ChatMessage.message_id) < min_id)
|
||||
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(needed)
|
||||
|
||||
db_messages = (await session.exec(query)).all()
|
||||
|
||||
for msg in reversed(db_messages): # 按时间正序插入
|
||||
msg_resp = await ChatMessageResp.from_db(msg, session)
|
||||
existing_messages.insert(0, msg_resp)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill from database: {e}")
|
||||
|
||||
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> List[ChatMessageResp]:
|
||||
"""仅从数据库获取消息(回退方案)"""
|
||||
try:
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select, col
|
||||
query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
|
||||
|
||||
if since > 0:
|
||||
# 获取指定ID之后的消息,按ID正序
|
||||
query = query.where(col(ChatMessage.message_id) > since)
|
||||
query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
||||
else:
|
||||
# 获取最新消息,按ID倒序(最新的在前面)
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
|
||||
messages = (await session.exec(query)).all()
|
||||
|
||||
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
|
||||
# 如果是 since > 0,保持正序;否则反转为时间正序
|
||||
if since == 0:
|
||||
results.reverse()
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages from database: {e}")
|
||||
return []
|
||||
|
||||
async def _batch_persist_to_database(self):
|
||||
"""批量持久化消息到数据库"""
|
||||
logger.info("Starting batch persistence to database")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 获取待处理的消息
|
||||
message_keys = []
|
||||
for _ in range(self.max_batch_size):
|
||||
key = await self._redis_exec(
|
||||
self.redis.brpop, ["pending_messages"], timeout=1
|
||||
)
|
||||
if key:
|
||||
# key 是 (queue_name, value) 的元组
|
||||
value = key[1]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
message_keys.append(value)
|
||||
else:
|
||||
break
|
||||
|
||||
if message_keys:
|
||||
await self._process_message_batch(message_keys)
|
||||
else:
|
||||
await asyncio.sleep(self.batch_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch persistence: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info("Stopped batch persistence to database")
|
||||
|
||||
async def _process_message_batch(self, message_keys: List[str]):
|
||||
"""处理消息批次"""
|
||||
async with with_db() as session:
|
||||
for key in message_keys:
|
||||
try:
|
||||
# 解析频道ID和消息ID
|
||||
channel_id, message_id = map(int, key.split(':'))
|
||||
|
||||
# 从 Redis 获取消息数据
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis.hgetall, f"msg:{channel_id}:{message_id}"
|
||||
)
|
||||
|
||||
if not raw_data:
|
||||
continue
|
||||
|
||||
# 解码数据
|
||||
message_data = {}
|
||||
for k, v in raw_data.items():
|
||||
if isinstance(k, bytes):
|
||||
k = k.decode('utf-8')
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode('utf-8')
|
||||
message_data[k] = v
|
||||
|
||||
# 检查消息是否已存在于数据库
|
||||
existing = await session.get(ChatMessage, int(message_id))
|
||||
if existing:
|
||||
continue
|
||||
|
||||
# 创建数据库消息 - 使用 Redis 生成的正数ID
|
||||
db_message = ChatMessage(
|
||||
message_id=int(message_id), # 使用 Redis 系统生成的正数ID
|
||||
channel_id=int(message_data["channel_id"]),
|
||||
sender_id=int(message_data["sender_id"]),
|
||||
content=message_data["content"],
|
||||
timestamp=datetime.fromisoformat(message_data["timestamp"]),
|
||||
type=MessageType(message_data["type"]),
|
||||
uuid=message_data.get("uuid") or None
|
||||
)
|
||||
|
||||
session.add(db_message)
|
||||
|
||||
# 更新 Redis 中的状态
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
"status", "persisted"
|
||||
)
|
||||
|
||||
logger.debug(f"Message {message_id} persisted to database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process message {key}: {e}")
|
||||
|
||||
# 提交批次
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(f"Batch of {len(message_keys)} messages committed to database")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to commit message batch: {e}")
|
||||
await session.rollback()
|
||||
|
||||
def start(self):
|
||||
"""启动系统"""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
self._batch_timer = asyncio.create_task(self._batch_persist_to_database())
|
||||
# 启动时初始化消息ID计数器
|
||||
asyncio.create_task(self._initialize_message_counter())
|
||||
logger.info("Redis message system started")
|
||||
|
||||
async def _initialize_message_counter(self):
|
||||
"""初始化全局消息ID计数器,确保从数据库最大ID开始"""
|
||||
try:
|
||||
# 清理可能存在的问题键
|
||||
await self._cleanup_redis_keys()
|
||||
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select, func
|
||||
|
||||
# 获取数据库中最大的消息ID
|
||||
result = await session.exec(
|
||||
select(func.max(ChatMessage.message_id))
|
||||
)
|
||||
max_id = result.one() or 0
|
||||
|
||||
# 检查 Redis 中的计数器值
|
||||
current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter")
|
||||
current_counter = int(current_counter) if current_counter else 0
|
||||
|
||||
# 设置计数器为两者中的最大值
|
||||
initial_counter = max(max_id, current_counter)
|
||||
await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter)
|
||||
|
||||
logger.info(f"Initialized global message ID counter to {initial_counter}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize message counter: {e}")
|
||||
# 如果初始化失败,设置一个安全的起始值
|
||||
await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000)
|
||||
|
||||
async def _cleanup_redis_keys(self):
|
||||
"""清理可能存在问题的 Redis 键"""
|
||||
try:
|
||||
# 扫描所有 channel:*:messages 键并检查类型
|
||||
keys_pattern = "channel:*:messages"
|
||||
keys = await self._redis_exec(self.redis.keys, keys_pattern)
|
||||
|
||||
for key in keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode('utf-8')
|
||||
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, key)
|
||||
if key_type and key_type != "zset":
|
||||
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}")
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")
|
||||
# 强制删除问题键
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
|
||||
logger.info("Redis keys cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup Redis keys: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""停止系统"""
|
||||
if self._running:
|
||||
self._running = False
|
||||
if self._batch_timer:
|
||||
self._batch_timer.cancel()
|
||||
self._batch_timer = None
|
||||
logger.info("Redis message system stopped")
|
||||
|
||||
def __del__(self):
|
||||
"""清理资源"""
|
||||
if hasattr(self, 'executor'):
|
||||
self.executor.shutdown(wait=False)
|
||||
|
||||
|
||||
# 全局消息系统实例
|
||||
redis_message_system = RedisMessageSystem()
|
||||
3
main.py
3
main.py
@@ -32,6 +32,7 @@ from app.service.init_geoip import init_geoip
|
||||
from app.service.load_achievements import load_achievements
|
||||
from app.service.osu_rx_statistics import create_rx_statistics
|
||||
from app.service.recalculate import recalculate
|
||||
from app.service.redis_message_system import redis_message_system
|
||||
|
||||
# 检查 New Relic 配置文件是否存在,如果存在则初始化 New Relic
|
||||
newrelic_config_path = os.path.join(os.path.dirname(__file__), "newrelic.ini")
|
||||
@@ -77,10 +78,12 @@ async def lifespan(app: FastAPI):
|
||||
await create_banchobot()
|
||||
await download_service.start_health_check() # 启动下载服务健康检查
|
||||
await start_cache_scheduler() # 启动缓存调度器
|
||||
redis_message_system.start() # 启动 Redis 消息系统
|
||||
load_achievements()
|
||||
# on shutdown
|
||||
yield
|
||||
stop_scheduler()
|
||||
redis_message_system.stop() # 停止 Redis 消息系统
|
||||
await stop_cache_scheduler() # 停止缓存调度器
|
||||
await download_service.stop_health_check() # 停止下载服务健康检查
|
||||
await engine.dispose()
|
||||
|
||||
Reference in New Issue
Block a user