diff --git a/app/middleware/session_verification.py b/app/middleware/session_verification.py index f2db36e..eceb931 100644 --- a/app/middleware/session_verification.py +++ b/app/middleware/session_verification.py @@ -111,14 +111,11 @@ class SessionVerificationState: async def mark_verified(self) -> None: """标记会话为已验证""" try: - # 创建专用数据库会话 - db = with_db() - try: + async with with_db() as db: + # 创建专用数据库会话 token_id = self.session.token_id if token_id is not None: await LoginSessionService.mark_session_verified(db, self.redis, self.user.id, token_id) - finally: - await db.close() except Exception as e: logger.error(f"[Session Verification] Error marking session verified: {e}") @@ -141,13 +138,10 @@ class SessionVerificationState: from app.service.verification_service import EmailVerificationService # 创建专用数据库会话发送邮件 - db = with_db() - try: + async with with_db() as db: await EmailVerificationService.send_verification_email( db, self.redis, self.user.id, self.user.username, self.user.email, None, None ) - finally: - await db.close() except Exception as e: logger.error(f"[Session Verification] Error issuing mail: {e}") @@ -229,26 +223,26 @@ class SessionVerificationMiddleware: return await call_next(request) # 获取数据库和Redis连接 - db = await self._get_db() - redis = await self._get_redis() + async with with_db() as db: + redis = await self._get_redis() - # 获取会话验证状态 - state = await SessionVerificationState.get_current(request, db, redis, user) - if not state: - # 无法获取会话状态,继续请求 - return await call_next(request) + # 获取会话验证状态 + state = await SessionVerificationState.get_current(request, db, redis, user) + if not state: + # 无法获取会话状态,继续请求 + return await call_next(request) - # 检查是否已验证 - if state.is_verified(): - # 已验证,继续请求 - return await call_next(request) + # 检查是否已验证 + if state.is_verified(): + # 已验证,继续请求 + return await call_next(request) - # 检查是否需要验证 - if not self._requires_verification(request): - return await call_next(request) + # 检查是否需要验证 + if not self._requires_verification(request): + return await call_next(request) - # 启动验证流程 - return await SessionVerificationController.initiate_verification(state, request) + # 启动验证流程 + return await SessionVerificationController.initiate_verification(state, request) except Exception as e: logger.error(f"[Session Verification Middleware] Unexpected error: {e}") @@ -264,10 +258,6 @@ class SessionVerificationMiddleware: except Exception: return None - async def _get_db(self) -> AsyncSession: - """获取数据库连接""" - return with_db() - async def _get_redis(self) -> Redis: """获取Redis连接""" return get_redis() diff --git a/app/middleware/verify_session.py b/app/middleware/verify_session.py index a935566..fbc7329 100644 --- a/app/middleware/verify_session.py +++ b/app/middleware/verify_session.py @@ -135,8 +135,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): token = auth_header[7:] # 移除"Bearer "前缀 # 创建专用数据库会话 - db = with_db() - try: + async with with_db() as db: # 获取token记录 token_record = await get_token_by_access_token(db, token) if not token_record: @@ -145,8 +144,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): # 获取用户 user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() return user - finally: - await db.close() except Exception as e: logger.debug(f"[Verify Session Middleware] Error getting user: {e}") @@ -163,8 +160,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): session_token = auth_header[7:] # 获取数据库和Redis连接 - db = with_db() - try: + async with with_db() as db: redis = get_redis() # 查找会话 @@ -173,8 +169,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): return None return SessionState(session, user, redis, db) - finally: - await db.close() except Exception as e: logger.error(f"[Verify Session Middleware] Error getting session state: {e}")