Update server.py

This commit is contained in:
咕谷酱
2025-08-28 04:21:43 +08:00
parent 1f53c66700
commit 7a0283086d

View File

@@ -19,7 +19,7 @@ from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber from app.service.subscribers.chat import ChatSubscriber
from app.utils import bg_tasks from app.utils import bg_tasks
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes from fastapi.security import SecurityScopes
from fastapi.websockets import WebSocketState from fastapi.websockets import WebSocketState
from redis.asyncio import Redis from redis.asyncio import Redis
@@ -289,7 +289,9 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
@chat_router.websocket("/notification-server") @chat_router.websocket("/notification-server")
async def chat_websocket( async def chat_websocket(
websocket: WebSocket, websocket: WebSocket,
authorization: str = Header(...), token: str | None = Query(None, description="认证令牌支持通过URL参数传递"),
access_token: str | None = Query(None, description="访问令牌支持通过URL参数传递"),
authorization: str | None = Header(None, description="Bearer认证头"),
factory: DBFactory = Depends(get_db_factory), factory: DBFactory = Depends(get_db_factory),
): ):
if not server._subscribed: if not server._subscribed:
@@ -297,9 +299,20 @@ async def chat_websocket(
await server.ChatSubscriber.start_subscribe() await server.ChatSubscriber.start_subscribe()
async for session in factory(): async for session in factory():
token = authorization[7:] # 优先使用查询参数中的token支持token或access_token参数名
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None: auth_token = token or access_token
await websocket.close(code=1008) if not auth_token and authorization:
if authorization.startswith("Bearer "):
auth_token = authorization[7:]
else:
auth_token = authorization
if not auth_token:
await websocket.close(code=1008, reason="Missing authentication token")
return
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token)) is None:
await websocket.close(code=1008, reason="Invalid or expired token")
return return
await websocket.accept() await websocket.accept()