from datetime import datetime from enum import Enum from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict from app.models.model import UTCBaseModel from app.utils import utcnow from ._base import DatabaseModel, included, ondemand from .user import User, UserDict, UserModel from pydantic import BaseModel from sqlmodel import ( VARCHAR, BigInteger, Column, DateTime, Field, ForeignKey, Relationship, SQLModel, col, select, ) from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.router.notification.server import ChatServer # ChatChannel class ChatUserAttributes(BaseModel): can_message: bool can_message_error: str | None = None last_read_id: int class ChannelType(str, Enum): PUBLIC = "PUBLIC" PRIVATE = "PRIVATE" MULTIPLAYER = "MULTIPLAYER" SPECTATOR = "SPECTATOR" TEMPORARY = "TEMPORARY" PM = "PM" GROUP = "GROUP" SYSTEM = "SYSTEM" ANNOUNCE = "ANNOUNCE" TEAM = "TEAM" class MessageType(str, Enum): ACTION = "action" MARKDOWN = "markdown" PLAIN = "plain" class ChatChannelDict(TypedDict): channel_id: int description: str name: str icon: str | None type: ChannelType uuid: NotRequired[str | None] message_length_limit: NotRequired[int] moderated: NotRequired[bool] current_user_attributes: NotRequired[ChatUserAttributes] last_read_id: NotRequired[int | None] last_message_id: NotRequired[int | None] recent_messages: NotRequired[list["ChatMessageDict"]] users: NotRequired[list[int]] class ChatChannelModel(DatabaseModel[ChatChannelDict]): CONVERSATION_INCLUDES: ClassVar[list[str]] = [ "last_message_id", "users", ] LISTING_INCLUDES: ClassVar[list[str]] = [ *CONVERSATION_INCLUDES, "current_user_attributes", "last_read_id", ] channel_id: int = Field(primary_key=True, index=True, default=None) description: str = Field(sa_column=Column(VARCHAR(255), index=True)) icon: str | None = Field(default=None) type: ChannelType = Field(index=True) @included @staticmethod async def name(session: AsyncSession, channel: "ChatChannel", user: User, server: "ChatServer") -> str: users = server.channels.get(channel.channel_id, []) if channel.type == ChannelType.PM and users and len(users) == 2: target_user_id = next(u for u in users if u != user.id) target_name = await session.exec(select(User.username).where(User.id == target_user_id)) return target_name.one() return channel.name @included @staticmethod async def moderated(session: AsyncSession, channel: "ChatChannel", user: User) -> bool: silence = ( await session.exec( select(SilenceUser).where( SilenceUser.channel_id == channel.channel_id, SilenceUser.user_id == user.id, ) ) ).first() return silence is not None @ondemand @staticmethod async def current_user_attributes( session: AsyncSession, channel: "ChatChannel", user: User, ) -> ChatUserAttributes: from app.dependencies.database import get_redis silence = ( await session.exec( select(SilenceUser).where( SilenceUser.channel_id == channel.channel_id, SilenceUser.user_id == user.id, ) ) ).first() can_message = silence is None can_message_error = "You are silenced in this channel" if not can_message else None redis = get_redis() last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}") last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg") last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else (last_msg or 0) return ChatUserAttributes( can_message=can_message, can_message_error=can_message_error, last_read_id=last_read_id, ) @ondemand @staticmethod async def last_read_id(_session: AsyncSession, channel: "ChatChannel", user: User) -> int | None: from app.dependencies.database import get_redis redis = get_redis() last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}") last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg") last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None return int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg @ondemand @staticmethod async def last_message_id(_session: AsyncSession, channel: "ChatChannel") -> int | None: from app.dependencies.database import get_redis redis = get_redis() last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg") return int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None @ondemand @staticmethod async def recent_messages( session: AsyncSession, channel: "ChatChannel", ) -> list["ChatMessageDict"]: messages = ( await session.exec( select(ChatMessage) .where(ChatMessage.channel_id == channel.channel_id) .order_by(col(ChatMessage.message_id).desc()) .limit(50) ) ).all() result = [ await ChatMessageModel.transform( msg, ) for msg in reversed(messages) ] return result @ondemand @staticmethod async def users( _session: AsyncSession, channel: "ChatChannel", server: "ChatServer", user: User, ) -> list[int]: if channel.type == ChannelType.PUBLIC: return [] users = server.channels.get(channel.channel_id, []).copy() if channel.type == ChannelType.PM and users and len(users) == 2: target_user_id = next(u for u in users if u != user.id) users = [target_user_id, user.id] return users @included @staticmethod async def message_length_limit(_session: AsyncSession, _channel: "ChatChannel") -> int: return 1000 class ChatChannel(ChatChannelModel, table=True): __tablename__: str = "chat_channels" name: str = Field(sa_column=Column(VARCHAR(50), index=True)) @classmethod async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None": if isinstance(channel, int) or channel.isdigit(): # 使用查询而不是 get() 来确保对象完全加载 result = await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel))) channel_ = result.first() if channel_ is not None: return channel_ result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) return result.first() @classmethod async def get_pm_channel(cls, user1: int, user2: int, session: AsyncSession) -> "ChatChannel | None": channel = await cls.get(f"pm_{user1}_{user2}", session) if channel is None: channel = await cls.get(f"pm_{user2}_{user1}", session) return channel # ChatMessage class ChatMessageDict(TypedDict): channel_id: int content: str message_id: int sender_id: int timestamp: datetime type: MessageType uuid: str | None is_action: NotRequired[bool] sender: NotRequired[UserDict] class ChatMessageModel(DatabaseModel[ChatMessageDict]): channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id") content: str = Field(sa_column=Column(VARCHAR(1000))) message_id: int = Field(index=True, primary_key=True, default=None) sender_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) timestamp: datetime = Field(sa_column=Column(DateTime, index=True), default_factory=utcnow) type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True) uuid: str | None = Field(default=None) @included @staticmethod async def is_action(_session: AsyncSession, db_message: "ChatMessage") -> bool: return db_message.type == MessageType.ACTION @ondemand @staticmethod async def sender(_session: AsyncSession, db_message: "ChatMessage") -> UserDict: return await UserModel.transform(db_message.user) class ChatMessage(ChatMessageModel, table=True): __tablename__: str = "chat_messages" user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) channel: "ChatChannel" = Relationship() class SilenceUser(UTCBaseModel, SQLModel, table=True): __tablename__: str = "chat_silence_users" id: int = Field(primary_key=True, default=None, index=True) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True) until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None) reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True)) banned_at: datetime = Field(sa_column=Column(DateTime, index=True), default_factory=utcnow) class UserSilenceResp(SQLModel): id: int user_id: int @classmethod def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp": return cls( id=db_silence.id, user_id=db_silence.user_id, )