Update server.py
This commit is contained in:
@@ -19,7 +19,7 @@ from app.models.notification import NotificationDetail
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
from app.utils import bg_tasks
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
from fastapi.websockets import WebSocketState
|
||||
from redis.asyncio import Redis
|
||||
@@ -289,7 +289,9 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||||
@chat_router.websocket("/notification-server")
|
||||
async def chat_websocket(
|
||||
websocket: WebSocket,
|
||||
authorization: str = Header(...),
|
||||
token: str | None = Query(None, description="认证令牌,支持通过URL参数传递"),
|
||||
access_token: str | None = Query(None, description="访问令牌,支持通过URL参数传递"),
|
||||
authorization: str | None = Header(None, description="Bearer认证头"),
|
||||
factory: DBFactory = Depends(get_db_factory),
|
||||
):
|
||||
if not server._subscribed:
|
||||
@@ -297,9 +299,20 @@ async def chat_websocket(
|
||||
await server.ChatSubscriber.start_subscribe()
|
||||
|
||||
async for session in factory():
|
||||
token = authorization[7:]
|
||||
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None:
|
||||
await websocket.close(code=1008)
|
||||
# 优先使用查询参数中的token,支持token或access_token参数名
|
||||
auth_token = token or access_token
|
||||
if not auth_token and authorization:
|
||||
if authorization.startswith("Bearer "):
|
||||
auth_token = authorization[7:]
|
||||
else:
|
||||
auth_token = authorization
|
||||
|
||||
if not auth_token:
|
||||
await websocket.close(code=1008, reason="Missing authentication token")
|
||||
return
|
||||
|
||||
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token)) is None:
|
||||
await websocket.close(code=1008, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
Reference in New Issue
Block a user