整理代码
This commit is contained in:
@@ -5,59 +5,66 @@
|
||||
- 支持消息状态同步和故障恢复
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from app.database.chat import ChatMessage, MessageType, ChatMessageResp
|
||||
from app.database.lazer_user import User, UserResp, RANKING_INCLUDES
|
||||
from app.database.chat import ChatMessage, ChatMessageResp, MessageType
|
||||
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
|
||||
from app.dependencies.database import get_redis_message, with_db
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class RedisMessageSystem:
|
||||
"""Redis 消息系统"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.redis = get_redis_message()
|
||||
self.executor = ThreadPoolExecutor(max_workers=2)
|
||||
self._batch_timer: Optional[asyncio.Task] = None
|
||||
self._batch_timer: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self.batch_interval = 5.0 # 5秒批量存储一次
|
||||
self.max_batch_size = 100 # 每批最多处理100条消息
|
||||
|
||||
|
||||
async def _redis_exec(self, func, *args, **kwargs):
|
||||
"""在线程池中执行 Redis 操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, lambda: func(*args, **kwargs))
|
||||
|
||||
async def send_message(self, channel_id: int, user: User, content: str,
|
||||
is_action: bool = False, user_uuid: Optional[str] = None) -> ChatMessageResp:
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
channel_id: int,
|
||||
user: User,
|
||||
content: str,
|
||||
is_action: bool = False,
|
||||
user_uuid: str | None = None,
|
||||
) -> ChatMessageResp:
|
||||
"""
|
||||
发送消息 - 立即存储到 Redis 并返回
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: 频道ID
|
||||
user: 发送用户
|
||||
content: 消息内容
|
||||
is_action: 是否为动作消息
|
||||
user_uuid: 用户UUID
|
||||
|
||||
|
||||
Returns:
|
||||
ChatMessageResp: 消息响应对象
|
||||
"""
|
||||
# 生成消息ID和时间戳
|
||||
message_id = await self._generate_message_id(channel_id)
|
||||
timestamp = datetime.now()
|
||||
|
||||
|
||||
# 确保用户ID存在
|
||||
if not user.id:
|
||||
raise ValueError("User ID is required")
|
||||
|
||||
|
||||
# 准备消息数据
|
||||
message_data = {
|
||||
"message_id": message_id,
|
||||
@@ -68,19 +75,20 @@ class RedisMessageSystem:
|
||||
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
|
||||
"uuid": user_uuid or "",
|
||||
"status": "cached", # Redis 缓存状态
|
||||
"created_at": time.time()
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
|
||||
# 立即存储到 Redis
|
||||
await self._store_to_redis(message_id, channel_id, message_data)
|
||||
|
||||
|
||||
# 创建响应对象
|
||||
async with with_db() as session:
|
||||
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
|
||||
|
||||
# 确保 statistics 不为空
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
|
||||
user_resp.statistics = UserStatisticsResp(
|
||||
mode=user.playmode,
|
||||
global_rank=0,
|
||||
@@ -96,9 +104,9 @@ class RedisMessageSystem:
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
|
||||
level={"current": 1, "progress": 0}
|
||||
level={"current": 1, "progress": 0},
|
||||
)
|
||||
|
||||
|
||||
response = ChatMessageResp(
|
||||
message_id=message_id,
|
||||
channel_id=channel_id,
|
||||
@@ -107,51 +115,71 @@ class RedisMessageSystem:
|
||||
sender_id=user.id,
|
||||
sender=user_resp,
|
||||
is_action=is_action,
|
||||
uuid=user_uuid
|
||||
uuid=user_uuid,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Message {message_id} sent to Redis cache for channel {channel_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
||||
return response
|
||||
|
||||
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> List[ChatMessageResp]:
|
||||
|
||||
async def get_messages(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[ChatMessageResp]:
|
||||
"""
|
||||
获取频道消息 - 优先从 Redis 获取最新消息
|
||||
|
||||
|
||||
Args:
|
||||
channel_id: 频道ID
|
||||
limit: 消息数量限制
|
||||
since: 起始消息ID
|
||||
|
||||
|
||||
Returns:
|
||||
List[ChatMessageResp]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
|
||||
|
||||
try:
|
||||
# 从 Redis 获取最新消息
|
||||
redis_messages = await self._get_from_redis(channel_id, limit, since)
|
||||
|
||||
|
||||
# 为每条消息构建响应对象
|
||||
async with with_db() as session:
|
||||
for msg_data in redis_messages:
|
||||
# 获取发送者信息
|
||||
sender = await session.get(User, msg_data["sender_id"])
|
||||
if sender:
|
||||
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
|
||||
|
||||
user_resp = await UserResp.from_db(
|
||||
sender, session, RANKING_INCLUDES
|
||||
)
|
||||
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
|
||||
user_resp.statistics = UserStatisticsResp(
|
||||
mode=sender.playmode,
|
||||
global_rank=0, country_rank=0, pp=0.0,
|
||||
ranked_score=0, hit_accuracy=0.0, play_count=0,
|
||||
play_time=0, total_score=0, total_hits=0,
|
||||
maximum_combo=0, replays_watched_by_others=0,
|
||||
global_rank=0,
|
||||
country_rank=0,
|
||||
pp=0.0,
|
||||
ranked_score=0,
|
||||
hit_accuracy=0.0,
|
||||
play_count=0,
|
||||
play_time=0,
|
||||
total_score=0,
|
||||
total_hits=0,
|
||||
maximum_combo=0,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
|
||||
level={"current": 1, "progress": 0}
|
||||
grade_counts={
|
||||
"ssh": 0,
|
||||
"ss": 0,
|
||||
"sh": 0,
|
||||
"s": 0,
|
||||
"a": 0,
|
||||
},
|
||||
level={"current": 1, "progress": 0},
|
||||
)
|
||||
|
||||
|
||||
message_resp = ChatMessageResp(
|
||||
message_id=msg_data["message_id"],
|
||||
channel_id=msg_data["channel_id"],
|
||||
@@ -160,77 +188,97 @@ class RedisMessageSystem:
|
||||
sender_id=msg_data["sender_id"],
|
||||
sender=user_resp,
|
||||
is_action=msg_data["type"] == MessageType.ACTION.value,
|
||||
uuid=msg_data.get("uuid") or None
|
||||
uuid=msg_data.get("uuid") or None,
|
||||
)
|
||||
messages.append(message_resp)
|
||||
|
||||
|
||||
# 如果 Redis 消息不够,从数据库补充
|
||||
if len(messages) < limit and since == 0:
|
||||
await self._backfill_from_database(channel_id, messages, limit)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages from Redis: {e}")
|
||||
# 回退到数据库查询
|
||||
messages = await self._get_from_database_only(channel_id, limit, since)
|
||||
|
||||
|
||||
return messages[:limit]
|
||||
|
||||
|
||||
async def _generate_message_id(self, channel_id: int) -> int:
|
||||
"""生成唯一的消息ID - 确保全局唯一且严格递增"""
|
||||
# 使用全局计数器确保所有频道的消息ID都是严格递增的
|
||||
message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter")
|
||||
|
||||
message_id = await self._redis_exec(
|
||||
self.redis.incr, "global_message_id_counter"
|
||||
)
|
||||
|
||||
# 同时更新频道的最后消息ID,用于客户端状态同步
|
||||
await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id)
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.set, f"channel:{channel_id}:last_msg_id", message_id
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: Dict[str, Any]):
|
||||
|
||||
async def _store_to_redis(
|
||||
self, message_id: int, channel_id: int, message_data: dict[str, Any]
|
||||
):
|
||||
"""存储消息到 Redis"""
|
||||
try:
|
||||
# 存储消息数据
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)
|
||||
for k, v in message_data.items()}
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
mapping={
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)
|
||||
for k, v in message_data.items()
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# 设置消息过期时间(7天)
|
||||
await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800)
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.expire, f"msg:{channel_id}:{message_id}", 604800
|
||||
)
|
||||
|
||||
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
|
||||
channel_messages_key = f"channel:{channel_id}:messages"
|
||||
|
||||
|
||||
# 检查键的类型,如果不是 zset 类型则删除
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, channel_messages_key)
|
||||
if key_type and key_type != "zset":
|
||||
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
||||
logger.warning(
|
||||
f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}"
|
||||
)
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
except Exception as type_check_error:
|
||||
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
||||
logger.warning(
|
||||
f"Failed to check key type for {channel_messages_key}: {type_check_error}"
|
||||
)
|
||||
# 如果检查失败,直接删除键以确保清理
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
|
||||
|
||||
# 添加到频道消息列表(sorted set)
|
||||
await self._redis_exec(
|
||||
self.redis.zadd,
|
||||
channel_messages_key,
|
||||
{f"msg:{channel_id}:{message_id}": message_id}
|
||||
self.redis.zadd,
|
||||
channel_messages_key,
|
||||
{f"msg:{channel_id}:{message_id}": message_id},
|
||||
)
|
||||
|
||||
|
||||
# 保持频道消息列表大小(最多1000条)
|
||||
await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001)
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
|
||||
)
|
||||
|
||||
# 添加到待持久化队列
|
||||
await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}")
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store message to Redis: {e}")
|
||||
raise
|
||||
|
||||
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> List[Dict[str, Any]]:
|
||||
|
||||
async def _get_from_redis(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict[str, Any]]:
|
||||
"""从 Redis 获取消息"""
|
||||
try:
|
||||
# 获取消息键列表,按消息ID排序
|
||||
@@ -239,22 +287,22 @@ class RedisMessageSystem:
|
||||
message_keys = await self._redis_exec(
|
||||
self.redis.zrangebyscore,
|
||||
f"channel:{channel_id}:messages",
|
||||
since + 1, "+inf",
|
||||
start=0, num=limit
|
||||
since + 1,
|
||||
"+inf",
|
||||
start=0,
|
||||
num=limit,
|
||||
)
|
||||
else:
|
||||
# 获取最新的消息(倒序获取,然后反转)
|
||||
message_keys = await self._redis_exec(
|
||||
self.redis.zrevrange,
|
||||
f"channel:{channel_id}:messages",
|
||||
0, limit - 1
|
||||
self.redis.zrevrange, f"channel:{channel_id}:messages", 0, limit - 1
|
||||
)
|
||||
|
||||
|
||||
messages = []
|
||||
for key in message_keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode('utf-8')
|
||||
|
||||
key = key.decode("utf-8")
|
||||
|
||||
# 获取消息数据
|
||||
raw_data = await self._redis_exec(self.redis.hgetall, key)
|
||||
if raw_data:
|
||||
@@ -262,106 +310,118 @@ class RedisMessageSystem:
|
||||
message_data = {}
|
||||
for k, v in raw_data.items():
|
||||
if isinstance(k, bytes):
|
||||
k = k.decode('utf-8')
|
||||
k = k.decode("utf-8")
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode('utf-8')
|
||||
|
||||
v = v.decode("utf-8")
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
if k in ['grade_counts', 'level'] or v.startswith(('{', '[')):
|
||||
if k in ["grade_counts", "level"] or v.startswith(
|
||||
("{", "[")
|
||||
):
|
||||
message_data[k] = json.loads(v)
|
||||
elif k in ['message_id', 'channel_id', 'sender_id']:
|
||||
elif k in ["message_id", "channel_id", "sender_id"]:
|
||||
message_data[k] = int(v)
|
||||
elif k == 'created_at':
|
||||
elif k == "created_at":
|
||||
message_data[k] = float(v)
|
||||
else:
|
||||
message_data[k] = v
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
message_data[k] = v
|
||||
|
||||
|
||||
messages.append(message_data)
|
||||
|
||||
|
||||
# 确保消息按ID正序排序(时间顺序)
|
||||
messages.sort(key=lambda x: x.get('message_id', 0))
|
||||
|
||||
messages.sort(key=lambda x: x.get("message_id", 0))
|
||||
|
||||
# 如果是获取最新消息(since=0),需要保持倒序(最新的在前面)
|
||||
if since == 0:
|
||||
messages.reverse()
|
||||
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages from Redis: {e}")
|
||||
return []
|
||||
|
||||
async def _backfill_from_database(self, channel_id: int, existing_messages: List[ChatMessageResp], limit: int):
|
||||
|
||||
async def _backfill_from_database(
|
||||
self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int
|
||||
):
|
||||
"""从数据库补充历史消息"""
|
||||
try:
|
||||
# 找到最小的消息ID
|
||||
min_id = float('inf')
|
||||
min_id = float("inf")
|
||||
if existing_messages:
|
||||
for msg in existing_messages:
|
||||
if msg.message_id is not None and msg.message_id < min_id:
|
||||
min_id = msg.message_id
|
||||
|
||||
|
||||
needed = limit - len(existing_messages)
|
||||
|
||||
|
||||
if needed <= 0:
|
||||
return
|
||||
|
||||
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select, col
|
||||
query = select(ChatMessage).where(
|
||||
ChatMessage.channel_id == channel_id
|
||||
)
|
||||
|
||||
if min_id != float('inf'):
|
||||
from sqlmodel import col, select
|
||||
|
||||
query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
|
||||
|
||||
if min_id != float("inf"):
|
||||
query = query.where(col(ChatMessage.message_id) < min_id)
|
||||
|
||||
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(needed)
|
||||
|
||||
|
||||
db_messages = (await session.exec(query)).all()
|
||||
|
||||
|
||||
for msg in reversed(db_messages): # 按时间正序插入
|
||||
msg_resp = await ChatMessageResp.from_db(msg, session)
|
||||
existing_messages.insert(0, msg_resp)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill from database: {e}")
|
||||
|
||||
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> List[ChatMessageResp]:
|
||||
|
||||
async def _get_from_database_only(
|
||||
self, channel_id: int, limit: int, since: int
|
||||
) -> list[ChatMessageResp]:
|
||||
"""仅从数据库获取消息(回退方案)"""
|
||||
try:
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select, col
|
||||
from sqlmodel import col, select
|
||||
|
||||
query = select(ChatMessage).where(ChatMessage.channel_id == channel_id)
|
||||
|
||||
|
||||
if since > 0:
|
||||
# 获取指定ID之后的消息,按ID正序
|
||||
query = query.where(col(ChatMessage.message_id) > since)
|
||||
query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
||||
query = query.order_by(col(ChatMessage.message_id).asc()).limit(
|
||||
limit
|
||||
)
|
||||
else:
|
||||
# 获取最新消息,按ID倒序(最新的在前面)
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(
|
||||
limit
|
||||
)
|
||||
|
||||
messages = (await session.exec(query)).all()
|
||||
|
||||
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
|
||||
|
||||
results = [
|
||||
await ChatMessageResp.from_db(msg, session) for msg in messages
|
||||
]
|
||||
|
||||
# 如果是 since > 0,保持正序;否则反转为时间正序
|
||||
if since == 0:
|
||||
results.reverse()
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages from database: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def _batch_persist_to_database(self):
|
||||
"""批量持久化消息到数据库"""
|
||||
logger.info("Starting batch persistence to database")
|
||||
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 获取待处理的消息
|
||||
@@ -374,52 +434,52 @@ class RedisMessageSystem:
|
||||
# key 是 (queue_name, value) 的元组
|
||||
value = key[1]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
value = value.decode("utf-8")
|
||||
message_keys.append(value)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
if message_keys:
|
||||
await self._process_message_batch(message_keys)
|
||||
else:
|
||||
await asyncio.sleep(self.batch_interval)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch persistence: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
logger.info("Stopped batch persistence to database")
|
||||
|
||||
async def _process_message_batch(self, message_keys: List[str]):
|
||||
|
||||
async def _process_message_batch(self, message_keys: list[str]):
|
||||
"""处理消息批次"""
|
||||
async with with_db() as session:
|
||||
for key in message_keys:
|
||||
try:
|
||||
# 解析频道ID和消息ID
|
||||
channel_id, message_id = map(int, key.split(':'))
|
||||
|
||||
channel_id, message_id = map(int, key.split(":"))
|
||||
|
||||
# 从 Redis 获取消息数据
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis.hgetall, f"msg:{channel_id}:{message_id}"
|
||||
)
|
||||
|
||||
|
||||
if not raw_data:
|
||||
continue
|
||||
|
||||
|
||||
# 解码数据
|
||||
message_data = {}
|
||||
for k, v in raw_data.items():
|
||||
if isinstance(k, bytes):
|
||||
k = k.decode('utf-8')
|
||||
k = k.decode("utf-8")
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode('utf-8')
|
||||
v = v.decode("utf-8")
|
||||
message_data[k] = v
|
||||
|
||||
|
||||
# 检查消息是否已存在于数据库
|
||||
existing = await session.get(ChatMessage, int(message_id))
|
||||
if existing:
|
||||
continue
|
||||
|
||||
|
||||
# 创建数据库消息 - 使用 Redis 生成的正数ID
|
||||
db_message = ChatMessage(
|
||||
message_id=int(message_id), # 使用 Redis 系统生成的正数ID
|
||||
@@ -428,31 +488,34 @@ class RedisMessageSystem:
|
||||
content=message_data["content"],
|
||||
timestamp=datetime.fromisoformat(message_data["timestamp"]),
|
||||
type=MessageType(message_data["type"]),
|
||||
uuid=message_data.get("uuid") or None
|
||||
uuid=message_data.get("uuid") or None,
|
||||
)
|
||||
|
||||
|
||||
session.add(db_message)
|
||||
|
||||
|
||||
# 更新 Redis 中的状态
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
"status", "persisted"
|
||||
"status",
|
||||
"persisted",
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"Message {message_id} persisted to database")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process message {key}: {e}")
|
||||
|
||||
|
||||
# 提交批次
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(f"Batch of {len(message_keys)} messages committed to database")
|
||||
logger.info(
|
||||
f"Batch of {len(message_keys)} messages committed to database"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to commit message batch: {e}")
|
||||
await session.rollback()
|
||||
|
||||
|
||||
def start(self):
|
||||
"""启动系统"""
|
||||
if not self._running:
|
||||
@@ -461,63 +524,71 @@ class RedisMessageSystem:
|
||||
# 启动时初始化消息ID计数器
|
||||
asyncio.create_task(self._initialize_message_counter())
|
||||
logger.info("Redis message system started")
|
||||
|
||||
|
||||
async def _initialize_message_counter(self):
|
||||
"""初始化全局消息ID计数器,确保从数据库最大ID开始"""
|
||||
try:
|
||||
# 清理可能存在的问题键
|
||||
await self._cleanup_redis_keys()
|
||||
|
||||
|
||||
async with with_db() as session:
|
||||
from sqlmodel import select, func
|
||||
|
||||
from sqlmodel import func, select
|
||||
|
||||
# 获取数据库中最大的消息ID
|
||||
result = await session.exec(
|
||||
select(func.max(ChatMessage.message_id))
|
||||
)
|
||||
result = await session.exec(select(func.max(ChatMessage.message_id)))
|
||||
max_id = result.one() or 0
|
||||
|
||||
|
||||
# 检查 Redis 中的计数器值
|
||||
current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter")
|
||||
current_counter = await self._redis_exec(
|
||||
self.redis.get, "global_message_id_counter"
|
||||
)
|
||||
current_counter = int(current_counter) if current_counter else 0
|
||||
|
||||
|
||||
# 设置计数器为两者中的最大值
|
||||
initial_counter = max(max_id, current_counter)
|
||||
await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter)
|
||||
|
||||
logger.info(f"Initialized global message ID counter to {initial_counter}")
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.set, "global_message_id_counter", initial_counter
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initialized global message ID counter to {initial_counter}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize message counter: {e}")
|
||||
# 如果初始化失败,设置一个安全的起始值
|
||||
await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000)
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis.setnx, "global_message_id_counter", 1000000
|
||||
)
|
||||
|
||||
async def _cleanup_redis_keys(self):
|
||||
"""清理可能存在问题的 Redis 键"""
|
||||
try:
|
||||
# 扫描所有 channel:*:messages 键并检查类型
|
||||
keys_pattern = "channel:*:messages"
|
||||
keys = await self._redis_exec(self.redis.keys, keys_pattern)
|
||||
|
||||
|
||||
for key in keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode('utf-8')
|
||||
|
||||
key = key.decode("utf-8")
|
||||
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, key)
|
||||
if key_type and key_type != "zset":
|
||||
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}")
|
||||
logger.warning(
|
||||
f"Cleaning up Redis key {key} with wrong type: {key_type}"
|
||||
)
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")
|
||||
# 强制删除问题键
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
|
||||
|
||||
logger.info("Redis keys cleanup completed")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup Redis keys: {e}")
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""停止系统"""
|
||||
if self._running:
|
||||
@@ -526,10 +597,10 @@ class RedisMessageSystem:
|
||||
self._batch_timer.cancel()
|
||||
self._batch_timer = None
|
||||
logger.info("Redis message system stopped")
|
||||
|
||||
|
||||
def __del__(self):
|
||||
"""清理资源"""
|
||||
if hasattr(self, 'executor'):
|
||||
if hasattr(self, "executor"):
|
||||
self.executor.shutdown(wait=False)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user