diff --git a/app/router/notification/server.py b/app/router/notification/server.py index 086ce50..b5f24c2 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -19,7 +19,7 @@ from app.models.notification import NotificationDetail from app.service.subscribers.chat import ChatSubscriber from app.utils import bg_tasks -from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect from fastapi.security import SecurityScopes from fastapi.websockets import WebSocketState from redis.asyncio import Redis @@ -289,7 +289,9 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory): @chat_router.websocket("/notification-server") async def chat_websocket( websocket: WebSocket, - authorization: str = Header(...), + token: str | None = Query(None, description="认证令牌,支持通过URL参数传递"), + access_token: str | None = Query(None, description="访问令牌,支持通过URL参数传递"), + authorization: str | None = Header(None, description="Bearer认证头"), factory: DBFactory = Depends(get_db_factory), ): if not server._subscribed: @@ -297,9 +299,20 @@ async def chat_websocket( await server.ChatSubscriber.start_subscribe() async for session in factory(): - token = authorization[7:] - if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None: - await websocket.close(code=1008) + # 优先使用查询参数中的token,支持token或access_token参数名 + auth_token = token or access_token + if not auth_token and authorization: + if authorization.startswith("Bearer "): + auth_token = authorization[7:] + else: + auth_token = authorization + + if not auth_token: + await websocket.close(code=1008, reason="Missing authentication token") + return + + if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token)) is None: + await websocket.close(code=1008, reason="Invalid or expired token") return await websocket.accept()