feat(auth): support trusted device (#52)

New API to maintain sessions and devices:

- GET /api/private/admin/sessions
- DELETE /api/private/admin/sessions/{session_id}
- GET /api/private/admin/trusted-devices
- DELETE /api/private/admin/trusted-devices/{device_id}

Auth:

web clients request `/oauth/token` and `/api/v2/session/verify` with `X-UUID` header to save the client as trusted device.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
MingxuanGame
2025-10-03 11:26:43 +08:00
committed by GitHub
parent f34ed53a55
commit 40670c094b
28 changed files with 897 additions and 1456 deletions

View File

@@ -1,301 +0,0 @@
"""
会话验证中间件和状态管理
基于osu-web的会话验证系统实现
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import ClassVar, Literal, cast
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import get_redis, with_db
from app.log import logger
from app.service.verification_service import LoginSessionService
from app.utils import bg_tasks
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlmodel.ext.asyncio.session import AsyncSession
class SessionVerificationState:
"""会话验证状态管理类
参考osu-web的State类实现
"""
def __init__(self, session: LoginSession, user: User, redis: Redis):
self.session = session
self.user = user
self.redis = redis
@classmethod
async def get_current(
cls,
request: Request,
db: AsyncSession,
redis: Redis,
user: User,
) -> SessionVerificationState | None:
"""获取当前会话验证状态"""
try:
# 从请求头或token中获取会话信息
session_token = cls._extract_session_token(request)
if not session_token:
return None
# 查找会话
session = await LoginSessionService.find_for_verification(db, session_token)
if not session or session.user_id != user.id:
return None
return cls(session, user, redis)
except Exception as e:
logger.error(f"[Session Verification] Error getting current state: {e}")
return None
@staticmethod
def _extract_session_token(request: Request) -> str | None:
"""从请求中提取会话token"""
# 尝试从Authorization header提取
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return auth_header[7:] # 移除"Bearer "前缀
# 可以扩展其他提取方式
return None
def get_method(self) -> str:
"""获取验证方法
参考osu-web的逻辑智能选择验证方法
"""
current_method = self.session.verification_method
if current_method is None:
# 智能选择验证方法
# 参考osu-web: API版本 < 20250913 或用户没有TOTP时使用邮件验证
# 这里简化为检查用户是否有TOTP
totp_key = getattr(self.user, "totp_key", None)
current_method = "totp" if totp_key else "mail"
# 设置验证方法
bg_tasks.add_task(self._set_verification_method, current_method)
return current_method
async def _set_verification_method(self, method: str) -> None:
"""内部方法:设置验证方法"""
try:
token_id = self.session.token_id
if token_id is not None and method in ["totp", "mail"]:
# 类型检查确保method是正确的字面量类型
verification_method = method if method in ["totp", "mail"] else "totp"
await LoginSessionService.set_login_method(
self.user.id,
token_id,
cast(Literal["totp", "mail"], verification_method),
self.redis,
)
except Exception as e:
logger.error(f"[Session Verification] Error setting verification method: {e}")
def is_verified(self) -> bool:
"""检查会话是否已验证"""
return self.session.is_verified
async def mark_verified(self) -> None:
"""标记会话为已验证"""
try:
async with with_db() as db:
# 创建专用数据库会话
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.mark_session_verified(db, self.redis, self.user.id, token_id)
except Exception as e:
logger.error(f"[Session Verification] Error marking session verified: {e}")
def get_key(self) -> str:
"""获取会话密钥"""
return str(self.session.id) if self.session.id else ""
def get_key_for_event(self) -> str:
"""获取用于事件广播的会话密钥"""
return LoginSessionService.get_key_for_event(self.get_key())
def user_id(self) -> int:
"""获取用户ID"""
return self.user.id
async def issue_mail_if_needed(self) -> None:
"""如果需要,发送验证邮件"""
try:
if self.get_method() == "mail":
from app.service.verification_service import EmailVerificationService
# 创建专用数据库会话发送邮件
async with with_db() as db:
await EmailVerificationService.send_verification_email(
db, self.redis, self.user.id, self.user.username, self.user.email, None, None
)
except Exception as e:
logger.error(f"[Session Verification] Error issuing mail: {e}")
class SessionVerificationController:
"""会话验证控制器
参考osu-web的Controller类实现
"""
# 需要跳过验证的路由参考osu-web的SKIP_VERIFICATION_ROUTES
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/me",
"/api/v2/logout",
"/oauth/token",
}
@staticmethod
def should_skip_verification(request: Request) -> bool:
"""检查是否应该跳过验证"""
path = request.url.path
return path in SessionVerificationController.SKIP_VERIFICATION_ROUTES
@staticmethod
async def initiate_verification(
state: SessionVerificationState,
request: Request,
) -> Response:
"""启动会话验证流程
参考osu-web的initiate方法
"""
try:
method = state.get_method()
# 如果是邮件验证,发送验证邮件
if method == "mail":
await state.issue_mail_if_needed()
# API请求返回JSON响应
if request.url.path.startswith("/api/"):
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": method})
# 其他情况可以扩展支持HTML响应
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"authentication": "verify", "method": method, "message": "Session verification required"},
)
except Exception as e:
logger.error(f"[Session Verification] Error initiating verification: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Verification initiation failed"
)
class SessionVerificationMiddleware:
"""会话验证中间件
参考osu-web的VerifyUser中间件实现
"""
def __init__(self, app: Callable[[Request], Awaitable[Response]]):
self.app = app
async def __call__(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
"""中间件主要逻辑"""
try:
# 检查是否需要跳过验证
if SessionVerificationController.should_skip_verification(request):
return await call_next(request)
# 获取依赖项
user = await self._get_user(request)
if not user:
# 未认证用户跳过验证
return await call_next(request)
# 获取数据库和Redis连接
async with with_db() as db:
redis = await self._get_redis()
# 获取会话验证状态
state = await SessionVerificationState.get_current(request, db, redis, user)
if not state:
# 无法获取会话状态,继续请求
return await call_next(request)
# 检查是否已验证
if state.is_verified():
# 已验证,继续请求
return await call_next(request)
# 检查是否需要验证
if not self._requires_verification(request):
return await call_next(request)
# 启动验证流程
return await SessionVerificationController.initiate_verification(state, request)
except Exception as e:
logger.error(f"[Session Verification Middleware] Unexpected error: {e}")
# 出错时允许请求继续,避免阻塞正常流程
return await call_next(request)
async def _get_user(self, request: Request) -> User | None:
"""获取当前用户"""
try:
# 这里需要手动获取用户,因为在中间件中无法直接使用依赖注入
# 简化实现实际应该从token中解析用户
return None # 暂时返回None需要实际实现
except Exception:
return None
async def _get_redis(self) -> Redis:
"""获取Redis连接"""
return get_redis()
def _requires_verification(self, request: Request) -> bool:
"""检查是否需要验证
参考osu-web的requiresVerification方法
"""
method = request.method
# GET/HEAD/OPTIONS请求一般不需要验证
safe_methods = {"GET", "HEAD", "OPTIONS"}
if method in safe_methods:
return False
# POST/PUT/DELETE等修改操作需要验证
return True
# FastAPI中间件包装器
class FastAPISessionVerificationMiddleware:
"""FastAPI会话验证中间件包装器"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope, receive)
async def call_next(req: Request) -> Response:
# 这里需要调用FastAPI应用
return Response("OK") # 占位符实现
middleware = SessionVerificationMiddleware(call_next)
response = await middleware(request, call_next)
await response(scope, receive, send)

View File

@@ -10,11 +10,13 @@ from collections.abc import Callable
from typing import ClassVar
from app.auth import get_token_by_access_token
from app.const import SUPPORT_TOTP_VERIFICATION_VER
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import get_redis, with_db
from app.log import logger
from app.service.verification_service import LoginSessionService
from app.utils import extract_user_agent
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
@@ -34,7 +36,9 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/session/verify/mail-fallback",
"/api/v2/me",
"/api/v2/me/",
"/api/v2/logout",
"/oauth/token",
"/health",
@@ -44,10 +48,8 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
"/redoc",
}
# 需要强制验证的路由模式(敏感操作)
# 总是需要验证的路由前缀
ALWAYS_VERIFY_PATTERNS: ClassVar[set[str]] = {
"/api/v2/account/",
"/api/v2/settings/",
"/api/private/admin/",
}
@@ -110,9 +112,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
if path.startswith(pattern):
return True
# 特权用户或非活跃用户需要验证
# if hasattr(user, 'is_privileged') and user.is_privileged():
# return True
if not user.is_active:
return True
@@ -154,6 +153,14 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
try:
# 提取会话token这里简化为使用相同的auth token
auth_header = request.headers.get("Authorization", "")
api_version = 0
raw_api_version = request.headers.get("x-api-version")
if raw_api_version is not None:
try:
api_version = int(raw_api_version)
except ValueError:
api_version = 0
if not auth_header.startswith("Bearer "):
return None
@@ -168,7 +175,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
if not session or session.user_id != user.id:
return None
return SessionState(session, user, redis, db)
return SessionState(session, user, redis, db, api_version)
except Exception as e:
logger.error(f"[Verify Session Middleware] Error getting session state: {e}")
@@ -178,8 +185,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
"""启动验证流程"""
try:
method = await state.get_method()
# 如果是邮件验证,可以在这里触发发送邮件
if method == "mail":
await state.issue_mail_if_needed()
@@ -202,11 +207,12 @@ class SessionState:
简化版本的会话状态管理
"""
def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession):
def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession, api_version: int = 0) -> None:
self.session = session
self.user = user
self.redis = redis
self.db = db
self.api_version = api_version
self._verification_method: str | None = None
def is_verified(self) -> bool:
@@ -223,14 +229,15 @@ class SessionState:
self.user.id, token_id, self.redis
)
# 如果没有设置,智能选择
if self._verification_method is None:
# 检查用户是否有TOTP密钥
await self.user.awaitable_attrs.totp_key # 预加载
totp_key = getattr(self.user, "totp_key", None)
if self.api_version < SUPPORT_TOTP_VERIFICATION_VER:
self._verification_method = "mail"
return self._verification_method
await self.user.awaitable_attrs.totp_key
totp_key = self.user.totp_key
self._verification_method = "totp" if totp_key else "mail"
# 保存选择的方法
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.set_login_method(
@@ -244,8 +251,15 @@ class SessionState:
try:
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.mark_session_verified(self.db, self.redis, self.user.id, token_id)
self.session.is_verified = True # 更新本地状态
await LoginSessionService.mark_session_verified(
self.db,
self.redis,
self.user.id,
token_id,
self.session.ip_address,
extract_user_agent(self.session.user_agent),
self.session.web_uuid,
)
except Exception as e:
logger.error(f"[Session State] Error marking verified: {e}")
@@ -266,10 +280,12 @@ class SessionState:
"""获取会话密钥"""
return str(self.session.id) if self.session.id else ""
def get_key_for_event(self) -> str:
@property
def key_for_event(self) -> str:
"""获取用于事件广播的会话密钥"""
return LoginSessionService.get_key_for_event(self.get_key())
@property
def user_id(self) -> int:
"""获取用户ID"""
return self.user.id