refactor(api): use Annotated-style dependency injection

This commit is contained in:
MingxuanGame
2025-10-03 05:41:31 +00:00
parent 37b4eadf79
commit 346c2557cf
45 changed files with 623 additions and 577 deletions

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Annotated
from app.database import ChatMessageResp
from app.database.chat import (
ChannelType,
@@ -11,7 +13,7 @@ from app.database.chat import (
UserSilenceResp,
)
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.database import Database, Redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.log import logger
@@ -24,7 +26,6 @@ from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlmodel import col, select
@@ -41,9 +42,9 @@ class KeepAliveResp(BaseModel):
)
async def keep_alive(
session: Database,
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None,
since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None,
):
resp = KeepAliveResp()
if history_since:
@@ -73,9 +74,9 @@ class MessageReq(BaseModel):
)
async def send_message(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
req: MessageReq = Depends(BodyOrForm(MessageReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
channel: Annotated[str, Path(..., description="频道 ID/名称")],
req: Annotated[MessageReq, Depends(BodyOrForm(MessageReq))],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
):
# 使用明确的查询来获取 channel避免延迟加载
if channel.isdigit():
@@ -156,10 +157,10 @@ async def send_message(
async def get_message(
session: Database,
channel: str,
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
since: int = Query(0, ge=0, description="获取自此消息 ID 之后的消息(向前加载新消息)"),
until: int | None = Query(None, description="获取自此消息 ID 之的消息(向后翻历史"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
limit: Annotated[int, Query(ge=1, le=50, description="获取消息的数量")] = 50,
since: Annotated[int, Query(ge=0, description="获取自此消息 ID 之的消息(向前加载新消息")] = 0,
until: Annotated[int | None, Query(description="获取自此消息 ID 之前的消息(向后翻历史)")] = None,
):
# 1) 查频道
if channel.isdigit():
@@ -220,9 +221,9 @@ async def get_message(
)
async def mark_as_read(
session: Database,
channel: str = Path(..., description="频道 ID/名称"),
message: int = Path(..., description="消息 ID"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
channel: Annotated[str, Path(..., description="频道 ID/名称")],
message: Annotated[int, Path(..., description="消息 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
):
# 使用明确的查询获取 channel避免延迟加载
if channel.isdigit():
@@ -259,9 +260,9 @@ class NewPMResp(BaseModel):
)
async def create_new_pm(
session: Database,
req: PMReq = Depends(BodyOrForm(PMReq)),
current_user: User = Security(get_current_user, scopes=["chat.write"]),
redis: Redis = Depends(get_redis),
req: Annotated[PMReq, Depends(BodyOrForm(PMReq))],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
redis: Redis,
):
user_id = current_user.id
target = await session.get(User, req.target_id)