""" 会话验证中间件和状态管理 基于osu-web的会话验证系统实现 """ from __future__ import annotations import asyncio from typing import Awaitable, Callable, Optional from fastapi import HTTPException, Request, Response, status from fastapi.responses import JSONResponse from redis.asyncio import Redis from sqlmodel.ext.asyncio.session import AsyncSession from app.database.lazer_user import User from app.database.verification import LoginSession from app.dependencies.database import with_db, get_redis from app.log import logger from app.service.verification_service import LoginSessionService class SessionVerificationState: """会话验证状态管理类 参考osu-web的State类实现 """ def __init__(self, session: LoginSession, user: User, redis: Redis): self.session = session self.user = user self.redis = redis @classmethod async def get_current( cls, request: Request, db: AsyncSession, redis: Redis, user: User, ) -> Optional[SessionVerificationState]: """获取当前会话验证状态""" try: # 从请求头或token中获取会话信息 session_token = cls._extract_session_token(request) if not session_token: return None # 查找会话 session = await LoginSessionService.find_for_verification(db, session_token) if not session or session.user_id != user.id: return None return cls(session, user, redis) except Exception as e: logger.error(f"[Session Verification] Error getting current state: {e}") return None @staticmethod def _extract_session_token(request: Request) -> Optional[str]: """从请求中提取会话token""" # 尝试从Authorization header提取 auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): return auth_header[7:] # 移除"Bearer "前缀 # 可以扩展其他提取方式 return None def get_method(self) -> str: """获取验证方法 参考osu-web的逻辑,智能选择验证方法 """ current_method = self.session.verification_method if current_method is None: # 智能选择验证方法 # 参考osu-web: API版本 < 20250913 或用户没有TOTP时使用邮件验证 # 这里简化为检查用户是否有TOTP totp_key = getattr(self.user, 'totp_key', None) current_method = 'totp' if totp_key else 'mail' # 设置验证方法 asyncio.create_task(self._set_verification_method(current_method)) return current_method async def _set_verification_method(self, method: str) -> None: """内部方法:设置验证方法""" try: token_id = self.session.token_id if token_id is not None and method in ['totp', 'mail']: # 类型检查确保method是正确的字面量类型 verification_method = method if method in ['totp', 'mail'] else 'totp' await LoginSessionService.set_login_method( self.user.id, token_id, verification_method, self.redis # type: ignore ) except Exception as e: logger.error(f"[Session Verification] Error setting verification method: {e}") def is_verified(self) -> bool: """检查会话是否已验证""" return self.session.is_verified async def mark_verified(self) -> None: """标记会话为已验证""" try: # 创建专用数据库会话 db = with_db() try: token_id = self.session.token_id if token_id is not None: await LoginSessionService.mark_session_verified( db, self.redis, self.user.id, token_id ) finally: await db.close() except Exception as e: logger.error(f"[Session Verification] Error marking session verified: {e}") def get_key(self) -> str: """获取会话密钥""" return str(self.session.id) if self.session.id else "" def get_key_for_event(self) -> str: """获取用于事件广播的会话密钥""" return LoginSessionService.get_key_for_event(self.get_key()) def user_id(self) -> int: """获取用户ID""" return self.user.id async def issue_mail_if_needed(self) -> None: """如果需要,发送验证邮件""" try: if self.get_method() == "mail": from app.service.verification_service import EmailVerificationService # 创建专用数据库会话发送邮件 db = with_db() try: await EmailVerificationService.send_verification_email( db, self.redis, self.user.id, self.user.username, self.user.email, None, None ) finally: await db.close() except Exception as e: logger.error(f"[Session Verification] Error issuing mail: {e}") class SessionVerificationController: """会话验证控制器 参考osu-web的Controller类实现 """ # 需要跳过验证的路由(参考osu-web的SKIP_VERIFICATION_ROUTES) SKIP_VERIFICATION_ROUTES = { "/api/v2/session/verify", "/api/v2/session/verify/reissue", "/api/v2/me", "/api/v2/logout", "/oauth/token", } @staticmethod def should_skip_verification(request: Request) -> bool: """检查是否应该跳过验证""" path = request.url.path return path in SessionVerificationController.SKIP_VERIFICATION_ROUTES @staticmethod async def initiate_verification( state: SessionVerificationState, request: Request, ) -> Response: """启动会话验证流程 参考osu-web的initiate方法 """ try: method = state.get_method() # 如果是邮件验证,发送验证邮件 if method == "mail": await state.issue_mail_if_needed() # API请求返回JSON响应 if request.url.path.startswith("/api/"): return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={"method": method} ) # 其他情况可以扩展支持HTML响应 return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "authentication": "verify", "method": method, "message": "Session verification required" } ) except Exception as e: logger.error(f"[Session Verification] Error initiating verification: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Verification initiation failed" ) class SessionVerificationMiddleware: """会话验证中间件 参考osu-web的VerifyUser中间件实现 """ def __init__(self, app: Callable[[Request], Awaitable[Response]]): self.app = app async def __call__(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: """中间件主要逻辑""" try: # 检查是否需要跳过验证 if SessionVerificationController.should_skip_verification(request): return await call_next(request) # 获取依赖项 user = await self._get_user(request) if not user: # 未认证用户跳过验证 return await call_next(request) # 获取数据库和Redis连接 db = await self._get_db() redis = await self._get_redis() # 获取会话验证状态 state = await SessionVerificationState.get_current(request, db, redis, user) if not state: # 无法获取会话状态,继续请求 return await call_next(request) # 检查是否已验证 if state.is_verified(): # 已验证,继续请求 return await call_next(request) # 检查是否需要验证 if not self._requires_verification(request): return await call_next(request) # 启动验证流程 return await SessionVerificationController.initiate_verification(state, request) except Exception as e: logger.error(f"[Session Verification Middleware] Unexpected error: {e}") # 出错时允许请求继续,避免阻塞正常流程 return await call_next(request) async def _get_user(self, request: Request) -> Optional[User]: """获取当前用户""" try: # 这里需要手动获取用户,因为在中间件中无法直接使用依赖注入 # 简化实现,实际应该从token中解析用户 return None # 暂时返回None,需要实际实现 except Exception: return None async def _get_db(self) -> AsyncSession: """获取数据库连接""" return with_db() async def _get_redis(self) -> Redis: """获取Redis连接""" return get_redis() def _requires_verification(self, request: Request) -> bool: """检查是否需要验证 参考osu-web的requiresVerification方法 """ method = request.method # GET/HEAD/OPTIONS请求一般不需要验证 safe_methods = {"GET", "HEAD", "OPTIONS"} if method in safe_methods: return False # POST/PUT/DELETE等修改操作需要验证 return True # FastAPI中间件包装器 class FastAPISessionVerificationMiddleware: """FastAPI会话验证中间件包装器""" def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): if scope["type"] != "http": return await self.app(scope, receive, send) request = Request(scope, receive) async def call_next(req: Request) -> Response: # 这里需要调用FastAPI应用 return Response("OK") # 占位符实现 middleware = SessionVerificationMiddleware(call_next) response = await middleware(request, call_next) await response(scope, receive, send)