fix(database): fix cross-session user (current_user doesn't belong to get_db)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@@ -25,8 +26,22 @@ redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
|||||||
|
|
||||||
|
|
||||||
# 数据库依赖
|
# 数据库依赖
|
||||||
|
db_session_context: ContextVar[AsyncSession | None] = ContextVar(
|
||||||
|
"db_session_context", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_db():
|
async def get_db():
|
||||||
async with AsyncSession(engine) as session:
|
session = db_session_context.get()
|
||||||
|
if session is None:
|
||||||
|
session = AsyncSession(engine)
|
||||||
|
db_session_context.set(session)
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
db_session_context.set(None)
|
||||||
|
else:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user