refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -3,11 +3,14 @@ from collections.abc import Awaitable, Callable
|
||||
from math import ceil
|
||||
import random
|
||||
import shlex
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.calculator import calculate_weighted_pp
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatMessage, ChatMessageModel, MessageType
|
||||
from app.database.score import Score, get_best_id
|
||||
from app.database.statistics import UserStatistics, get_rank
|
||||
from app.database.user import User
|
||||
@@ -95,7 +98,7 @@ class Bot:
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
await session.refresh(bot)
|
||||
resp = await ChatMessageResp.from_db(msg, session, bot)
|
||||
resp = await ChatMessageModel.transform(msg, includes=["sender"])
|
||||
await server.send_message_to_channel(resp)
|
||||
|
||||
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
|
||||
@@ -119,7 +122,7 @@ class Bot:
|
||||
await session.refresh(channel)
|
||||
await session.refresh(user)
|
||||
await session.refresh(bot)
|
||||
await server.batch_join_channel([user, bot], channel, session)
|
||||
await server.batch_join_channel([user, bot], channel)
|
||||
return channel
|
||||
|
||||
async def _send_reply(
|
||||
|
||||
@@ -1,37 +1,40 @@
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
from typing import Annotated, Literal, Self
|
||||
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatChannelModel,
|
||||
ChatMessage,
|
||||
SilenceUser,
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.user import User, UserResp
|
||||
from app.database.user import User, UserModel
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
from app.utils import api_doc
|
||||
|
||||
from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlmodel import col, select
|
||||
|
||||
|
||||
class UpdateResponse(BaseModel):
|
||||
presence: list[ChatChannelResp] = Field(default_factory=list)
|
||||
silences: list[Any] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/updates",
|
||||
response_model=UpdateResponse,
|
||||
name="获取更新",
|
||||
description="获取当前用户所在频道的最新的禁言情况。",
|
||||
tags=["聊天"],
|
||||
responses={
|
||||
200: api_doc(
|
||||
"获取更新响应。",
|
||||
{"presence": list[ChatChannelModel], "silences": list[UserSilenceResp]},
|
||||
ChatChannel.LISTING_INCLUDES,
|
||||
name="UpdateResponse",
|
||||
)
|
||||
},
|
||||
)
|
||||
async def get_update(
|
||||
session: Database,
|
||||
@@ -44,45 +47,44 @@ async def get_update(
|
||||
Query(alias="includes[]", description="要包含的更新类型"),
|
||||
] = ["presence", "silences"],
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
resp = {
|
||||
"presence": [],
|
||||
"silences": [],
|
||||
}
|
||||
if "presence" in includes:
|
||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||
for channel_id in channel_ids:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_type = db_channel.type
|
||||
|
||||
resp.presence.append(
|
||||
await ChatChannelResp.from_db(
|
||||
resp["presence"].append(
|
||||
await ChatChannelModel.transform(
|
||||
db_channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
user=current_user,
|
||||
server=server,
|
||||
includes=ChatChannel.LISTING_INCLUDES,
|
||||
)
|
||||
)
|
||||
if "silences" in includes:
|
||||
if history_since:
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
|
||||
).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
return resp
|
||||
|
||||
|
||||
@router.put(
|
||||
"/chat/channels/{channel}/users/{user}",
|
||||
response_model=ChatChannelResp,
|
||||
name="加入频道",
|
||||
description="加入指定的公开/房间频道。",
|
||||
tags=["聊天"],
|
||||
responses={200: api_doc("加入的频道", ChatChannelModel, ChatChannel.LISTING_INCLUDES)},
|
||||
)
|
||||
async def join_channel(
|
||||
session: Database,
|
||||
@@ -101,7 +103,7 @@ async def join_channel(
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return await server.join_channel(current_user, db_channel, session)
|
||||
return await server.join_channel(current_user, db_channel)
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -128,13 +130,13 @@ async def leave_channel(
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
await server.leave_channel(current_user, db_channel, session)
|
||||
await server.leave_channel(current_user, db_channel)
|
||||
return
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/channels",
|
||||
response_model=list[ChatChannelResp],
|
||||
responses={200: api_doc("加入的频道", list[ChatChannelModel])},
|
||||
name="获取频道列表",
|
||||
description="获取所有公开频道。",
|
||||
tags=["聊天"],
|
||||
@@ -142,35 +144,30 @@ async def leave_channel(
|
||||
async def get_channel_list(
|
||||
session: Database,
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
redis: Redis,
|
||||
):
|
||||
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
|
||||
results = []
|
||||
for channel in channels:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
channel_type = channel.type
|
||||
results = await ChatChannelModel.transform_many(
|
||||
channels,
|
||||
user=current_user,
|
||||
server=server,
|
||||
)
|
||||
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class GetChannelResp(BaseModel):
|
||||
channel: ChatChannelResp
|
||||
users: list[UserResp] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/channels/{channel}",
|
||||
response_model=GetChannelResp,
|
||||
responses={
|
||||
200: api_doc(
|
||||
"频道详细信息",
|
||||
{
|
||||
"channel": ChatChannelModel,
|
||||
"users": list[UserModel],
|
||||
},
|
||||
ChatChannel.LISTING_INCLUDES + User.CARD_INCLUDES,
|
||||
name="GetChannelResponse",
|
||||
)
|
||||
},
|
||||
name="获取频道信息",
|
||||
description="获取指定频道的信息。",
|
||||
tags=["聊天"],
|
||||
@@ -191,7 +188,6 @@ async def get_channel(
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
@@ -209,15 +205,15 @@ async def get_channel(
|
||||
users.extend([target_user, current_user])
|
||||
break
|
||||
|
||||
return GetChannelResp(
|
||||
channel=await ChatChannelResp.from_db(
|
||||
return {
|
||||
"channel": await ChatChannelModel.transform(
|
||||
db_channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
user=current_user,
|
||||
server=server,
|
||||
includes=ChatChannel.LISTING_INCLUDES,
|
||||
),
|
||||
"users": await UserModel.transform_many(users, includes=User.CARD_INCLUDES),
|
||||
}
|
||||
|
||||
|
||||
class CreateChannelReq(BaseModel):
|
||||
@@ -244,7 +240,7 @@ class CreateChannelReq(BaseModel):
|
||||
|
||||
@router.post(
|
||||
"/chat/channels",
|
||||
response_model=ChatChannelResp,
|
||||
responses={200: api_doc("创建的频道", ChatChannelModel, ["recent_messages.sender"])},
|
||||
name="创建频道",
|
||||
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
|
||||
tags=["聊天"],
|
||||
@@ -289,21 +285,13 @@ async def create_channel(
|
||||
await session.refresh(current_user)
|
||||
if req.type == "PM":
|
||||
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
await server.batch_join_channel([target, current_user], channel) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
else:
|
||||
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
await server.batch_join_channel([*target_users, current_user], channel)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
await server.join_channel(current_user, channel)
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, []),
|
||||
include_recent_messages=True,
|
||||
return await ChatChannelModel.transform(
|
||||
channel, user=current_user, server=server, includes=["recent_messages.sender"]
|
||||
)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import ChatMessageResp
|
||||
from app.database import ChatChannelModel
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatMessage,
|
||||
ChatMessageModel,
|
||||
MessageType,
|
||||
SilenceUser,
|
||||
UserSilenceResp,
|
||||
@@ -18,6 +18,7 @@ from app.log import log
|
||||
from app.models.notification import ChannelMessage, ChannelMessageTeam
|
||||
from app.router.v2 import api_v2_router as router
|
||||
from app.service.redis_message_system import redis_message_system
|
||||
from app.utils import api_doc
|
||||
|
||||
from .banchobot import bot
|
||||
from .server import server
|
||||
@@ -68,7 +69,7 @@ class MessageReq(BaseModel):
|
||||
|
||||
@router.post(
|
||||
"/chat/channels/{channel}/messages",
|
||||
response_model=ChatMessageResp,
|
||||
responses={200: api_doc("发送的消息", ChatMessageModel, ["sender", "is_action"])},
|
||||
name="发送消息",
|
||||
description="发送消息到指定频道。",
|
||||
tags=["聊天"],
|
||||
@@ -130,7 +131,7 @@ async def send_message(
|
||||
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
|
||||
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
|
||||
temp_msg = ChatMessage(
|
||||
message_id=resp.message_id, # 使用 Redis 系统生成的ID
|
||||
message_id=resp["message_id"], # 使用 Redis 系统生成的ID
|
||||
channel_id=channel_id,
|
||||
content=req.message,
|
||||
sender_id=user_id,
|
||||
@@ -151,7 +152,7 @@ async def send_message(
|
||||
|
||||
@router.get(
|
||||
"/chat/channels/{channel}/messages",
|
||||
response_model=list[ChatMessageResp],
|
||||
responses={200: api_doc("获取的消息", list[ChatMessageModel], ["sender"])},
|
||||
name="获取消息",
|
||||
description="获取指定频道的消息列表(统一按时间正序返回)。",
|
||||
tags=["聊天"],
|
||||
@@ -177,7 +178,7 @@ async def get_message(
|
||||
|
||||
try:
|
||||
messages = await redis_message_system.get_messages(channel_id, limit, since)
|
||||
if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id:
|
||||
if len(messages) >= 2 and messages[0]["message_id"] > messages[-1]["message_id"]:
|
||||
messages.reverse()
|
||||
return messages
|
||||
except Exception as e:
|
||||
@@ -189,7 +190,7 @@ async def get_message(
|
||||
# 向前加载新消息 → 直接 ASC
|
||||
query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
||||
rows = (await session.exec(query)).all()
|
||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
||||
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
|
||||
# 已经 ASC,无需反转
|
||||
return resp
|
||||
|
||||
@@ -202,15 +203,14 @@ async def get_message(
|
||||
rows = (await session.exec(query)).all()
|
||||
rows = list(rows)
|
||||
rows.reverse() # 反转为 ASC
|
||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
||||
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
|
||||
return resp
|
||||
|
||||
query = base.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
rows = (await session.exec(query)).all()
|
||||
rows = list(rows)
|
||||
rows.reverse() # 反转为 ASC
|
||||
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
|
||||
return resp
|
||||
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
|
||||
return resp
|
||||
|
||||
|
||||
@@ -248,17 +248,23 @@ class PMReq(BaseModel):
|
||||
uuid: str | None = None
|
||||
|
||||
|
||||
class NewPMResp(BaseModel):
|
||||
channel: ChatChannelResp
|
||||
message: ChatMessageResp
|
||||
new_channel_id: int
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/new",
|
||||
name="创建私聊频道",
|
||||
description="创建一个新的私聊频道。",
|
||||
tags=["聊天"],
|
||||
responses={
|
||||
200: api_doc(
|
||||
"创建私聊频道响应",
|
||||
{
|
||||
"channel": ChatChannelModel,
|
||||
"message": ChatMessageModel,
|
||||
"new_channel_id": int,
|
||||
},
|
||||
["recent_messages.sender", "sender"],
|
||||
name="NewPMResponse",
|
||||
)
|
||||
},
|
||||
)
|
||||
async def create_new_pm(
|
||||
session: Database,
|
||||
@@ -290,9 +296,9 @@ async def create_new_pm(
|
||||
await session.refresh(target)
|
||||
await session.refresh(current_user)
|
||||
|
||||
await server.batch_join_channel([target, current_user], channel, session)
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, current_user, redis, server.channels[channel.channel_id]
|
||||
await server.batch_join_channel([target, current_user], channel)
|
||||
channel_resp = await ChatChannelModel.transform(
|
||||
channel, user=current_user, server=server, includes=["recent_messages.sender"]
|
||||
)
|
||||
msg = ChatMessage(
|
||||
channel_id=channel.channel_id,
|
||||
@@ -306,10 +312,10 @@ async def create_new_pm(
|
||||
await session.refresh(msg)
|
||||
await session.refresh(current_user)
|
||||
await session.refresh(channel)
|
||||
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
|
||||
message_resp = await ChatMessageModel.transform(msg, user=current_user, includes=["sender"])
|
||||
await server.send_message_to_channel(message_resp)
|
||||
return NewPMResp(
|
||||
channel=channel_resp,
|
||||
message=message_resp,
|
||||
new_channel_id=channel_resp.channel_id,
|
||||
)
|
||||
return {
|
||||
"channel": channel_resp,
|
||||
"message": message_resp,
|
||||
"new_channel_id": channel_resp["channel_id"],
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import Annotated, overload
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database import ChatMessageDict
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelDict, ChatChannelModel
|
||||
from app.database.notification import UserNotification, insert_notification
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import (
|
||||
@@ -16,7 +17,7 @@ from app.log import log
|
||||
from app.models.chat import ChatEvent
|
||||
from app.models.notification import NotificationDetail
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
from app.utils import bg_tasks
|
||||
from app.utils import bg_tasks, safe_json_dumps
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
@@ -65,7 +66,7 @@ class ChatServer:
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
||||
).first()
|
||||
if db_channel:
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
await self.leave_channel(user, db_channel)
|
||||
|
||||
@overload
|
||||
async def send_event(self, client: int, event: ChatEvent): ...
|
||||
@@ -80,7 +81,7 @@ class ChatServer:
|
||||
return
|
||||
client = client_
|
||||
if client.client_state == WebSocketState.CONNECTED:
|
||||
await client.send_text(event.model_dump_json())
|
||||
await client.send_text(safe_json_dumps(event))
|
||||
|
||||
async def broadcast(self, channel_id: int, event: ChatEvent):
|
||||
users_in_channel = self.channels.get(channel_id, [])
|
||||
@@ -107,38 +108,38 @@ class ChatServer:
|
||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||
|
||||
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
|
||||
async def send_message_to_channel(self, message: ChatMessageDict, is_bot_command: bool = False):
|
||||
logger.info(
|
||||
f"Sending message to channel {message.channel_id}, message_id: "
|
||||
f"{message.message_id}, is_bot_command: {is_bot_command}"
|
||||
f"Sending message to channel {message['channel_id']}, message_id: "
|
||||
f"{message['message_id']}, is_bot_command: {is_bot_command}"
|
||||
)
|
||||
|
||||
event = ChatEvent(
|
||||
event="chat.message.new",
|
||||
data={"messages": [message], "users": [message.sender]},
|
||||
data={"messages": [message], "users": [message["sender"]]}, # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||
)
|
||||
if is_bot_command:
|
||||
logger.info(f"Sending bot command to user {message.sender_id}")
|
||||
bg_tasks.add_task(self.send_event, message.sender_id, event)
|
||||
logger.info(f"Sending bot command to user {message['sender_id']}")
|
||||
bg_tasks.add_task(self.send_event, message["sender_id"], event)
|
||||
else:
|
||||
# 总是广播消息,无论是临时ID还是真实ID
|
||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||||
logger.info(f"Broadcasting message to all users in channel {message['channel_id']}")
|
||||
bg_tasks.add_task(
|
||||
self.broadcast,
|
||||
message.channel_id,
|
||||
message["channel_id"],
|
||||
event,
|
||||
)
|
||||
|
||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||
if message.message_id and message.message_id > 0:
|
||||
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||||
if message["message_id"] and message["message_id"] > 0:
|
||||
await self.mark_as_read(message["channel_id"], message["sender_id"], message["message_id"])
|
||||
await self.redis.set(f"chat:{message['channel_id']}:last_msg", message["message_id"])
|
||||
logger.info(f"Updated last message ID for channel {message['channel_id']} to {message['message_id']}")
|
||||
else:
|
||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||||
logger.debug(f"Skipping last message update for message ID: {message['message_id']}")
|
||||
|
||||
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
|
||||
async def batch_join_channel(self, users: list[User], channel: ChatChannel):
|
||||
channel_id = channel.channel_id
|
||||
|
||||
not_joined = []
|
||||
@@ -151,22 +152,18 @@ class ChatServer:
|
||||
not_joined.append(user)
|
||||
|
||||
for user in not_joined:
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
channel_resp = await ChatChannelModel.transform(
|
||||
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||
)
|
||||
await self.send_event(
|
||||
user.id,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
|
||||
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
|
||||
async def join_channel(self, user: User, channel: ChatChannel) -> ChatChannelDict:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
|
||||
@@ -175,25 +172,21 @@ class ChatServer:
|
||||
if user_id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user_id)
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
channel_resp: ChatChannelDict = await ChatChannelModel.transform(
|
||||
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||
)
|
||||
|
||||
await self.send_event(
|
||||
user_id,
|
||||
ChatEvent(
|
||||
event="chat.channel.join",
|
||||
data=channel_resp.model_dump(),
|
||||
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
|
||||
return channel_resp
|
||||
|
||||
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
|
||||
async def leave_channel(self, user: User, channel: ChatChannel) -> None:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
|
||||
@@ -203,18 +196,14 @@ class ChatServer:
|
||||
if (c := self.channels.get(channel_id)) is not None and not c:
|
||||
del self.channels[channel_id]
|
||||
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
|
||||
channel_resp = await ChatChannelModel.transform(
|
||||
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
|
||||
)
|
||||
await self.send_event(
|
||||
user_id,
|
||||
ChatEvent(
|
||||
event="chat.channel.part",
|
||||
data=channel_resp.model_dump(),
|
||||
data=channel_resp, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
|
||||
@@ -232,7 +221,7 @@ class ChatServer:
|
||||
return
|
||||
|
||||
logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})")
|
||||
await self.join_channel(user, db_channel, session)
|
||||
await self.join_channel(user, db_channel)
|
||||
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
@@ -248,7 +237,7 @@ class ChatServer:
|
||||
return
|
||||
|
||||
logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})")
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
await self.leave_channel(user, db_channel)
|
||||
|
||||
async def new_private_notification(self, detail: NotificationDetail):
|
||||
async with with_db() as session:
|
||||
@@ -336,6 +325,6 @@ async def chat_websocket(
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
|
||||
if db_channel is not None:
|
||||
await server.join_channel(user, db_channel, session)
|
||||
await server.join_channel(user, db_channel)
|
||||
|
||||
await _listen_stop(websocket, user_id, factory)
|
||||
|
||||
Reference in New Issue
Block a user