refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -8,14 +8,14 @@
import asyncio
from datetime import datetime
import json
import time
from typing import Any
from app.database.chat import ChatMessage, ChatMessageResp, MessageType
from app.database.user import RANKING_INCLUDES, User, UserResp
from app.database import ChatMessageDict
from app.database.chat import ChatMessage, ChatMessageModel, MessageType
from app.database.user import User, UserModel
from app.dependencies.database import get_redis_message, with_db
from app.log import logger
from app.utils import bg_tasks
from app.utils import bg_tasks, safe_json_dumps
class RedisMessageSystem:
@@ -35,7 +35,7 @@ class RedisMessageSystem:
content: str,
is_action: bool = False,
user_uuid: str | None = None,
) -> ChatMessageResp:
) -> "ChatMessageDict":
"""
发送消息 - 立即存储到 Redis 并返回
@@ -47,7 +47,7 @@ class RedisMessageSystem:
user_uuid: 用户UUID
Returns:
ChatMessageResp: 消息响应对象
ChatMessage: 消息响应对象
"""
# 生成消息ID和时间戳
message_id = await self._generate_message_id(channel_id)
@@ -57,28 +57,16 @@ class RedisMessageSystem:
if not user.id:
raise ValueError("User ID is required")
# 获取频道类型以判断是否需要存储到数据库
async with with_db() as session:
from app.database.chat import ChannelType, ChatChannel
from sqlmodel import select
channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id))
channel_type = channel_result.first()
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
# 准备消息数据
message_data = {
message_data: "ChatMessageDict" = {
"message_id": message_id,
"channel_id": channel_id,
"sender_id": user.id,
"content": content,
"timestamp": timestamp.isoformat(),
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
"timestamp": timestamp,
"type": MessageType.ACTION if is_action else MessageType.PLAIN,
"uuid": user_uuid or "",
"status": "cached", # Redis 缓存状态
"created_at": time.time(),
"is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
"is_action": is_action,
}
# 立即存储到 Redis
@@ -86,51 +74,13 @@ class RedisMessageSystem:
# 创建响应对象
async with with_db() as session:
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
user_resp = await UserModel.transform(user, session=session, includes=User.LIST_INCLUDES)
message_data["sender"] = user_resp
# 确保 statistics 不为空
if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
return message_data
user_resp.statistics = UserStatisticsResp(
mode=user.playmode,
global_rank=0,
country_rank=0,
pp=0.0,
ranked_score=0,
hit_accuracy=0.0,
play_count=0,
play_time=0,
total_score=0,
total_hits=0,
maximum_combo=0,
replays_watched_by_others=0,
is_ranked=False,
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
level={"current": 1, "progress": 0},
)
response = ChatMessageResp(
message_id=message_id,
channel_id=channel_id,
content=content,
timestamp=timestamp,
sender_id=user.id,
sender=user_resp,
is_action=is_action,
uuid=user_uuid,
)
if is_multiplayer:
logger.info(
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id},"
" will not be persisted to database"
)
else:
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
return response
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]:
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
"""
获取频道消息 - 优先从 Redis 获取最新消息
@@ -140,9 +90,9 @@ class RedisMessageSystem:
since: 起始消息ID
Returns:
List[ChatMessageResp]: 消息列表
List[ChatMessageDict]: 消息列表
"""
messages = []
messages: list["ChatMessageDict"] = []
try:
# 从 Redis 获取最新消息
@@ -154,45 +104,21 @@ class RedisMessageSystem:
# 获取发送者信息
sender = await session.get(User, msg_data["sender_id"])
if sender:
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
user_resp = await UserModel.transform(sender, includes=User.LIST_INCLUDES)
if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp
from app.database.chat import ChatMessageDict
user_resp.statistics = UserStatisticsResp(
mode=sender.playmode,
global_rank=0,
country_rank=0,
pp=0.0,
ranked_score=0,
hit_accuracy=0.0,
play_count=0,
play_time=0,
total_score=0,
total_hits=0,
maximum_combo=0,
replays_watched_by_others=0,
is_ranked=False,
grade_counts={
"ssh": 0,
"ss": 0,
"sh": 0,
"s": 0,
"a": 0,
},
level={"current": 1, "progress": 0},
)
message_resp = ChatMessageResp(
message_id=msg_data["message_id"],
channel_id=msg_data["channel_id"],
content=msg_data["content"],
timestamp=datetime.fromisoformat(msg_data["timestamp"]),
sender_id=msg_data["sender_id"],
sender=user_resp,
is_action=msg_data["type"] == MessageType.ACTION.value,
uuid=msg_data.get("uuid") or None,
)
message_resp: ChatMessageDict = {
"message_id": msg_data["message_id"],
"channel_id": msg_data["channel_id"],
"content": msg_data["content"],
"timestamp": datetime.fromisoformat(msg_data["timestamp"]), # pyright: ignore[reportArgumentType]
"sender_id": msg_data["sender_id"],
"sender": user_resp,
"is_action": msg_data["type"] == MessageType.ACTION.value,
"uuid": msg_data.get("uuid") or None,
"type": MessageType(msg_data["type"]),
}
messages.append(message_resp)
# 如果 Redis 消息不够,从数据库补充
@@ -216,86 +142,46 @@ class RedisMessageSystem:
return message_id
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]):
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: ChatMessageDict):
"""存储消息到 Redis"""
try:
# 检查是否是多人房间消息
is_multiplayer = message_data.get("is_multiplayer", False)
# 存储消息数据
await self.redis.hset(
# 存储消息数据为 JSON 字符串
await self.redis.set(
f"msg:{channel_id}:{message_id}",
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in message_data.items()},
safe_json_dumps(message_data),
ex=604800, # 7天过期
)
# 设置消息过期时间7天
await self.redis.expire(f"msg:{channel_id}:{message_id}", 604800)
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
# 添加到频道消息列表(按时间排序
channel_messages_key = f"channel:{channel_id}:messages"
# 更健壮的键类型检查清理
# 检查清理错误类型的键
try:
key_type = await self.redis.type(channel_messages_key)
if key_type == "none":
# 键不存在,这是正常的
pass
elif key_type != "zset":
# 键类型错误,需要清理
if key_type not in ("none", "zset"):
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
await self.redis.delete(channel_messages_key)
# 验证删除是否成功
verify_type = await self.redis.type(channel_messages_key)
if verify_type != "none":
logger.error(
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
)
# 强制删除
await self.redis.unlink(channel_messages_key)
except Exception as type_check_error:
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
# 如果检查失败,尝试强制删除键以确保清理
try:
await self.redis.delete(channel_messages_key)
except Exception:
# 最后的努力使用unlink
try:
await self.redis.unlink(channel_messages_key)
except Exception as final_error:
logger.error(f"Critical: Unable to clear problematic key {channel_messages_key}: {final_error}")
await self.redis.delete(channel_messages_key)
# 添加到频道消息列表sorted set
try:
await self.redis.zadd(
channel_messages_key,
mapping={f"msg:{channel_id}:{message_id}": message_id},
)
except Exception as zadd_error:
logger.error(f"Failed to add message to sorted set {channel_messages_key}: {zadd_error}")
# 如果添加失败,再次尝试清理并重试
await self.redis.delete(channel_messages_key)
await self.redis.zadd(
channel_messages_key,
mapping={f"msg:{channel_id}:{message_id}": message_id},
)
await self.redis.zadd(
channel_messages_key,
mapping={f"msg:{channel_id}:{message_id}": message_id},
)
# 保持频道消息列表大小最多1000条
await self.redis.zremrangebyrank(channel_messages_key, 0, -1001)
# 只有非多人房间消息才添加到待持久化队列
if not is_multiplayer:
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
logger.debug(f"Message {message_id} added to persistence queue")
else:
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
logger.debug(f"Message {message_id} added to persistence queue")
except Exception as e:
logger.error(f"Failed to store message to Redis: {e}")
raise
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]:
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
"""从 Redis 获取消息"""
try:
# 获取消息键列表按消息ID排序
@@ -314,28 +200,16 @@ class RedisMessageSystem:
messages = []
for key in message_keys:
# 获取消息数据
raw_data = await self.redis.hgetall(key)
# 获取消息数据JSON 字符串)
raw_data = await self.redis.get(key)
if raw_data:
# 解码数据
message_data: dict[str, Any] = {}
for k, v in raw_data.items():
# 尝试解析 JSON
try:
if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
message_data[k] = json.loads(v)
elif k in ["message_id", "channel_id", "sender_id"]:
message_data[k] = int(v)
elif k == "is_multiplayer":
message_data[k] = v == "True"
elif k == "created_at":
message_data[k] = float(v)
else:
message_data[k] = v
except (json.JSONDecodeError, ValueError):
message_data[k] = v
messages.append(message_data)
try:
# 解析 JSON 字符串为字典
message_data = json.loads(raw_data)
messages.append(message_data)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode message JSON from {key}: {e}")
continue
# 确保消息按ID正序排序时间顺序
messages.sort(key=lambda x: x.get("message_id", 0))
@@ -350,15 +224,15 @@ class RedisMessageSystem:
logger.error(f"Failed to get messages from Redis: {e}")
return []
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int):
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageDict], limit: int):
"""从数据库补充历史消息"""
try:
# 找到最小的消息ID
min_id = float("inf")
if existing_messages:
for msg in existing_messages:
if msg.message_id is not None and msg.message_id < min_id:
min_id = msg.message_id
if msg["message_id"] is not None and msg["message_id"] < min_id:
min_id = msg["message_id"]
needed = limit - len(existing_messages)
@@ -378,13 +252,13 @@ class RedisMessageSystem:
db_messages = (await session.exec(query)).all()
for msg in reversed(db_messages): # 按时间正序插入
msg_resp = await ChatMessageResp.from_db(msg, session)
msg_resp = await ChatMessageModel.transform(msg, includes=["sender"])
existing_messages.insert(0, msg_resp)
except Exception as e:
logger.error(f"Failed to backfill from database: {e}")
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]:
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageDict]:
"""仅从数据库获取消息(回退方案)"""
try:
async with with_db() as session:
@@ -402,7 +276,7 @@ class RedisMessageSystem:
messages = (await session.exec(query)).all()
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
results = await ChatMessageModel.transform_many(messages, includes=["sender"])
# 如果是 since > 0保持正序否则反转为时间正序
if since == 0:
@@ -450,27 +324,17 @@ class RedisMessageSystem:
# 解析频道ID和消息ID
channel_id, message_id = map(int, key.split(":"))
# 从 Redis 获取消息数据
raw_data = await self.redis.hgetall(f"msg:{channel_id}:{message_id}")
# 从 Redis 获取消息数据JSON 字符串)
raw_data = await self.redis.get(f"msg:{channel_id}:{message_id}")
if not raw_data:
continue
# 解码数据
message_data = {}
for k, v in raw_data.items():
message_data[k] = v
# 检查是否是多人房间消息,如果是则跳过数据库存储
is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
if is_multiplayer:
# 多人房间消息不存储到数据库,直接标记为已跳过
await self.redis.hset(
f"msg:{channel_id}:{message_id}",
"status",
"skipped_multiplayer",
)
logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
# 解析 JSON 字符串为字典
try:
message_data = json.loads(raw_data)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode message JSON for {channel_id}:{message_id}: {e}")
continue
# 检查消息是否已存在于数据库
@@ -491,13 +355,6 @@ class RedisMessageSystem:
session.add(db_message)
# 更新 Redis 中的状态
await self.redis.hset(
f"msg:{channel_id}:{message_id}",
"status",
"persisted",
)
logger.debug(f"Message {message_id} persisted to database")
except Exception as e: