refactor(app): update database code
This commit is contained in:
@@ -8,7 +8,7 @@ import uuid
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.database import DBFactory, get_db_factory
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
|
||||
from .hub import Hubs
|
||||
@@ -16,7 +16,6 @@ from .packet import PROTOCOLS, SEP
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
|
||||
from fastapi.security import SecurityScopes
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter(prefix="/signalr", include_in_schema=False)
|
||||
|
||||
@@ -47,7 +46,7 @@ async def connect(
|
||||
websocket: WebSocket,
|
||||
id: str,
|
||||
authorization: str = Header(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
):
|
||||
token = authorization[7:]
|
||||
user_id = id.split(":")[0]
|
||||
@@ -56,13 +55,14 @@ async def connect(
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
try:
|
||||
if (
|
||||
user := await get_current_user(
|
||||
SecurityScopes(scopes=["*"]), db, token_pw=token
|
||||
)
|
||||
) is None or str(user.id) != user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
async for session in factory():
|
||||
if (
|
||||
user := await get_current_user(
|
||||
session, SecurityScopes(scopes=["*"]), token_pw=token
|
||||
)
|
||||
) is None or str(user.id) != user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
except HTTPException:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user