From 7a0283086d2c5cd5b50fd66f7d36d94bfad09f6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= <74496778+GooGuJiang@users.noreply.github.com> Date: Thu, 28 Aug 2025 04:21:43 +0800 Subject: [PATCH] Update server.py --- app/router/notification/server.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/app/router/notification/server.py b/app/router/notification/server.py index 086ce50..b5f24c2 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -19,7 +19,7 @@ from app.models.notification import NotificationDetail from app.service.subscribers.chat import ChatSubscriber from app.utils import bg_tasks -from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect from fastapi.security import SecurityScopes from fastapi.websockets import WebSocketState from redis.asyncio import Redis @@ -289,7 +289,9 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory): @chat_router.websocket("/notification-server") async def chat_websocket( websocket: WebSocket, - authorization: str = Header(...), + token: str | None = Query(None, description="认证令牌,支持通过URL参数传递"), + access_token: str | None = Query(None, description="访问令牌,支持通过URL参数传递"), + authorization: str | None = Header(None, description="Bearer认证头"), factory: DBFactory = Depends(get_db_factory), ): if not server._subscribed: @@ -297,9 +299,20 @@ async def chat_websocket( await server.ChatSubscriber.start_subscribe() async for session in factory(): - token = authorization[7:] - if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None: - await websocket.close(code=1008) + # 优先使用查询参数中的token,支持token或access_token参数名 + auth_token = token or access_token + if not auth_token and authorization: + if authorization.startswith("Bearer "): + auth_token = authorization[7:] + else: + auth_token = authorization + + if not auth_token: + await websocket.close(code=1008, reason="Missing authentication token") + return + + if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token)) is None: + await websocket.close(code=1008, reason="Invalid or expired token") return await websocket.accept()