refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -1,37 +1,40 @@
from typing import Annotated, Any, Literal, Self
from typing import Annotated, Literal, Self
from app.database.chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatChannelModel,
ChatMessage,
SilenceUser,
UserSilenceResp,
)
from app.database.user import User, UserResp
from app.database.user import User, UserModel
from app.dependencies.database import Database, Redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router
from app.utils import api_doc
from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, model_validator
from sqlmodel import col, select
class UpdateResponse(BaseModel):
presence: list[ChatChannelResp] = Field(default_factory=list)
silences: list[Any] = Field(default_factory=list)
@router.get(
"/chat/updates",
response_model=UpdateResponse,
name="获取更新",
description="获取当前用户所在频道的最新的禁言情况。",
tags=["聊天"],
responses={
200: api_doc(
"获取更新响应。",
{"presence": list[ChatChannelModel], "silences": list[UserSilenceResp]},
ChatChannel.LISTING_INCLUDES,
name="UpdateResponse",
)
},
)
async def get_update(
session: Database,
@@ -44,45 +47,44 @@ async def get_update(
Query(alias="includes[]", description="要包含的更新类型"),
] = ["presence", "silences"],
):
resp = UpdateResponse()
resp = {
"presence": [],
"silences": [],
}
if "presence" in includes:
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
# 使用明确的查询避免延迟加载
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel:
# 提取必要的属性避免惰性加载
channel_type = db_channel.type
resp.presence.append(
await ChatChannelResp.from_db(
resp["presence"].append(
await ChatChannelModel.transform(
db_channel,
session,
current_user,
redis,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
user=current_user,
server=server,
includes=ChatChannel.LISTING_INCLUDES,
)
)
if "silences" in includes:
if history_since:
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
return resp
@router.put(
"/chat/channels/{channel}/users/{user}",
response_model=ChatChannelResp,
name="加入频道",
description="加入指定的公开/房间频道。",
tags=["聊天"],
responses={200: api_doc("加入的频道", ChatChannelModel, ChatChannel.LISTING_INCLUDES)},
)
async def join_channel(
session: Database,
@@ -101,7 +103,7 @@ async def join_channel(
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
return await server.join_channel(current_user, db_channel, session)
return await server.join_channel(current_user, db_channel)
@router.delete(
@@ -128,13 +130,13 @@ async def leave_channel(
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
await server.leave_channel(current_user, db_channel, session)
await server.leave_channel(current_user, db_channel)
return
@router.get(
"/chat/channels",
response_model=list[ChatChannelResp],
responses={200: api_doc("加入的频道", list[ChatChannelModel])},
name="获取频道列表",
description="获取所有公开频道。",
tags=["聊天"],
@@ -142,35 +144,30 @@ async def leave_channel(
async def get_channel_list(
session: Database,
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
redis: Redis,
):
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
results = []
for channel in channels:
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
channel_type = channel.type
results = await ChatChannelModel.transform_many(
channels,
user=current_user,
server=server,
)
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
return results
class GetChannelResp(BaseModel):
channel: ChatChannelResp
users: list[UserResp] = Field(default_factory=list)
@router.get(
"/chat/channels/{channel}",
response_model=GetChannelResp,
responses={
200: api_doc(
"频道详细信息",
{
"channel": ChatChannelModel,
"users": list[UserModel],
},
ChatChannel.LISTING_INCLUDES + User.CARD_INCLUDES,
name="GetChannelResponse",
)
},
name="获取频道信息",
description="获取指定频道的信息。",
tags=["聊天"],
@@ -191,7 +188,6 @@ async def get_channel(
raise HTTPException(status_code=404, detail="Channel not found")
# 立即提取需要的属性
channel_id = db_channel.channel_id
channel_type = db_channel.type
channel_name = db_channel.name
@@ -209,15 +205,15 @@ async def get_channel(
users.extend([target_user, current_user])
break
return GetChannelResp(
channel=await ChatChannelResp.from_db(
return {
"channel": await ChatChannelModel.transform(
db_channel,
session,
current_user,
redis,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
user=current_user,
server=server,
includes=ChatChannel.LISTING_INCLUDES,
),
"users": await UserModel.transform_many(users, includes=User.CARD_INCLUDES),
}
class CreateChannelReq(BaseModel):
@@ -244,7 +240,7 @@ class CreateChannelReq(BaseModel):
@router.post(
"/chat/channels",
response_model=ChatChannelResp,
responses={200: api_doc("创建的频道", ChatChannelModel, ["recent_messages.sender"])},
name="创建频道",
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
tags=["聊天"],
@@ -289,21 +285,13 @@ async def create_channel(
await session.refresh(current_user)
if req.type == "PM":
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel) # pyright: ignore[reportPossiblyUnboundVariable]
else:
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.batch_join_channel([*target_users, current_user], channel)
await server.join_channel(current_user, channel, session)
await server.join_channel(current_user, channel)
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
return await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, []),
include_recent_messages=True,
return await ChatChannelModel.transform(
channel, user=current_user, server=server, includes=["recent_messages.sender"]
)