Update server.py
This commit is contained in:
@@ -19,7 +19,7 @@ from app.models.notification import NotificationDetail
|
|||||||
from app.service.subscribers.chat import ChatSubscriber
|
from app.service.subscribers.chat import ChatSubscriber
|
||||||
from app.utils import bg_tasks
|
from app.utils import bg_tasks
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.security import SecurityScopes
|
from fastapi.security import SecurityScopes
|
||||||
from fastapi.websockets import WebSocketState
|
from fastapi.websockets import WebSocketState
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
@@ -289,7 +289,9 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
|||||||
@chat_router.websocket("/notification-server")
|
@chat_router.websocket("/notification-server")
|
||||||
async def chat_websocket(
|
async def chat_websocket(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
authorization: str = Header(...),
|
token: str | None = Query(None, description="认证令牌,支持通过URL参数传递"),
|
||||||
|
access_token: str | None = Query(None, description="访问令牌,支持通过URL参数传递"),
|
||||||
|
authorization: str | None = Header(None, description="Bearer认证头"),
|
||||||
factory: DBFactory = Depends(get_db_factory),
|
factory: DBFactory = Depends(get_db_factory),
|
||||||
):
|
):
|
||||||
if not server._subscribed:
|
if not server._subscribed:
|
||||||
@@ -297,9 +299,20 @@ async def chat_websocket(
|
|||||||
await server.ChatSubscriber.start_subscribe()
|
await server.ChatSubscriber.start_subscribe()
|
||||||
|
|
||||||
async for session in factory():
|
async for session in factory():
|
||||||
token = authorization[7:]
|
# 优先使用查询参数中的token,支持token或access_token参数名
|
||||||
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None:
|
auth_token = token or access_token
|
||||||
await websocket.close(code=1008)
|
if not auth_token and authorization:
|
||||||
|
if authorization.startswith("Bearer "):
|
||||||
|
auth_token = authorization[7:]
|
||||||
|
else:
|
||||||
|
auth_token = authorization
|
||||||
|
|
||||||
|
if not auth_token:
|
||||||
|
await websocket.close(code=1008, reason="Missing authentication token")
|
||||||
|
return
|
||||||
|
|
||||||
|
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token)) is None:
|
||||||
|
await websocket.close(code=1008, reason="Invalid or expired token")
|
||||||
return
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|||||||
Reference in New Issue
Block a user