refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -18,6 +18,7 @@ from app.database.chat import ChatMessage, ChatMessageResp, MessageType
|
||||
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
|
||||
from app.dependencies.database import get_redis_message, with_db
|
||||
from app.log import logger
|
||||
from app.utils import bg_tasks
|
||||
|
||||
|
||||
class RedisMessageSystem:
|
||||
@@ -67,12 +68,11 @@ class RedisMessageSystem:
|
||||
|
||||
# 获取频道类型以判断是否需要存储到数据库
|
||||
async with with_db() as session:
|
||||
from app.database.chat import ChatChannel, ChannelType
|
||||
from app.database.chat import ChannelType, ChatChannel
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
channel_result = await session.exec(
|
||||
select(ChatChannel.type).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
|
||||
channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id))
|
||||
channel_type = channel_result.first()
|
||||
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
|
||||
|
||||
@@ -132,17 +132,14 @@ class RedisMessageSystem:
|
||||
|
||||
if is_multiplayer:
|
||||
logger.info(
|
||||
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}, will not be persisted to database"
|
||||
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id},"
|
||||
" will not be persisted to database"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Message {message_id} sent to Redis cache for channel {channel_id}"
|
||||
)
|
||||
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
||||
return response
|
||||
|
||||
async def get_messages(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[ChatMessageResp]:
|
||||
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]:
|
||||
"""
|
||||
获取频道消息 - 优先从 Redis 获取最新消息
|
||||
|
||||
@@ -166,9 +163,7 @@ class RedisMessageSystem:
|
||||
# 获取发送者信息
|
||||
sender = await session.get(User, msg_data["sender_id"])
|
||||
if sender:
|
||||
user_resp = await UserResp.from_db(
|
||||
sender, session, RANKING_INCLUDES
|
||||
)
|
||||
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
|
||||
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
@@ -223,39 +218,28 @@ class RedisMessageSystem:
|
||||
async def _generate_message_id(self, channel_id: int) -> int:
|
||||
"""生成唯一的消息ID - 确保全局唯一且严格递增"""
|
||||
# 使用全局计数器确保所有频道的消息ID都是严格递增的
|
||||
message_id = await self._redis_exec(
|
||||
self.redis.incr, "global_message_id_counter"
|
||||
)
|
||||
message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter")
|
||||
|
||||
# 同时更新频道的最后消息ID,用于客户端状态同步
|
||||
await self._redis_exec(
|
||||
self.redis.set, f"channel:{channel_id}:last_msg_id", message_id
|
||||
)
|
||||
await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id)
|
||||
|
||||
return message_id
|
||||
|
||||
async def _store_to_redis(
|
||||
self, message_id: int, channel_id: int, message_data: dict[str, Any]
|
||||
):
|
||||
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]):
|
||||
"""存储消息到 Redis"""
|
||||
try:
|
||||
# 检查是否是多人房间消息
|
||||
is_multiplayer = message_data.get("is_multiplayer", False)
|
||||
|
||||
|
||||
# 存储消息数据
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
mapping={
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)
|
||||
for k, v in message_data.items()
|
||||
},
|
||||
mapping={k: json.dumps(v) if isinstance(v, dict | list) else str(v) for k, v in message_data.items()},
|
||||
)
|
||||
|
||||
# 设置消息过期时间(7天)
|
||||
await self._redis_exec(
|
||||
self.redis.expire, f"msg:{channel_id}:{message_id}", 604800
|
||||
)
|
||||
await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800)
|
||||
|
||||
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
|
||||
channel_messages_key = f"channel:{channel_id}:messages"
|
||||
@@ -264,14 +248,10 @@ class RedisMessageSystem:
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, channel_messages_key)
|
||||
if key_type and key_type != "zset":
|
||||
logger.warning(
|
||||
f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}"
|
||||
)
|
||||
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
except Exception as type_check_error:
|
||||
logger.warning(
|
||||
f"Failed to check key type for {channel_messages_key}: {type_check_error}"
|
||||
)
|
||||
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
||||
# 如果检查失败,直接删除键以确保清理
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
|
||||
@@ -283,15 +263,11 @@ class RedisMessageSystem:
|
||||
)
|
||||
|
||||
# 保持频道消息列表大小(最多1000条)
|
||||
await self._redis_exec(
|
||||
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
|
||||
)
|
||||
await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001)
|
||||
|
||||
# 只有非多人房间消息才添加到待持久化队列
|
||||
if not is_multiplayer:
|
||||
await self._redis_exec(
|
||||
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
|
||||
)
|
||||
await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}")
|
||||
logger.debug(f"Message {message_id} added to persistence queue")
|
||||
else:
|
||||
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
|
||||
@@ -300,9 +276,7 @@ class RedisMessageSystem:
|
||||
logger.error(f"Failed to store message to Redis: {e}")
|
||||
raise
|
||||
|
||||
async def _get_from_redis(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict[str, Any]]:
|
||||
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]:
|
||||
"""从 Redis 获取消息"""
|
||||
try:
|
||||
# 获取消息键列表,按消息ID排序
|
||||
@@ -340,9 +314,7 @@ class RedisMessageSystem:
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
if k in ["grade_counts", "level"] or v.startswith(
|
||||
("{", "[")
|
||||
):
|
||||
if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
|
||||
message_data[k] = json.loads(v)
|
||||
elif k in ["message_id", "channel_id", "sender_id"]:
|
||||
message_data[k] = int(v)
|
||||
@@ -368,9 +340,7 @@ class RedisMessageSystem:
|
||||
logger.error(f"Failed to get messages from Redis: {e}")
|
||||
return []
|
||||
|
||||
async def _backfill_from_database(
|
||||
self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int
|
||||
):
|
||||
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int):
|
||||
"""从数据库补充历史消息"""
|
||||
try:
|
||||
# 找到最小的消息ID
|
||||
@@ -404,9 +374,7 @@ class RedisMessageSystem:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill from database: {e}")
|
||||
|
||||
async def _get_from_database_only(
|
||||
self, channel_id: int, limit: int, since: int
|
||||
) -> list[ChatMessageResp]:
|
||||
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]:
|
||||
"""仅从数据库获取消息(回退方案)"""
|
||||
try:
|
||||
async with with_db() as session:
|
||||
@@ -417,20 +385,14 @@ class RedisMessageSystem:
|
||||
if since > 0:
|
||||
# 获取指定ID之后的消息,按ID正序
|
||||
query = query.where(col(ChatMessage.message_id) > since)
|
||||
query = query.order_by(col(ChatMessage.message_id).asc()).limit(
|
||||
limit
|
||||
)
|
||||
query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
||||
else:
|
||||
# 获取最新消息,按ID倒序(最新的在前面)
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(
|
||||
limit
|
||||
)
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
|
||||
messages = (await session.exec(query)).all()
|
||||
|
||||
results = [
|
||||
await ChatMessageResp.from_db(msg, session) for msg in messages
|
||||
]
|
||||
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
|
||||
# 如果是 since > 0,保持正序;否则反转为时间正序
|
||||
if since == 0:
|
||||
@@ -451,9 +413,7 @@ class RedisMessageSystem:
|
||||
# 获取待处理的消息
|
||||
message_keys = []
|
||||
for _ in range(self.max_batch_size):
|
||||
key = await self._redis_exec(
|
||||
self.redis.brpop, ["pending_messages"], timeout=1
|
||||
)
|
||||
key = await self._redis_exec(self.redis.brpop, ["pending_messages"], timeout=1)
|
||||
if key:
|
||||
# key 是 (queue_name, value) 的元组
|
||||
value = key[1]
|
||||
@@ -483,9 +443,7 @@ class RedisMessageSystem:
|
||||
channel_id, message_id = map(int, key.split(":"))
|
||||
|
||||
# 从 Redis 获取消息数据
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis.hgetall, f"msg:{channel_id}:{message_id}"
|
||||
)
|
||||
raw_data = await self._redis_exec(self.redis.hgetall, f"msg:{channel_id}:{message_id}")
|
||||
|
||||
if not raw_data:
|
||||
continue
|
||||
@@ -546,9 +504,7 @@ class RedisMessageSystem:
|
||||
# 提交批次
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Batch of {len(message_keys)} messages committed to database"
|
||||
)
|
||||
logger.info(f"Batch of {len(message_keys)} messages committed to database")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to commit message batch: {e}")
|
||||
await session.rollback()
|
||||
@@ -559,7 +515,7 @@ class RedisMessageSystem:
|
||||
self._running = True
|
||||
self._batch_timer = asyncio.create_task(self._batch_persist_to_database())
|
||||
# 启动时初始化消息ID计数器
|
||||
asyncio.create_task(self._initialize_message_counter())
|
||||
bg_tasks.add_task(self._initialize_message_counter)
|
||||
logger.info("Redis message system started")
|
||||
|
||||
async def _initialize_message_counter(self):
|
||||
@@ -576,27 +532,19 @@ class RedisMessageSystem:
|
||||
max_id = result.one() or 0
|
||||
|
||||
# 检查 Redis 中的计数器值
|
||||
current_counter = await self._redis_exec(
|
||||
self.redis.get, "global_message_id_counter"
|
||||
)
|
||||
current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter")
|
||||
current_counter = int(current_counter) if current_counter else 0
|
||||
|
||||
# 设置计数器为两者中的最大值
|
||||
initial_counter = max(max_id, current_counter)
|
||||
await self._redis_exec(
|
||||
self.redis.set, "global_message_id_counter", initial_counter
|
||||
)
|
||||
await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter)
|
||||
|
||||
logger.info(
|
||||
f"Initialized global message ID counter to {initial_counter}"
|
||||
)
|
||||
logger.info(f"Initialized global message ID counter to {initial_counter}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize message counter: {e}")
|
||||
# 如果初始化失败,设置一个安全的起始值
|
||||
await self._redis_exec(
|
||||
self.redis.setnx, "global_message_id_counter", 1000000
|
||||
)
|
||||
await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000)
|
||||
|
||||
async def _cleanup_redis_keys(self):
|
||||
"""清理可能存在问题的 Redis 键"""
|
||||
@@ -612,9 +560,7 @@ class RedisMessageSystem:
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, key)
|
||||
if key_type and key_type != "zset":
|
||||
logger.warning(
|
||||
f"Cleaning up Redis key {key} with wrong type: {key_type}"
|
||||
)
|
||||
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}")
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")
|
||||
|
||||
Reference in New Issue
Block a user