refactor(app): update database code
This commit is contained in:
@@ -3,9 +3,11 @@ from __future__ import annotations
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextvars import ContextVar
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
import redis.asyncio as redis
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
@@ -52,7 +54,12 @@ async def get_db():
|
||||
yield session
|
||||
|
||||
|
||||
def with_db():
|
||||
return AsyncSession(engine)
|
||||
|
||||
|
||||
DBFactory = Callable[[], AsyncIterator[AsyncSession]]
|
||||
Database = Annotated[AsyncSession, Depends(get_db)]
|
||||
|
||||
|
||||
async def get_db_factory() -> DBFactory:
|
||||
|
||||
@@ -8,7 +8,7 @@ from app.database import User
|
||||
from app.database.auth import V1APIKeys
|
||||
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||
|
||||
from .database import get_db
|
||||
from .database import Database
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import (
|
||||
@@ -19,7 +19,6 @@ from fastapi.security import (
|
||||
SecurityScopes,
|
||||
)
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
@@ -64,7 +63,7 @@ v1_api_key = APIKeyQuery(name="k", scheme_name="V1 API Key", description="v1 API
|
||||
|
||||
|
||||
async def v1_authorize(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Database,
|
||||
api_key: Annotated[str, Depends(v1_api_key)],
|
||||
):
|
||||
"""V1 API Key 授权"""
|
||||
@@ -79,8 +78,8 @@ async def v1_authorize(
|
||||
|
||||
|
||||
async def get_client_user(
|
||||
db: Database,
|
||||
token: Annotated[str, Depends(oauth2_password)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
@@ -95,8 +94,8 @@ async def get_client_user(
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
db: Database,
|
||||
security_scopes: SecurityScopes,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
||||
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
||||
token_client_credentials: Annotated[
|
||||
|
||||
Reference in New Issue
Block a user