Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
298 lines
11 KiB
Python
298 lines
11 KiB
Python
from typing import Annotated, Literal, Self
|
|
|
|
from app.database.chat import (
|
|
ChannelType,
|
|
ChatChannel,
|
|
ChatChannelModel,
|
|
ChatMessage,
|
|
SilenceUser,
|
|
UserSilenceResp,
|
|
)
|
|
from app.database.user import User, UserModel
|
|
from app.dependencies.database import Database, Redis
|
|
from app.dependencies.param import BodyOrForm
|
|
from app.dependencies.user import get_current_user
|
|
from app.router.v2 import api_v2_router as router
|
|
from app.utils import api_doc
|
|
|
|
from .server import server
|
|
|
|
from fastapi import Depends, HTTPException, Path, Query, Security
|
|
from pydantic import BaseModel, model_validator
|
|
from sqlmodel import col, select
|
|
|
|
|
|
@router.get(
|
|
"/chat/updates",
|
|
name="获取更新",
|
|
description="获取当前用户所在频道的最新的禁言情况。",
|
|
tags=["聊天"],
|
|
responses={
|
|
200: api_doc(
|
|
"获取更新响应。",
|
|
{"presence": list[ChatChannelModel], "silences": list[UserSilenceResp]},
|
|
ChatChannel.LISTING_INCLUDES,
|
|
name="UpdateResponse",
|
|
)
|
|
},
|
|
)
|
|
async def get_update(
|
|
session: Database,
|
|
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
|
redis: Redis,
|
|
history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None,
|
|
since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None,
|
|
includes: Annotated[
|
|
list[str],
|
|
Query(alias="includes[]", description="要包含的更新类型"),
|
|
] = ["presence", "silences"],
|
|
):
|
|
resp = {
|
|
"presence": [],
|
|
"silences": [],
|
|
}
|
|
if "presence" in includes:
|
|
channel_ids = server.get_user_joined_channel(current_user.id)
|
|
for channel_id in channel_ids:
|
|
# 使用明确的查询避免延迟加载
|
|
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
|
if db_channel:
|
|
resp["presence"].append(
|
|
await ChatChannelModel.transform(
|
|
db_channel,
|
|
user=current_user,
|
|
server=server,
|
|
includes=ChatChannel.LISTING_INCLUDES,
|
|
)
|
|
)
|
|
if "silences" in includes:
|
|
if history_since:
|
|
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
|
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
|
|
elif since:
|
|
msg = await session.get(ChatMessage, since)
|
|
if msg:
|
|
silences = (
|
|
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
|
|
).all()
|
|
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
|
|
return resp
|
|
|
|
|
|
@router.put(
|
|
"/chat/channels/{channel}/users/{user}",
|
|
name="加入频道",
|
|
description="加入指定的公开/房间频道。",
|
|
tags=["聊天"],
|
|
responses={200: api_doc("加入的频道", ChatChannelModel, ChatChannel.LISTING_INCLUDES)},
|
|
)
|
|
async def join_channel(
|
|
session: Database,
|
|
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
|
user: Annotated[str, Path(..., description="用户 ID")],
|
|
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
|
|
):
|
|
if await current_user.is_restricted(session):
|
|
raise HTTPException(status_code=403, detail="You are restricted from sending messages")
|
|
|
|
# 使用明确的查询避免延迟加载
|
|
if channel.isdigit():
|
|
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
|
else:
|
|
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
|
|
|
if db_channel is None:
|
|
raise HTTPException(status_code=404, detail="Channel not found")
|
|
return await server.join_channel(current_user, db_channel)
|
|
|
|
|
|
@router.delete(
|
|
"/chat/channels/{channel}/users/{user}",
|
|
status_code=204,
|
|
name="离开频道",
|
|
description="将用户移出指定的公开/房间频道。",
|
|
tags=["聊天"],
|
|
)
|
|
async def leave_channel(
|
|
session: Database,
|
|
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
|
user: Annotated[str, Path(..., description="用户 ID")],
|
|
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
|
|
):
|
|
if await current_user.is_restricted(session):
|
|
raise HTTPException(status_code=403, detail="You are restricted from sending messages")
|
|
|
|
# 使用明确的查询避免延迟加载
|
|
if channel.isdigit():
|
|
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
|
else:
|
|
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
|
|
|
if db_channel is None:
|
|
raise HTTPException(status_code=404, detail="Channel not found")
|
|
await server.leave_channel(current_user, db_channel)
|
|
return
|
|
|
|
|
|
@router.get(
|
|
"/chat/channels",
|
|
responses={200: api_doc("加入的频道", list[ChatChannelModel])},
|
|
name="获取频道列表",
|
|
description="获取所有公开频道。",
|
|
tags=["聊天"],
|
|
)
|
|
async def get_channel_list(
|
|
session: Database,
|
|
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
|
):
|
|
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
|
|
results = await ChatChannelModel.transform_many(
|
|
channels,
|
|
user=current_user,
|
|
server=server,
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
@router.get(
|
|
"/chat/channels/{channel}",
|
|
responses={
|
|
200: api_doc(
|
|
"频道详细信息",
|
|
{
|
|
"channel": ChatChannelModel,
|
|
"users": list[UserModel],
|
|
},
|
|
ChatChannel.LISTING_INCLUDES + User.CARD_INCLUDES,
|
|
name="GetChannelResponse",
|
|
)
|
|
},
|
|
name="获取频道信息",
|
|
description="获取指定频道的信息。",
|
|
tags=["聊天"],
|
|
)
|
|
async def get_channel(
|
|
session: Database,
|
|
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
|
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
|
redis: Redis,
|
|
):
|
|
# 使用明确的查询避免延迟加载
|
|
if channel.isdigit():
|
|
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
|
else:
|
|
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
|
|
|
if db_channel is None:
|
|
raise HTTPException(status_code=404, detail="Channel not found")
|
|
|
|
# 立即提取需要的属性
|
|
channel_type = db_channel.type
|
|
channel_name = db_channel.name
|
|
|
|
users = []
|
|
if channel_type == ChannelType.PM:
|
|
user_ids = channel_name.split("_")[1:]
|
|
if len(user_ids) != 2:
|
|
raise HTTPException(status_code=404, detail="Target user not found")
|
|
for id_ in user_ids:
|
|
if int(id_) == current_user.id:
|
|
continue
|
|
target_user = await session.get(User, int(id_))
|
|
if target_user is None or await target_user.is_restricted(session):
|
|
raise HTTPException(status_code=404, detail="Target user not found")
|
|
users.extend([target_user, current_user])
|
|
break
|
|
|
|
return {
|
|
"channel": await ChatChannelModel.transform(
|
|
db_channel,
|
|
user=current_user,
|
|
server=server,
|
|
includes=ChatChannel.LISTING_INCLUDES,
|
|
),
|
|
"users": await UserModel.transform_many(users, includes=User.CARD_INCLUDES),
|
|
}
|
|
|
|
|
|
class CreateChannelReq(BaseModel):
|
|
class AnnounceChannel(BaseModel):
|
|
name: str
|
|
description: str
|
|
|
|
message: str | None = None
|
|
type: Literal["ANNOUNCE", "PM"] = "PM"
|
|
target_id: int | None = None
|
|
target_ids: list[int] | None = None
|
|
channel: AnnounceChannel | None = None
|
|
|
|
@model_validator(mode="after")
|
|
def check(self) -> Self:
|
|
if self.type == "PM":
|
|
if self.target_id is None:
|
|
raise ValueError("target_id must be set for PM channels")
|
|
else:
|
|
if self.target_ids is None or self.channel is None or self.message is None:
|
|
raise ValueError("target_ids, channel, and message must be set for ANNOUNCE channels")
|
|
return self
|
|
|
|
|
|
@router.post(
|
|
"/chat/channels",
|
|
responses={200: api_doc("创建的频道", ChatChannelModel, ["recent_messages.sender"])},
|
|
name="创建频道",
|
|
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
|
|
tags=["聊天"],
|
|
)
|
|
async def create_channel(
|
|
session: Database,
|
|
req: Annotated[CreateChannelReq, Depends(BodyOrForm(CreateChannelReq))],
|
|
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
|
|
redis: Redis,
|
|
):
|
|
if await current_user.is_restricted(session):
|
|
raise HTTPException(status_code=403, detail="You are restricted from sending messages")
|
|
|
|
if req.type == "PM":
|
|
target = await session.get(User, req.target_id)
|
|
if not target or await target.is_restricted(session):
|
|
raise HTTPException(status_code=404, detail="Target user not found")
|
|
is_can_pm, block = await target.is_user_can_pm(current_user, session)
|
|
if not is_can_pm:
|
|
raise HTTPException(status_code=403, detail=block)
|
|
|
|
channel = await ChatChannel.get_pm_channel(
|
|
current_user.id,
|
|
req.target_id, # pyright: ignore[reportArgumentType]
|
|
session,
|
|
)
|
|
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
|
else:
|
|
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
|
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
|
|
channel = result.first()
|
|
|
|
if channel is None:
|
|
channel = ChatChannel(
|
|
name=channel_name,
|
|
description=req.channel.description if req.channel else "Private message channel",
|
|
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
|
|
)
|
|
session.add(channel)
|
|
await session.commit()
|
|
await session.refresh(channel)
|
|
await session.refresh(current_user)
|
|
if req.type == "PM":
|
|
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
|
await server.batch_join_channel([target, current_user], channel) # pyright: ignore[reportPossiblyUnboundVariable]
|
|
else:
|
|
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
|
|
await server.batch_join_channel([*target_users, current_user], channel)
|
|
|
|
await server.join_channel(current_user, channel)
|
|
|
|
return await ChatChannelModel.transform(
|
|
channel, user=current_user, server=server, includes=["recent_messages.sender"]
|
|
)
|