Update server.py

This commit is contained in:
咕谷酱
2025-08-28 04:21:43 +08:00
parent 1f53c66700
commit 7a0283086d

View File

@@ -19,7 +19,7 @@ from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber
from app.utils import bg_tasks
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
from fastapi.websockets import WebSocketState
from redis.asyncio import Redis
@@ -289,7 +289,9 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
@chat_router.websocket("/notification-server")
async def chat_websocket(
websocket: WebSocket,
authorization: str = Header(...),
token: str | None = Query(None, description="认证令牌支持通过URL参数传递"),
access_token: str | None = Query(None, description="访问令牌支持通过URL参数传递"),
authorization: str | None = Header(None, description="Bearer认证头"),
factory: DBFactory = Depends(get_db_factory),
):
if not server._subscribed:
@@ -297,9 +299,20 @@ async def chat_websocket(
await server.ChatSubscriber.start_subscribe()
async for session in factory():
token = authorization[7:]
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None:
await websocket.close(code=1008)
# 优先使用查询参数中的token支持token或access_token参数名
auth_token = token or access_token
if not auth_token and authorization:
if authorization.startswith("Bearer "):
auth_token = authorization[7:]
else:
auth_token = authorization
if not auth_token:
await websocket.close(code=1008, reason="Missing authentication token")
return
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token)) is None:
await websocket.close(code=1008, reason="Invalid or expired token")
return
await websocket.accept()