diff --git a/app/dependencies/database.py b/app/dependencies/database.py index c3ec9a4..1ced03c 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextvars import ContextVar import json from app.config import settings @@ -25,8 +26,22 @@ redis_client = redis.from_url(settings.redis_url, decode_responses=True) # 数据库依赖 +db_session_context: ContextVar[AsyncSession | None] = ContextVar( + "db_session_context", default=None +) + + async def get_db(): - async with AsyncSession(engine) as session: + session = db_session_context.get() + if session is None: + session = AsyncSession(engine) + db_session_context.set(session) + try: + yield session + finally: + await session.close() + db_session_context.set(None) + else: yield session