refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -8,14 +8,14 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.database.chat import ChatMessage, ChatMessageResp, MessageType
|
||||
from app.database.user import RANKING_INCLUDES, User, UserResp
|
||||
from app.database import ChatMessageDict
|
||||
from app.database.chat import ChatMessage, ChatMessageModel, MessageType
|
||||
from app.database.user import User, UserModel
|
||||
from app.dependencies.database import get_redis_message, with_db
|
||||
from app.log import logger
|
||||
from app.utils import bg_tasks
|
||||
from app.utils import bg_tasks, safe_json_dumps
|
||||
|
||||
|
||||
class RedisMessageSystem:
|
||||
@@ -35,7 +35,7 @@ class RedisMessageSystem:
|
||||
content: str,
|
||||
is_action: bool = False,
|
||||
user_uuid: str | None = None,
|
||||
) -> ChatMessageResp:
|
||||
) -> "ChatMessageDict":
|
||||
"""
|
||||
发送消息 - 立即存储到 Redis 并返回
|
||||
|
||||
@@ -47,7 +47,7 @@ class RedisMessageSystem:
|
||||
user_uuid: 用户UUID
|
||||
|
||||
Returns:
|
||||
ChatMessageResp: 消息响应对象
|
||||
ChatMessage: 消息响应对象
|
||||
"""
|
||||
# 生成消息ID和时间戳
|
||||
message_id = await self._generate_message_id(channel_id)
|
||||
@@ -57,28 +57,16 @@ class RedisMessageSystem:
|
||||
if not user.id:
|
||||
raise ValueError("User ID is required")
|
||||
|
||||
# 获取频道类型以判断是否需要存储到数据库
|
||||
async with with_db() as session:
|
||||
from app.database.chat import ChannelType, ChatChannel
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id))
|
||||
channel_type = channel_result.first()
|
||||
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
|
||||
|
||||
# 准备消息数据
|
||||
message_data = {
|
||||
message_data: "ChatMessageDict" = {
|
||||
"message_id": message_id,
|
||||
"channel_id": channel_id,
|
||||
"sender_id": user.id,
|
||||
"content": content,
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
|
||||
"timestamp": timestamp,
|
||||
"type": MessageType.ACTION if is_action else MessageType.PLAIN,
|
||||
"uuid": user_uuid or "",
|
||||
"status": "cached", # Redis 缓存状态
|
||||
"created_at": time.time(),
|
||||
"is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
|
||||
"is_action": is_action,
|
||||
}
|
||||
|
||||
# 立即存储到 Redis
|
||||
@@ -86,51 +74,13 @@ class RedisMessageSystem:
|
||||
|
||||
# 创建响应对象
|
||||
async with with_db() as session:
|
||||
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
user_resp = await UserModel.transform(user, session=session, includes=User.LIST_INCLUDES)
|
||||
message_data["sender"] = user_resp
|
||||
|
||||
# 确保 statistics 不为空
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
||||
return message_data
|
||||
|
||||
user_resp.statistics = UserStatisticsResp(
|
||||
mode=user.playmode,
|
||||
global_rank=0,
|
||||
country_rank=0,
|
||||
pp=0.0,
|
||||
ranked_score=0,
|
||||
hit_accuracy=0.0,
|
||||
play_count=0,
|
||||
play_time=0,
|
||||
total_score=0,
|
||||
total_hits=0,
|
||||
maximum_combo=0,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
|
||||
level={"current": 1, "progress": 0},
|
||||
)
|
||||
|
||||
response = ChatMessageResp(
|
||||
message_id=message_id,
|
||||
channel_id=channel_id,
|
||||
content=content,
|
||||
timestamp=timestamp,
|
||||
sender_id=user.id,
|
||||
sender=user_resp,
|
||||
is_action=is_action,
|
||||
uuid=user_uuid,
|
||||
)
|
||||
|
||||
if is_multiplayer:
|
||||
logger.info(
|
||||
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id},"
|
||||
" will not be persisted to database"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
||||
return response
|
||||
|
||||
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]:
|
||||
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
|
||||
"""
|
||||
获取频道消息 - 优先从 Redis 获取最新消息
|
||||
|
||||
@@ -140,9 +90,9 @@ class RedisMessageSystem:
|
||||
since: 起始消息ID
|
||||
|
||||
Returns:
|
||||
List[ChatMessageResp]: 消息列表
|
||||
List[ChatMessageDict]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
messages: list["ChatMessageDict"] = []
|
||||
|
||||
try:
|
||||
# 从 Redis 获取最新消息
|
||||
@@ -154,45 +104,21 @@ class RedisMessageSystem:
|
||||
# 获取发送者信息
|
||||
sender = await session.get(User, msg_data["sender_id"])
|
||||
if sender:
|
||||
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
|
||||
user_resp = await UserModel.transform(sender, includes=User.LIST_INCLUDES)
|
||||
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
from app.database.chat import ChatMessageDict
|
||||
|
||||
user_resp.statistics = UserStatisticsResp(
|
||||
mode=sender.playmode,
|
||||
global_rank=0,
|
||||
country_rank=0,
|
||||
pp=0.0,
|
||||
ranked_score=0,
|
||||
hit_accuracy=0.0,
|
||||
play_count=0,
|
||||
play_time=0,
|
||||
total_score=0,
|
||||
total_hits=0,
|
||||
maximum_combo=0,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_counts={
|
||||
"ssh": 0,
|
||||
"ss": 0,
|
||||
"sh": 0,
|
||||
"s": 0,
|
||||
"a": 0,
|
||||
},
|
||||
level={"current": 1, "progress": 0},
|
||||
)
|
||||
|
||||
message_resp = ChatMessageResp(
|
||||
message_id=msg_data["message_id"],
|
||||
channel_id=msg_data["channel_id"],
|
||||
content=msg_data["content"],
|
||||
timestamp=datetime.fromisoformat(msg_data["timestamp"]),
|
||||
sender_id=msg_data["sender_id"],
|
||||
sender=user_resp,
|
||||
is_action=msg_data["type"] == MessageType.ACTION.value,
|
||||
uuid=msg_data.get("uuid") or None,
|
||||
)
|
||||
message_resp: ChatMessageDict = {
|
||||
"message_id": msg_data["message_id"],
|
||||
"channel_id": msg_data["channel_id"],
|
||||
"content": msg_data["content"],
|
||||
"timestamp": datetime.fromisoformat(msg_data["timestamp"]), # pyright: ignore[reportArgumentType]
|
||||
"sender_id": msg_data["sender_id"],
|
||||
"sender": user_resp,
|
||||
"is_action": msg_data["type"] == MessageType.ACTION.value,
|
||||
"uuid": msg_data.get("uuid") or None,
|
||||
"type": MessageType(msg_data["type"]),
|
||||
}
|
||||
messages.append(message_resp)
|
||||
|
||||
# 如果 Redis 消息不够,从数据库补充
|
||||
@@ -216,86 +142,46 @@ class RedisMessageSystem:
|
||||
|
||||
return message_id
|
||||
|
||||
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]):
|
||||
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: ChatMessageDict):
|
||||
"""存储消息到 Redis"""
|
||||
try:
|
||||
# 检查是否是多人房间消息
|
||||
is_multiplayer = message_data.get("is_multiplayer", False)
|
||||
|
||||
# 存储消息数据
|
||||
await self.redis.hset(
|
||||
# 存储消息数据为 JSON 字符串
|
||||
await self.redis.set(
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in message_data.items()},
|
||||
safe_json_dumps(message_data),
|
||||
ex=604800, # 7天过期
|
||||
)
|
||||
|
||||
# 设置消息过期时间(7天)
|
||||
await self.redis.expire(f"msg:{channel_id}:{message_id}", 604800)
|
||||
|
||||
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
|
||||
# 添加到频道消息列表(按时间排序)
|
||||
channel_messages_key = f"channel:{channel_id}:messages"
|
||||
|
||||
# 更健壮的键类型检查和清理
|
||||
# 检查并清理错误类型的键
|
||||
try:
|
||||
key_type = await self.redis.type(channel_messages_key)
|
||||
if key_type == "none":
|
||||
# 键不存在,这是正常的
|
||||
pass
|
||||
elif key_type != "zset":
|
||||
# 键类型错误,需要清理
|
||||
if key_type not in ("none", "zset"):
|
||||
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
||||
await self.redis.delete(channel_messages_key)
|
||||
|
||||
# 验证删除是否成功
|
||||
verify_type = await self.redis.type(channel_messages_key)
|
||||
if verify_type != "none":
|
||||
logger.error(
|
||||
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
|
||||
)
|
||||
# 强制删除
|
||||
await self.redis.unlink(channel_messages_key)
|
||||
|
||||
except Exception as type_check_error:
|
||||
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
||||
# 如果检查失败,尝试强制删除键以确保清理
|
||||
try:
|
||||
await self.redis.delete(channel_messages_key)
|
||||
except Exception:
|
||||
# 最后的努力:使用unlink
|
||||
try:
|
||||
await self.redis.unlink(channel_messages_key)
|
||||
except Exception as final_error:
|
||||
logger.error(f"Critical: Unable to clear problematic key {channel_messages_key}: {final_error}")
|
||||
await self.redis.delete(channel_messages_key)
|
||||
|
||||
# 添加到频道消息列表(sorted set)
|
||||
try:
|
||||
await self.redis.zadd(
|
||||
channel_messages_key,
|
||||
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
||||
)
|
||||
except Exception as zadd_error:
|
||||
logger.error(f"Failed to add message to sorted set {channel_messages_key}: {zadd_error}")
|
||||
# 如果添加失败,再次尝试清理并重试
|
||||
await self.redis.delete(channel_messages_key)
|
||||
await self.redis.zadd(
|
||||
channel_messages_key,
|
||||
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
||||
)
|
||||
await self.redis.zadd(
|
||||
channel_messages_key,
|
||||
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
||||
)
|
||||
|
||||
# 保持频道消息列表大小(最多1000条)
|
||||
await self.redis.zremrangebyrank(channel_messages_key, 0, -1001)
|
||||
|
||||
# 只有非多人房间消息才添加到待持久化队列
|
||||
if not is_multiplayer:
|
||||
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
|
||||
logger.debug(f"Message {message_id} added to persistence queue")
|
||||
else:
|
||||
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
|
||||
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
|
||||
logger.debug(f"Message {message_id} added to persistence queue")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store message to Redis: {e}")
|
||||
raise
|
||||
|
||||
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]:
|
||||
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
|
||||
"""从 Redis 获取消息"""
|
||||
try:
|
||||
# 获取消息键列表,按消息ID排序
|
||||
@@ -314,28 +200,16 @@ class RedisMessageSystem:
|
||||
|
||||
messages = []
|
||||
for key in message_keys:
|
||||
# 获取消息数据
|
||||
raw_data = await self.redis.hgetall(key)
|
||||
# 获取消息数据(JSON 字符串)
|
||||
raw_data = await self.redis.get(key)
|
||||
if raw_data:
|
||||
# 解码数据
|
||||
message_data: dict[str, Any] = {}
|
||||
for k, v in raw_data.items():
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
|
||||
message_data[k] = json.loads(v)
|
||||
elif k in ["message_id", "channel_id", "sender_id"]:
|
||||
message_data[k] = int(v)
|
||||
elif k == "is_multiplayer":
|
||||
message_data[k] = v == "True"
|
||||
elif k == "created_at":
|
||||
message_data[k] = float(v)
|
||||
else:
|
||||
message_data[k] = v
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
message_data[k] = v
|
||||
|
||||
messages.append(message_data)
|
||||
try:
|
||||
# 解析 JSON 字符串为字典
|
||||
message_data = json.loads(raw_data)
|
||||
messages.append(message_data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode message JSON from {key}: {e}")
|
||||
continue
|
||||
|
||||
# 确保消息按ID正序排序(时间顺序)
|
||||
messages.sort(key=lambda x: x.get("message_id", 0))
|
||||
@@ -350,15 +224,15 @@ class RedisMessageSystem:
|
||||
logger.error(f"Failed to get messages from Redis: {e}")
|
||||
return []
|
||||
|
||||
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int):
|
||||
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageDict], limit: int):
|
||||
"""从数据库补充历史消息"""
|
||||
try:
|
||||
# 找到最小的消息ID
|
||||
min_id = float("inf")
|
||||
if existing_messages:
|
||||
for msg in existing_messages:
|
||||
if msg.message_id is not None and msg.message_id < min_id:
|
||||
min_id = msg.message_id
|
||||
if msg["message_id"] is not None and msg["message_id"] < min_id:
|
||||
min_id = msg["message_id"]
|
||||
|
||||
needed = limit - len(existing_messages)
|
||||
|
||||
@@ -378,13 +252,13 @@ class RedisMessageSystem:
|
||||
db_messages = (await session.exec(query)).all()
|
||||
|
||||
for msg in reversed(db_messages): # 按时间正序插入
|
||||
msg_resp = await ChatMessageResp.from_db(msg, session)
|
||||
msg_resp = await ChatMessageModel.transform(msg, includes=["sender"])
|
||||
existing_messages.insert(0, msg_resp)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill from database: {e}")
|
||||
|
||||
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]:
|
||||
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageDict]:
|
||||
"""仅从数据库获取消息(回退方案)"""
|
||||
try:
|
||||
async with with_db() as session:
|
||||
@@ -402,7 +276,7 @@ class RedisMessageSystem:
|
||||
|
||||
messages = (await session.exec(query)).all()
|
||||
|
||||
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
results = await ChatMessageModel.transform_many(messages, includes=["sender"])
|
||||
|
||||
# 如果是 since > 0,保持正序;否则反转为时间正序
|
||||
if since == 0:
|
||||
@@ -450,27 +324,17 @@ class RedisMessageSystem:
|
||||
# 解析频道ID和消息ID
|
||||
channel_id, message_id = map(int, key.split(":"))
|
||||
|
||||
# 从 Redis 获取消息数据
|
||||
raw_data = await self.redis.hgetall(f"msg:{channel_id}:{message_id}")
|
||||
# 从 Redis 获取消息数据(JSON 字符串)
|
||||
raw_data = await self.redis.get(f"msg:{channel_id}:{message_id}")
|
||||
|
||||
if not raw_data:
|
||||
continue
|
||||
|
||||
# 解码数据
|
||||
message_data = {}
|
||||
for k, v in raw_data.items():
|
||||
message_data[k] = v
|
||||
|
||||
# 检查是否是多人房间消息,如果是则跳过数据库存储
|
||||
is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
|
||||
if is_multiplayer:
|
||||
# 多人房间消息不存储到数据库,直接标记为已跳过
|
||||
await self.redis.hset(
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
"status",
|
||||
"skipped_multiplayer",
|
||||
)
|
||||
logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
|
||||
# 解析 JSON 字符串为字典
|
||||
try:
|
||||
message_data = json.loads(raw_data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode message JSON for {channel_id}:{message_id}: {e}")
|
||||
continue
|
||||
|
||||
# 检查消息是否已存在于数据库
|
||||
@@ -491,13 +355,6 @@ class RedisMessageSystem:
|
||||
|
||||
session.add(db_message)
|
||||
|
||||
# 更新 Redis 中的状态
|
||||
await self.redis.hset(
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
"status",
|
||||
"persisted",
|
||||
)
|
||||
|
||||
logger.debug(f"Message {message_id} persisted to database")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user