refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import Annotated, overload
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database import ChatMessageDict
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelDict, ChatChannelModel
|
||||
from app.database.notification import UserNotification, insert_notification
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import (
|
||||
@@ -16,7 +17,7 @@ from app.log import log
|
||||
from app.models.chat import ChatEvent
|
||||
from app.models.notification import NotificationDetail
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
from app.utils import bg_tasks
|
||||
from app.utils import bg_tasks, safe_json_dumps
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
@@ -65,7 +66,7 @@ class ChatServer:
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
||||
).first()
|
||||
if db_channel:
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
await self.leave_channel(user, db_channel)
|
||||
|
||||
@overload
|
||||
async def send_event(self, client: int, event: ChatEvent): ...
|
||||
@@ -80,7 +81,7 @@ class ChatServer:
|
||||
return
|
||||
client = client_
|
||||
if client.client_state == WebSocketState.CONNECTED:
|
||||
await client.send_text(event.model_dump_json())
|
||||
await client.send_text(safe_json_dumps(event))
|
||||
|
||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||
users_in_channel = self.channels.get(channel_id, [])
|
||||
@@ -107,38 +108,38 @@ class ChatServer:
|
||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||
|
||||
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
|
||||
async def send_message_to_channel(self, message: ChatMessageDict, is_bot_command: bool = False):
|
||||
logger.info(
|
||||
f"Sending message to channel {message.channel_id}, message_id: "
|
||||
f"{message.message_id}, is_bot_command: {is_bot_command}"
|
||||
f"Sending message to channel {message['channel_id']}, message_id: "
|
||||
f"{message['message_id']}, is_bot_command: {is_bot_command}"
|
||||
)
|
||||
|
||||
event = ChatEvent(
|
||||
event="chat.message.new",
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
data={"messages": [message], "users": [message["sender"]]}, # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||
)
|
||||
if is_bot_command:
|
||||
logger.info(f"Sending bot command to user {message.sender_id}")
|
||||
bg_tasks.add_task(self.send_event, message.sender_id, event)
|
||||
logger.info(f"Sending bot command to user {message['sender_id']}")
|
||||
bg_tasks.add_task(self.send_event, message["sender_id"], event)
|
||||
else:
|
||||
# 总是广播消息,无论是临时ID还是真实ID
|
||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||||
logger.info(f"Broadcasting message to all users in channel {message['channel_id']}")
|
||||
bg_tasks.add_task(
|
||||
self.broadcast,
|
||||
message.channel_id,
|
||||
message["channel_id"],
|
||||
event,
|
||||
)
|
||||
|
||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||
if message.message_id and message.message_id > 0:
|
||||
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||||
if message["message_id"] and message["message_id"] > 0:
|
||||
await self.mark_as_read(message["channel_id"], message["sender_id"], message["message_id"])
|
||||
await self.redis.set(f"chat:{message['channel_id']}:last_msg", message["message_id"])
|
||||
logger.info(f"Updated last message ID for channel {message['channel_id']} to {message['message_id']}")
|
||||
else:
|
||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||||
logger.debug(f"Skipping last message update for message ID: {message['message_id']}")
|
||||
|
||||
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
|
||||
async def batch_join_channel(self, users: list[User], channel: ChatChannel):
|
||||
channel_id = channel.channel_id
|
||||
|
||||
not_joined = []
|
||||
@@ -151,22 +152,18 @@ class ChatServer:
|
||||
not_joined.append(user)
|
||||
|
||||
for user in not_joined:
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
channel_resp = await ChatChannelModel.transform(
|
||||
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||
)
|
||||
await self.send_event(
|
||||
user.id,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
|
||||
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
|
||||
async def join_channel(self, user: User, channel: ChatChannel) -> ChatChannelDict:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
|
||||
@@ -175,25 +172,21 @@ class ChatServer:
|
||||
if user_id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user_id)
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
channel_resp: ChatChannelDict = await ChatChannelModel.transform(
|
||||
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||
)
|
||||
|
||||
await self.send_event(
|
||||
user_id,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
|
||||
return channel_resp
|
||||
|
||||
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
|
||||
async def leave_channel(self, user: User, channel: ChatChannel) -> None:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
|
||||
@@ -203,18 +196,14 @@ class ChatServer:
|
||||
if (c := self.channels.get(channel_id)) is not None and not c:
|
||||
del self.channels[channel_id]
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
|
||||
channel_resp = await ChatChannelModel.transform(
|
||||
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||
)
|
||||
await self.send_event(
|
||||
user_id,
|
||||
ChatEvent(
|
||||
event="chat.channel.part",
|
||||
data=channel_resp.model_dump(),
|
||||
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
|
||||
@@ -232,7 +221,7 @@ class ChatServer:
|
||||
return
|
||||
|
||||
logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})")
|
||||
await self.join_channel(user, db_channel, session)
|
||||
await self.join_channel(user, db_channel)
|
||||
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
@@ -248,7 +237,7 @@ class ChatServer:
|
||||
return
|
||||
|
||||
logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})")
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
await self.leave_channel(user, db_channel)
|
||||
|
||||
async def new_private_notification(self, detail: NotificationDetail):
|
||||
async with with_db() as session:
|
||||
@@ -336,6 +325,6 @@ async def chat_websocket(
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
|
||||
if db_channel is not None:
|
||||
await server.join_channel(user, db_channel, session)
|
||||
await server.join_channel(user, db_channel)
|
||||
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
|
||||
Reference in New Issue
Block a user