Files
g0v0-server/app/middleware/session_verification.py
2025-09-30 07:57:08 +00:00

312 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
会话验证中间件和状态管理
基于osu-web的会话验证系统实现
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import ClassVar, Literal, cast
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import get_redis, with_db
from app.log import logger
from app.service.verification_service import LoginSessionService
from app.utils import bg_tasks
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlmodel.ext.asyncio.session import AsyncSession
class SessionVerificationState:
"""会话验证状态管理类
参考osu-web的State类实现
"""
def __init__(self, session: LoginSession, user: User, redis: Redis):
self.session = session
self.user = user
self.redis = redis
@classmethod
async def get_current(
cls,
request: Request,
db: AsyncSession,
redis: Redis,
user: User,
) -> SessionVerificationState | None:
"""获取当前会话验证状态"""
try:
# 从请求头或token中获取会话信息
session_token = cls._extract_session_token(request)
if not session_token:
return None
# 查找会话
session = await LoginSessionService.find_for_verification(db, session_token)
if not session or session.user_id != user.id:
return None
return cls(session, user, redis)
except Exception as e:
logger.error(f"[Session Verification] Error getting current state: {e}")
return None
@staticmethod
def _extract_session_token(request: Request) -> str | None:
"""从请求中提取会话token"""
# 尝试从Authorization header提取
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return auth_header[7:] # 移除"Bearer "前缀
# 可以扩展其他提取方式
return None
def get_method(self) -> str:
"""获取验证方法
参考osu-web的逻辑智能选择验证方法
"""
current_method = self.session.verification_method
if current_method is None:
# 智能选择验证方法
# 参考osu-web: API版本 < 20250913 或用户没有TOTP时使用邮件验证
# 这里简化为检查用户是否有TOTP
totp_key = getattr(self.user, "totp_key", None)
current_method = "totp" if totp_key else "mail"
# 设置验证方法
bg_tasks.add_task(self._set_verification_method, current_method)
return current_method
async def _set_verification_method(self, method: str) -> None:
"""内部方法:设置验证方法"""
try:
token_id = self.session.token_id
if token_id is not None and method in ["totp", "mail"]:
# 类型检查确保method是正确的字面量类型
verification_method = method if method in ["totp", "mail"] else "totp"
await LoginSessionService.set_login_method(
self.user.id,
token_id,
cast(Literal["totp", "mail"], verification_method),
self.redis,
)
except Exception as e:
logger.error(f"[Session Verification] Error setting verification method: {e}")
def is_verified(self) -> bool:
"""检查会话是否已验证"""
return self.session.is_verified
async def mark_verified(self) -> None:
"""标记会话为已验证"""
try:
# 创建专用数据库会话
db = with_db()
try:
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.mark_session_verified(db, self.redis, self.user.id, token_id)
finally:
await db.close()
except Exception as e:
logger.error(f"[Session Verification] Error marking session verified: {e}")
def get_key(self) -> str:
"""获取会话密钥"""
return str(self.session.id) if self.session.id else ""
def get_key_for_event(self) -> str:
"""获取用于事件广播的会话密钥"""
return LoginSessionService.get_key_for_event(self.get_key())
def user_id(self) -> int:
"""获取用户ID"""
return self.user.id
async def issue_mail_if_needed(self) -> None:
"""如果需要,发送验证邮件"""
try:
if self.get_method() == "mail":
from app.service.verification_service import EmailVerificationService
# 创建专用数据库会话发送邮件
db = with_db()
try:
await EmailVerificationService.send_verification_email(
db, self.redis, self.user.id, self.user.username, self.user.email, None, None
)
finally:
await db.close()
except Exception as e:
logger.error(f"[Session Verification] Error issuing mail: {e}")
class SessionVerificationController:
"""会话验证控制器
参考osu-web的Controller类实现
"""
# 需要跳过验证的路由参考osu-web的SKIP_VERIFICATION_ROUTES
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/me",
"/api/v2/logout",
"/oauth/token",
}
@staticmethod
def should_skip_verification(request: Request) -> bool:
"""检查是否应该跳过验证"""
path = request.url.path
return path in SessionVerificationController.SKIP_VERIFICATION_ROUTES
@staticmethod
async def initiate_verification(
state: SessionVerificationState,
request: Request,
) -> Response:
"""启动会话验证流程
参考osu-web的initiate方法
"""
try:
method = state.get_method()
# 如果是邮件验证,发送验证邮件
if method == "mail":
await state.issue_mail_if_needed()
# API请求返回JSON响应
if request.url.path.startswith("/api/"):
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": method})
# 其他情况可以扩展支持HTML响应
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"authentication": "verify", "method": method, "message": "Session verification required"},
)
except Exception as e:
logger.error(f"[Session Verification] Error initiating verification: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Verification initiation failed"
)
class SessionVerificationMiddleware:
"""会话验证中间件
参考osu-web的VerifyUser中间件实现
"""
def __init__(self, app: Callable[[Request], Awaitable[Response]]):
self.app = app
async def __call__(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
"""中间件主要逻辑"""
try:
# 检查是否需要跳过验证
if SessionVerificationController.should_skip_verification(request):
return await call_next(request)
# 获取依赖项
user = await self._get_user(request)
if not user:
# 未认证用户跳过验证
return await call_next(request)
# 获取数据库和Redis连接
db = await self._get_db()
redis = await self._get_redis()
# 获取会话验证状态
state = await SessionVerificationState.get_current(request, db, redis, user)
if not state:
# 无法获取会话状态,继续请求
return await call_next(request)
# 检查是否已验证
if state.is_verified():
# 已验证,继续请求
return await call_next(request)
# 检查是否需要验证
if not self._requires_verification(request):
return await call_next(request)
# 启动验证流程
return await SessionVerificationController.initiate_verification(state, request)
except Exception as e:
logger.error(f"[Session Verification Middleware] Unexpected error: {e}")
# 出错时允许请求继续,避免阻塞正常流程
return await call_next(request)
async def _get_user(self, request: Request) -> User | None:
"""获取当前用户"""
try:
# 这里需要手动获取用户,因为在中间件中无法直接使用依赖注入
# 简化实现实际应该从token中解析用户
return None # 暂时返回None需要实际实现
except Exception:
return None
async def _get_db(self) -> AsyncSession:
"""获取数据库连接"""
return with_db()
async def _get_redis(self) -> Redis:
"""获取Redis连接"""
return get_redis()
def _requires_verification(self, request: Request) -> bool:
"""检查是否需要验证
参考osu-web的requiresVerification方法
"""
method = request.method
# GET/HEAD/OPTIONS请求一般不需要验证
safe_methods = {"GET", "HEAD", "OPTIONS"}
if method in safe_methods:
return False
# POST/PUT/DELETE等修改操作需要验证
return True
# FastAPI中间件包装器
class FastAPISessionVerificationMiddleware:
"""FastAPI会话验证中间件包装器"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope, receive)
async def call_next(req: Request) -> Response:
# 这里需要调用FastAPI应用
return Response("OK") # 占位符实现
middleware = SessionVerificationMiddleware(call_next)
response = await middleware(request, call_next)
await response(scope, receive, send)