refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -51,30 +51,22 @@ class ChatChannelBase(SQLModel):
|
||||
|
||||
|
||||
class ChatChannel(ChatChannelBase, table=True):
|
||||
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType]
|
||||
channel_id: int | None = Field(primary_key=True, index=True, default=None)
|
||||
__tablename__: str = "chat_channels"
|
||||
channel_id: int = Field(primary_key=True, index=True, default=None)
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls, channel: str | int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
|
||||
if isinstance(channel, int) or channel.isdigit():
|
||||
# 使用查询而不是 get() 来确保对象完全加载
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))
|
||||
channel_ = result.first()
|
||||
if channel_ is not None:
|
||||
return channel_
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
return result.first()
|
||||
|
||||
@classmethod
|
||||
async def get_pm_channel(
|
||||
cls, user1: int, user2: int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
async def get_pm_channel(cls, user1: int, user2: int, session: AsyncSession) -> "ChatChannel | None":
|
||||
channel = await cls.get(f"pm_{user1}_{user2}", session)
|
||||
if channel is None:
|
||||
channel = await cls.get(f"pm_{user2}_{user1}", session)
|
||||
@@ -153,18 +145,13 @@ class ChatChannelResp(ChatChannelBase):
|
||||
.limit(10)
|
||||
)
|
||||
).all()
|
||||
c.recent_messages = [
|
||||
await ChatMessageResp.from_db(msg, session, user) for msg in messages
|
||||
]
|
||||
c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
|
||||
c.recent_messages.reverse()
|
||||
|
||||
if c.type == ChannelType.PM and users and len(users) == 2:
|
||||
target_user_id = next(u for u in users if u != user.id)
|
||||
target_name = await session.exec(
|
||||
select(User.username).where(User.id == target_user_id)
|
||||
)
|
||||
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
|
||||
c.name = target_name.one()
|
||||
assert user.id
|
||||
c.users = [target_user_id, user.id]
|
||||
return c
|
||||
|
||||
@@ -181,19 +168,15 @@ class MessageType(str, Enum):
|
||||
class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
||||
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
||||
message_id: int | None = Field(index=True, primary_key=True, default=None)
|
||||
sender_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
message_id: int = Field(index=True, primary_key=True, default=None)
|
||||
sender_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
timestamp: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
|
||||
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
||||
uuid: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ChatMessage(ChatMessageBase, table=True):
|
||||
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "chat_messages"
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
channel: ChatChannel = Relationship()
|
||||
|
||||
@@ -211,9 +194,7 @@ class ChatMessageResp(ChatMessageBase):
|
||||
if user:
|
||||
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
else:
|
||||
m.sender = await UserResp.from_db(
|
||||
db_message.user, session, RANKING_INCLUDES
|
||||
)
|
||||
m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
|
||||
return m
|
||||
|
||||
|
||||
@@ -221,17 +202,13 @@ class ChatMessageResp(ChatMessageBase):
|
||||
|
||||
|
||||
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(primary_key=True, default=None, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
__tablename__: str = "chat_silence_users"
|
||||
id: int = Field(primary_key=True, default=None, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
|
||||
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
|
||||
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
|
||||
banned_at: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
banned_at: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
|
||||
|
||||
|
||||
class UserSilenceResp(SQLModel):
|
||||
@@ -240,7 +217,6 @@ class UserSilenceResp(SQLModel):
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp":
|
||||
assert db_silence.id is not None
|
||||
return cls(
|
||||
id=db_silence.id,
|
||||
user_id=db_silence.user_id,
|
||||
|
||||
Reference in New Issue
Block a user