refactor(api): use Annotated-style dependency injection
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Self
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
@@ -11,7 +11,7 @@ from app.database.chat import (
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.user import User, UserResp
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
@@ -20,7 +20,6 @@ from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
|
||||
|
||||
@@ -38,11 +37,14 @@ class UpdateResponse(BaseModel):
|
||||
)
|
||||
async def get_update(
|
||||
session: Database,
|
||||
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
redis: Redis,
|
||||
history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None,
|
||||
since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None,
|
||||
includes: Annotated[
|
||||
list[str],
|
||||
Query(alias="includes[]", description="要包含的更新类型"),
|
||||
] = ["presence", "silences"],
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
if "presence" in includes:
|
||||
@@ -86,9 +88,9 @@ async def get_update(
|
||||
)
|
||||
async def join_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
||||
user: Annotated[str, Path(..., description="用户 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
@@ -110,9 +112,9 @@ async def join_channel(
|
||||
)
|
||||
async def leave_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
||||
user: Annotated[str, Path(..., description="用户 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
@@ -135,8 +137,8 @@ async def leave_channel(
|
||||
)
|
||||
async def get_channel_list(
|
||||
session: Database,
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
redis: Redis,
|
||||
):
|
||||
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
|
||||
results = []
|
||||
@@ -171,9 +173,9 @@ class GetChannelResp(BaseModel):
|
||||
)
|
||||
async def get_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
redis: Redis,
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
@@ -245,9 +247,9 @@ class CreateChannelReq(BaseModel):
|
||||
)
|
||||
async def create_channel(
|
||||
session: Database,
|
||||
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
req: Annotated[CreateChannelReq, Depends(BodyOrForm(CreateChannelReq))],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
|
||||
redis: Redis,
|
||||
):
|
||||
if req.type == "PM":
|
||||
target = await session.get(User, req.target_id)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.database import ChatMessageResp
|
||||
from app.database.chat import (
|
||||
ChannelType,
|
||||
@@ -11,7 +13,7 @@ from app.database.chat import (
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.database import Database, Redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
@@ -24,7 +26,6 @@ from .server import server
|
||||
|
||||
from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
|
||||
|
||||
@@ -41,9 +42,9 @@ class KeepAliveResp(BaseModel):
|
||||
)
|
||||
async def keep_alive(
|
||||
session: Database,
|
||||
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
history_since: Annotated[int | None, Query(description="获取自此禁言 ID 之后的禁言记录")] = None,
|
||||
since: Annotated[int | None, Query(description="获取自此消息 ID 之后的禁言记录")] = None,
|
||||
):
|
||||
resp = KeepAliveResp()
|
||||
if history_since:
|
||||
@@ -73,9 +74,9 @@ class MessageReq(BaseModel):
|
||||
)
|
||||
async def send_message(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
req: MessageReq = Depends(BodyOrForm(MessageReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
||||
req: Annotated[MessageReq, Depends(BodyOrForm(MessageReq))],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
|
||||
):
|
||||
# 使用明确的查询来获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
@@ -156,10 +157,10 @@ async def send_message(
|
||||
async def get_message(
|
||||
session: Database,
|
||||
channel: str,
|
||||
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
|
||||
since: int = Query(0, ge=0, description="获取自此消息 ID 之后的消息(向前加载新消息)"),
|
||||
until: int | None = Query(None, description="获取自此消息 ID 之前的消息(向后翻历史)"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
limit: Annotated[int, Query(ge=1, le=50, description="获取消息的数量")] = 50,
|
||||
since: Annotated[int, Query(ge=0, description="获取自此消息 ID 之后的消息(向前加载新消息)")] = 0,
|
||||
until: Annotated[int | None, Query(description="获取自此消息 ID 之前的消息(向后翻历史)")] = None,
|
||||
):
|
||||
# 1) 查频道
|
||||
if channel.isdigit():
|
||||
@@ -220,9 +221,9 @@ async def get_message(
|
||||
)
|
||||
async def mark_as_read(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
message: int = Path(..., description="消息 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
channel: Annotated[str, Path(..., description="频道 ID/名称")],
|
||||
message: Annotated[int, Path(..., description="消息 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
|
||||
):
|
||||
# 使用明确的查询获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
@@ -259,9 +260,9 @@ class NewPMResp(BaseModel):
|
||||
)
|
||||
async def create_new_pm(
|
||||
session: Database,
|
||||
req: PMReq = Depends(BodyOrForm(PMReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
req: Annotated[PMReq, Depends(BodyOrForm(PMReq))],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
|
||||
redis: Redis,
|
||||
):
|
||||
user_id = current_user.id
|
||||
target = await session.get(User, req.target_id)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import overload
|
||||
from typing import Annotated, overload
|
||||
|
||||
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
|
||||
from app.database.notification import UserNotification, insert_notification
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import (
|
||||
DBFactory,
|
||||
Redis,
|
||||
get_db_factory,
|
||||
get_redis,
|
||||
with_db,
|
||||
@@ -22,7 +23,6 @@ from app.utils import bg_tasks
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
from fastapi.websockets import WebSocketState
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -298,10 +298,10 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||||
@chat_router.websocket("/notification-server")
|
||||
async def chat_websocket(
|
||||
websocket: WebSocket,
|
||||
token: str | None = Query(None, description="认证令牌,支持通过URL参数传递"),
|
||||
access_token: str | None = Query(None, description="访问令牌,支持通过URL参数传递"),
|
||||
authorization: str | None = Header(None, description="Bearer认证头"),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
factory: Annotated[DBFactory, Depends(get_db_factory)],
|
||||
token: Annotated[str | None, Query(description="认证令牌,支持通过URL参数传递")] = None,
|
||||
access_token: Annotated[str | None, Query(description="访问令牌,支持通过URL参数传递")] = None,
|
||||
authorization: Annotated[str | None, Header(description="Bearer认证头")] = None,
|
||||
):
|
||||
if not server._subscribed:
|
||||
server._subscribed = True
|
||||
|
||||
Reference in New Issue
Block a user