refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Self
|
||||
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
|
||||
|
||||
from app.database.user import RANKING_INCLUDES, User, UserResp
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.utils import utcnow
|
||||
|
||||
from ._base import DatabaseModel, included, ondemand
|
||||
from .user import User, UserDict, UserModel
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import (
|
||||
VARCHAR,
|
||||
BigInteger,
|
||||
@@ -22,6 +23,8 @@ from sqlmodel import (
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.router.notification.server import ChatServer
|
||||
# ChatChannel
|
||||
|
||||
|
||||
@@ -44,16 +47,168 @@ class ChannelType(str, Enum):
|
||||
TEAM = "TEAM"
|
||||
|
||||
|
||||
class ChatChannelBase(SQLModel):
|
||||
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
|
||||
class MessageType(str, Enum):
|
||||
ACTION = "action"
|
||||
MARKDOWN = "markdown"
|
||||
PLAIN = "plain"
|
||||
|
||||
|
||||
class ChatChannelDict(TypedDict):
|
||||
channel_id: int
|
||||
description: str
|
||||
name: str
|
||||
icon: str | None
|
||||
type: ChannelType
|
||||
uuid: NotRequired[str | None]
|
||||
message_length_limit: NotRequired[int]
|
||||
moderated: NotRequired[bool]
|
||||
current_user_attributes: NotRequired[ChatUserAttributes]
|
||||
last_read_id: NotRequired[int | None]
|
||||
last_message_id: NotRequired[int | None]
|
||||
recent_messages: NotRequired[list["ChatMessageDict"]]
|
||||
users: NotRequired[list[int]]
|
||||
|
||||
|
||||
class ChatChannelModel(DatabaseModel[ChatChannelDict]):
|
||||
CONVERSATION_INCLUDES: ClassVar[list[str]] = [
|
||||
"last_message_id",
|
||||
"users",
|
||||
]
|
||||
LISTING_INCLUDES: ClassVar[list[str]] = [
|
||||
*CONVERSATION_INCLUDES,
|
||||
"current_user_attributes",
|
||||
"last_read_id",
|
||||
]
|
||||
|
||||
channel_id: int = Field(primary_key=True, index=True, default=None)
|
||||
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
|
||||
icon: str | None = Field(default=None)
|
||||
type: ChannelType = Field(index=True)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def name(session: AsyncSession, channel: "ChatChannel", user: User, server: "ChatServer") -> str:
|
||||
users = server.channels.get(channel.channel_id, [])
|
||||
if channel.type == ChannelType.PM and users and len(users) == 2:
|
||||
target_user_id = next(u for u in users if u != user.id)
|
||||
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
|
||||
return target_name.one()
|
||||
return channel.name
|
||||
|
||||
class ChatChannel(ChatChannelBase, table=True):
|
||||
@included
|
||||
@staticmethod
|
||||
async def moderated(session: AsyncSession, channel: "ChatChannel", user: User) -> bool:
|
||||
silence = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
SilenceUser.channel_id == channel.channel_id,
|
||||
SilenceUser.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
return silence is not None
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def current_user_attributes(
|
||||
session: AsyncSession,
|
||||
channel: "ChatChannel",
|
||||
user: User,
|
||||
) -> ChatUserAttributes:
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
silence = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
SilenceUser.channel_id == channel.channel_id,
|
||||
SilenceUser.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
can_message = silence is None
|
||||
can_message_error = "You are silenced in this channel" if not can_message else None
|
||||
|
||||
redis = get_redis()
|
||||
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else (last_msg or 0)
|
||||
|
||||
return ChatUserAttributes(
|
||||
can_message=can_message,
|
||||
can_message_error=can_message_error,
|
||||
last_read_id=last_read_id,
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def last_read_id(_session: AsyncSession, channel: "ChatChannel", user: User) -> int | None:
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
redis = get_redis()
|
||||
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||
return int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def last_message_id(_session: AsyncSession, channel: "ChatChannel") -> int | None:
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
redis = get_redis()
|
||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
return int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def recent_messages(
|
||||
session: AsyncSession,
|
||||
channel: "ChatChannel",
|
||||
) -> list["ChatMessageDict"]:
|
||||
messages = (
|
||||
await session.exec(
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.channel_id == channel.channel_id)
|
||||
.order_by(col(ChatMessage.message_id).desc())
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
result = [
|
||||
await ChatMessageModel.transform(
|
||||
msg,
|
||||
)
|
||||
for msg in reversed(messages)
|
||||
]
|
||||
return result
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def users(
|
||||
_session: AsyncSession,
|
||||
channel: "ChatChannel",
|
||||
server: "ChatServer",
|
||||
user: User,
|
||||
) -> list[int]:
|
||||
if channel.type == ChannelType.PUBLIC:
|
||||
return []
|
||||
users = server.channels.get(channel.channel_id, []).copy()
|
||||
if channel.type == ChannelType.PM and users and len(users) == 2:
|
||||
target_user_id = next(u for u in users if u != user.id)
|
||||
users = [target_user_id, user.id]
|
||||
return users
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def message_length_limit(_session: AsyncSession, _channel: "ChatChannel") -> int:
|
||||
return 1000
|
||||
|
||||
|
||||
class ChatChannel(ChatChannelModel, table=True):
|
||||
__tablename__: str = "chat_channels"
|
||||
channel_id: int = Field(primary_key=True, index=True, default=None)
|
||||
|
||||
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
|
||||
|
||||
@classmethod
|
||||
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
|
||||
@@ -74,93 +229,20 @@ class ChatChannel(ChatChannelBase, table=True):
|
||||
return channel
|
||||
|
||||
|
||||
class ChatChannelResp(ChatChannelBase):
|
||||
channel_id: int
|
||||
moderated: bool = False
|
||||
uuid: str | None = None
|
||||
current_user_attributes: ChatUserAttributes | None = None
|
||||
last_read_id: int | None = None
|
||||
last_message_id: int | None = None
|
||||
recent_messages: list["ChatMessageResp"] = Field(default_factory=list)
|
||||
users: list[int] = Field(default_factory=list)
|
||||
message_length_limit: int = 1000
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
channel: ChatChannel,
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
redis: Redis,
|
||||
users: list[int] | None = None,
|
||||
include_recent_messages: bool = False,
|
||||
) -> Self:
|
||||
c = cls.model_validate(channel)
|
||||
silence = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
SilenceUser.channel_id == channel.channel_id,
|
||||
SilenceUser.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||
|
||||
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
|
||||
|
||||
if silence is not None:
|
||||
attribute = ChatUserAttributes(
|
||||
can_message=False,
|
||||
can_message_error=silence.reason or "You are muted in this channel.",
|
||||
last_read_id=last_read_id or 0,
|
||||
)
|
||||
c.moderated = True
|
||||
else:
|
||||
attribute = ChatUserAttributes(
|
||||
can_message=True,
|
||||
last_read_id=last_read_id or 0,
|
||||
)
|
||||
c.moderated = False
|
||||
|
||||
c.current_user_attributes = attribute
|
||||
if c.type != ChannelType.PUBLIC and users is not None:
|
||||
c.users = users
|
||||
c.last_message_id = last_msg
|
||||
c.last_read_id = last_read_id
|
||||
|
||||
if include_recent_messages:
|
||||
messages = (
|
||||
await session.exec(
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.channel_id == channel.channel_id)
|
||||
.order_by(col(ChatMessage.timestamp).desc())
|
||||
.limit(10)
|
||||
)
|
||||
).all()
|
||||
c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
|
||||
c.recent_messages.reverse()
|
||||
|
||||
if c.type == ChannelType.PM and users and len(users) == 2:
|
||||
target_user_id = next(u for u in users if u != user.id)
|
||||
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
|
||||
c.name = target_name.one()
|
||||
c.users = [target_user_id, user.id]
|
||||
return c
|
||||
|
||||
|
||||
# ChatMessage
|
||||
class ChatMessageDict(TypedDict):
|
||||
channel_id: int
|
||||
content: str
|
||||
message_id: int
|
||||
sender_id: int
|
||||
timestamp: datetime
|
||||
type: MessageType
|
||||
uuid: str | None
|
||||
is_action: NotRequired[bool]
|
||||
sender: NotRequired[UserDict]
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
ACTION = "action"
|
||||
MARKDOWN = "markdown"
|
||||
PLAIN = "plain"
|
||||
|
||||
|
||||
class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||
class ChatMessageModel(DatabaseModel[ChatMessageDict]):
|
||||
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
||||
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
||||
message_id: int = Field(index=True, primary_key=True, default=None)
|
||||
@@ -169,31 +251,21 @@ class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
||||
uuid: str | None = Field(default=None)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def is_action(_session: AsyncSession, db_message: "ChatMessage") -> bool:
|
||||
return db_message.type == MessageType.ACTION
|
||||
|
||||
class ChatMessage(ChatMessageBase, table=True):
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def sender(_session: AsyncSession, db_message: "ChatMessage") -> UserDict:
|
||||
return await UserModel.transform(db_message.user)
|
||||
|
||||
|
||||
class ChatMessage(ChatMessageModel, table=True):
|
||||
__tablename__: str = "chat_messages"
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
channel: ChatChannel = Relationship()
|
||||
|
||||
|
||||
class ChatMessageResp(ChatMessageBase):
|
||||
sender: UserResp | None = None
|
||||
is_action: bool = False
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
|
||||
) -> "ChatMessageResp":
|
||||
m = cls.model_validate(db_message.model_dump())
|
||||
m.is_action = db_message.type == MessageType.ACTION
|
||||
if user:
|
||||
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
else:
|
||||
m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
|
||||
return m
|
||||
|
||||
|
||||
# SilenceUser
|
||||
channel: "ChatChannel" = Relationship()
|
||||
|
||||
|
||||
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
||||
|
||||
Reference in New Issue
Block a user