feat(chat): support public channel chat
This commit is contained in:
193
app/database/chat.py
Normal file
193
app/database/chat.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Self
|
||||
|
||||
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import (
|
||||
VARCHAR,
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
# ChatChannel
|
||||
|
||||
|
||||
class ChatUserAttributes(BaseModel):
|
||||
can_message: bool
|
||||
can_message_error: str | None = None
|
||||
last_read_id: int
|
||||
|
||||
|
||||
class ChannelType(str, Enum):
|
||||
PUBLIC = "PUBLIC"
|
||||
PRIVATE = "PRIVATE"
|
||||
MULTIPLAYER = "MULTIPLAYER"
|
||||
SPECTATOR = "SPECTATOR"
|
||||
TEMPORARY = "TEMPORARY"
|
||||
PM = "PM"
|
||||
GROUP = "GROUP"
|
||||
SYSTEM = "SYSTEM"
|
||||
ANNOUNCE = "ANNOUNCE"
|
||||
TEAM = "TEAM"
|
||||
|
||||
|
||||
class ChatChannelBase(SQLModel):
|
||||
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
|
||||
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
|
||||
icon: str | None = Field(default=None)
|
||||
type: ChannelType = Field(index=True)
|
||||
|
||||
|
||||
class ChatChannel(ChatChannelBase, table=True):
|
||||
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType]
|
||||
channel_id: int | None = Field(primary_key=True, index=True, default=None)
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls, channel: str | int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
if isinstance(channel, int) or channel.isdigit():
|
||||
channel_ = await session.get(ChatChannel, channel)
|
||||
if channel_ is not None:
|
||||
return channel_
|
||||
return (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
|
||||
|
||||
class ChatChannelResp(ChatChannelBase):
|
||||
channel_id: int
|
||||
moderated: bool = False
|
||||
uuid: str | None = None
|
||||
current_user_attributes: ChatUserAttributes | None = None
|
||||
last_read_id: int | None = None
|
||||
last_message_id: int | None = None
|
||||
recent_messages: list[str] | None = None
|
||||
users: list[int] | None = None
|
||||
message_length_limit: int = 1000
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
channel: ChatChannel,
|
||||
session: AsyncSession,
|
||||
users: list[int],
|
||||
user: User,
|
||||
redis: Redis,
|
||||
) -> Self:
|
||||
c = cls.model_validate(channel)
|
||||
silence = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
SilenceUser.channel_id == channel.channel_id,
|
||||
SilenceUser.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
if last_msg and last_msg.isdigit():
|
||||
last_msg = int(last_msg)
|
||||
else:
|
||||
last_msg = None
|
||||
|
||||
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
if last_read_id and last_read_id.isdigit():
|
||||
last_read_id = int(last_read_id)
|
||||
else:
|
||||
last_read_id = last_msg
|
||||
|
||||
if silence is not None:
|
||||
attribute = ChatUserAttributes(
|
||||
can_message=False,
|
||||
can_message_error=silence.reason or "You are muted in this channel.",
|
||||
last_read_id=last_read_id or 0,
|
||||
)
|
||||
c.moderated = True
|
||||
else:
|
||||
attribute = ChatUserAttributes(
|
||||
can_message=True,
|
||||
last_read_id=last_read_id or 0,
|
||||
)
|
||||
c.moderated = False
|
||||
|
||||
c.current_user_attributes = attribute
|
||||
c.users = users
|
||||
c.last_message_id = last_msg
|
||||
c.last_read_id = last_read_id
|
||||
return c
|
||||
|
||||
|
||||
# ChatMessage
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
ACTION = "action"
|
||||
MARKDOWN = "markdown"
|
||||
PLAIN = "plain"
|
||||
|
||||
|
||||
class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
||||
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
||||
message_id: int | None = Field(index=True, primary_key=True, default=None)
|
||||
sender_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
||||
uuid: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ChatMessage(ChatMessageBase, table=True):
|
||||
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType]
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
|
||||
class ChatMessageResp(ChatMessageBase):
|
||||
sender: UserResp | None = None
|
||||
is_action: bool = False
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
|
||||
) -> "ChatMessageResp":
|
||||
m = cls.model_validate(db_message.model_dump())
|
||||
m.is_action = db_message.type == MessageType.ACTION
|
||||
if user:
|
||||
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
else:
|
||||
m.sender = await UserResp.from_db(
|
||||
db_message.user, session, RANKING_INCLUDES
|
||||
)
|
||||
return m
|
||||
|
||||
|
||||
# SilenceUser
|
||||
|
||||
|
||||
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(primary_key=True, default=None, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
|
||||
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
|
||||
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
|
||||
banned_at: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
Reference in New Issue
Block a user