feat(session-verify): 添加 TOTP 支持 (#34)
* chore(deps): add pyotp * feat(auth): implement TOTP verification feat(auth): implement TOTP verification and email verification services - Added TOTP keys management with a new database model `TotpKeys`. - Introduced `EmailVerification` and `LoginSession` models for email verification. - Created `verification_service` to handle email verification logic and TOTP processes. - Updated user response models to include session verification methods. - Implemented routes for TOTP creation, verification, and fallback to email verification. - Enhanced login session management to support new location checks and verification methods. - Added migration script to create `totp_keys` table in the database. * feat(config): update config example * docs(totp): complete creating TOTP flow * refactor(totp): resolve review * feat(api): forbid unverified request * fix(totp): trace session by token id to avoid other sessions are forbidden * chore(linter): make pyright happy * fix(totp): only mark sessions with a specified token id
This commit is contained in:
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.database.verification import EmailVerification, LoginSession
|
||||
from app.log import logger
|
||||
from app.utils import utcnow
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import asyncio
|
||||
from app.database.user_login_log import UserLoginLog
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip
|
||||
from app.log import logger
|
||||
from app.utils import utcnow
|
||||
from app.utils import simplify_user_agent, utcnow
|
||||
|
||||
from fastapi import Request
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -45,9 +45,6 @@ class LoginLogService:
|
||||
raw_ip = get_client_ip(request)
|
||||
ip_address = normalize_ip(raw_ip)
|
||||
|
||||
# 获取并简化User-Agent
|
||||
from app.utils import simplify_user_agent
|
||||
|
||||
raw_user_agent = request.headers.get("User-Agent", "")
|
||||
user_agent = simplify_user_agent(raw_user_agent, max_length=500)
|
||||
|
||||
|
||||
@@ -7,15 +7,16 @@ from __future__ import annotations
|
||||
from datetime import timedelta
|
||||
import secrets
|
||||
import string
|
||||
from typing import Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.database.verification import EmailVerification, LoginSession
|
||||
from app.log import logger
|
||||
from app.service.email_queue import email_queue # 导入邮件队列
|
||||
from app.utils import utcnow
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -279,20 +280,18 @@ This email was sent automatically, please do not reply.
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def verify_code(
|
||||
async def verify_email_code(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
code: str,
|
||||
ip_address: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""验证验证码"""
|
||||
"""验证邮箱验证码"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}")
|
||||
# 仍然标记登录会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
return True, "验证成功(邮件验证功能已禁用)"
|
||||
|
||||
# 先从 Redis 检查
|
||||
@@ -319,9 +318,6 @@ This email was sent automatically, please do not reply.
|
||||
verification.is_used = True
|
||||
verification.used_at = utcnow()
|
||||
|
||||
# 同时更新对应的登录会话状态
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# 删除 Redis 记录
|
||||
@@ -382,10 +378,12 @@ class LoginSessionService:
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
token_id: int,
|
||||
ip_address: str,
|
||||
user_agent: str | None = None,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False,
|
||||
is_verified: bool = False,
|
||||
) -> LoginSession:
|
||||
"""创建登录会话"""
|
||||
|
||||
@@ -393,13 +391,13 @@ class LoginSessionService:
|
||||
|
||||
session = LoginSession(
|
||||
user_id=user_id,
|
||||
session_token=session_token,
|
||||
token_id=token_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=None,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location,
|
||||
expires_at=utcnow() + timedelta(hours=24), # 24小时过期
|
||||
is_verified=not is_new_location, # 新位置需要验证
|
||||
is_verified=is_verified,
|
||||
)
|
||||
|
||||
db.add(session)
|
||||
@@ -416,46 +414,21 @@ class LoginSessionService:
|
||||
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
async def verify_session(
|
||||
db: AsyncSession, redis: Redis, session_token: str, verification_code: str
|
||||
) -> tuple[bool, str]:
|
||||
"""验证会话(通过邮件验证码)"""
|
||||
try:
|
||||
# 从 Redis 获取用户ID
|
||||
user_id = await redis.get(f"login_session:{session_token}")
|
||||
if not user_id:
|
||||
return False, "会话无效或已过期"
|
||||
@classmethod
|
||||
def _session_verify_redis_key(cls, user_id: int, token_id: int) -> str:
|
||||
return f"session_verification_method:{user_id}:{token_id}"
|
||||
|
||||
user_id = int(user_id)
|
||||
@classmethod
|
||||
async def get_login_method(cls, user_id: int, token_id: int, redis: Redis) -> Literal["totp", "mail"] | None:
|
||||
return await redis.get(cls._session_verify_redis_key(user_id, token_id))
|
||||
|
||||
# 验证邮件验证码
|
||||
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code)
|
||||
@classmethod
|
||||
async def set_login_method(cls, user_id: int, token_id: int, method: Literal["totp", "mail"], redis: Redis) -> None:
|
||||
await redis.set(cls._session_verify_redis_key(user_id, token_id), method)
|
||||
|
||||
if not success:
|
||||
return False, message
|
||||
|
||||
# 更新会话状态
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.session_token == session_token,
|
||||
LoginSession.user_id == user_id,
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
)
|
||||
)
|
||||
|
||||
session = result.first()
|
||||
if session:
|
||||
session.is_verified = True
|
||||
session.verified_at = utcnow()
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"[Login Session] User {user_id} session verification successful")
|
||||
return True, "会话验证成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during session verification: {e}")
|
||||
return False, "验证过程中发生错误"
|
||||
@classmethod
|
||||
async def clear_login_method(cls, user_id: int, token_id: int, redis: Redis) -> None:
|
||||
await redis.delete(cls._session_verify_redis_key(user_id, token_id))
|
||||
|
||||
@staticmethod
|
||||
async def check_new_location(
|
||||
@@ -485,7 +458,7 @@ class LoginSessionService:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def mark_session_verified(db: AsyncSession, user_id: int) -> bool:
|
||||
async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool:
|
||||
"""标记用户的未验证会话为已验证"""
|
||||
try:
|
||||
# 查找用户所有未验证且未过期的会话
|
||||
@@ -494,6 +467,7 @@ class LoginSessionService:
|
||||
LoginSession.user_id == user_id,
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > utcnow(),
|
||||
LoginSession.token_id == token_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -507,8 +481,27 @@ class LoginSessionService:
|
||||
if sessions:
|
||||
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
|
||||
|
||||
await LoginSessionService.clear_login_method(user_id, token_id, redis)
|
||||
|
||||
return len(sessions) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def check_is_need_verification(db: AsyncSession, user_id: int, token_id: int) -> bool:
|
||||
"""检查用户是否需要验证(有未验证的会话)"""
|
||||
if settings.enable_totp_verification or settings.enable_email_verification:
|
||||
unverified_session = (
|
||||
await db.exec(
|
||||
select(exists()).where(
|
||||
LoginSession.user_id == user_id,
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > utcnow(),
|
||||
LoginSession.token_id == token_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
return unverified_session or False
|
||||
return False
|
||||
Reference in New Issue
Block a user