fix(session): fix wrong usages of with_db

This commit is contained in:
MingxuanGame
2025-10-02 15:26:54 +00:00
parent 9a77c8d246
commit 2e1d922f59
2 changed files with 21 additions and 37 deletions

View File

@@ -111,14 +111,11 @@ class SessionVerificationState:
async def mark_verified(self) -> None: async def mark_verified(self) -> None:
"""标记会话为已验证""" """标记会话为已验证"""
try: try:
# 创建专用数据库会话 async with with_db() as db:
db = with_db() # 创建专用数据库会话
try:
token_id = self.session.token_id token_id = self.session.token_id
if token_id is not None: if token_id is not None:
await LoginSessionService.mark_session_verified(db, self.redis, self.user.id, token_id) await LoginSessionService.mark_session_verified(db, self.redis, self.user.id, token_id)
finally:
await db.close()
except Exception as e: except Exception as e:
logger.error(f"[Session Verification] Error marking session verified: {e}") logger.error(f"[Session Verification] Error marking session verified: {e}")
@@ -141,13 +138,10 @@ class SessionVerificationState:
from app.service.verification_service import EmailVerificationService from app.service.verification_service import EmailVerificationService
# 创建专用数据库会话发送邮件 # 创建专用数据库会话发送邮件
db = with_db() async with with_db() as db:
try:
await EmailVerificationService.send_verification_email( await EmailVerificationService.send_verification_email(
db, self.redis, self.user.id, self.user.username, self.user.email, None, None db, self.redis, self.user.id, self.user.username, self.user.email, None, None
) )
finally:
await db.close()
except Exception as e: except Exception as e:
logger.error(f"[Session Verification] Error issuing mail: {e}") logger.error(f"[Session Verification] Error issuing mail: {e}")
@@ -229,26 +223,26 @@ class SessionVerificationMiddleware:
return await call_next(request) return await call_next(request)
# 获取数据库和Redis连接 # 获取数据库和Redis连接
db = await self._get_db() async with with_db() as db:
redis = await self._get_redis() redis = await self._get_redis()
# 获取会话验证状态 # 获取会话验证状态
state = await SessionVerificationState.get_current(request, db, redis, user) state = await SessionVerificationState.get_current(request, db, redis, user)
if not state: if not state:
# 无法获取会话状态,继续请求 # 无法获取会话状态,继续请求
return await call_next(request) return await call_next(request)
# 检查是否已验证 # 检查是否已验证
if state.is_verified(): if state.is_verified():
# 已验证,继续请求 # 已验证,继续请求
return await call_next(request) return await call_next(request)
# 检查是否需要验证 # 检查是否需要验证
if not self._requires_verification(request): if not self._requires_verification(request):
return await call_next(request) return await call_next(request)
# 启动验证流程 # 启动验证流程
return await SessionVerificationController.initiate_verification(state, request) return await SessionVerificationController.initiate_verification(state, request)
except Exception as e: except Exception as e:
logger.error(f"[Session Verification Middleware] Unexpected error: {e}") logger.error(f"[Session Verification Middleware] Unexpected error: {e}")
@@ -264,10 +258,6 @@ class SessionVerificationMiddleware:
except Exception: except Exception:
return None return None
async def _get_db(self) -> AsyncSession:
"""获取数据库连接"""
return with_db()
async def _get_redis(self) -> Redis: async def _get_redis(self) -> Redis:
"""获取Redis连接""" """获取Redis连接"""
return get_redis() return get_redis()

View File

@@ -135,8 +135,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
token = auth_header[7:] # 移除"Bearer "前缀 token = auth_header[7:] # 移除"Bearer "前缀
# 创建专用数据库会话 # 创建专用数据库会话
db = with_db() async with with_db() as db:
try:
# 获取token记录 # 获取token记录
token_record = await get_token_by_access_token(db, token) token_record = await get_token_by_access_token(db, token)
if not token_record: if not token_record:
@@ -145,8 +144,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
# 获取用户 # 获取用户
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
return user return user
finally:
await db.close()
except Exception as e: except Exception as e:
logger.debug(f"[Verify Session Middleware] Error getting user: {e}") logger.debug(f"[Verify Session Middleware] Error getting user: {e}")
@@ -163,8 +160,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
session_token = auth_header[7:] session_token = auth_header[7:]
# 获取数据库和Redis连接 # 获取数据库和Redis连接
db = with_db() async with with_db() as db:
try:
redis = get_redis() redis = get_redis()
# 查找会话 # 查找会话
@@ -173,8 +169,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
return None return None
return SessionState(session, user, redis, db) return SessionState(session, user, redis, db)
finally:
await db.close()
except Exception as e: except Exception as e:
logger.error(f"[Verify Session Middleware] Error getting session state: {e}") logger.error(f"[Verify Session Middleware] Error getting session state: {e}")