feat(session-verify): 添加 TOTP 支持 (#34)

* chore(deps): add pyotp

* feat(auth): implement TOTP verification

feat(auth): implement TOTP verification and email verification services

- Added TOTP keys management with a new database model `TotpKeys`.
- Introduced `EmailVerification` and `LoginSession` models for email verification.
- Created `verification_service` to handle email verification logic and TOTP processes.
- Updated user response models to include session verification methods.
- Implemented routes for TOTP creation, verification, and fallback to email verification.
- Enhanced login session management to support new location checks and verification methods.
- Added migration script to create `totp_keys` table in the database.

* feat(config): update config example

* docs(totp): complete creating TOTP flow

* refactor(totp): resolve review

* feat(api): forbid unverified request

* fix(totp): trace session by token id to avoid other sessions are forbidden

* chore(linter): make pyright happy

* fix(totp): only mark sessions with a specified token id
This commit is contained in:
MingxuanGame
2025-09-21 19:50:11 +08:00
committed by GitHub
parent 7b4ff1224d
commit 1527e23b43
25 changed files with 684 additions and 235 deletions

View File

@@ -6,7 +6,7 @@ from __future__ import annotations
from datetime import timedelta
from app.database.email_verification import EmailVerification, LoginSession
from app.database.verification import EmailVerification, LoginSession
from app.log import logger
from app.utils import utcnow

View File

@@ -9,7 +9,7 @@ import asyncio
from app.database.user_login_log import UserLoginLog
from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip
from app.log import logger
from app.utils import utcnow
from app.utils import simplify_user_agent, utcnow
from fastapi import Request
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -45,9 +45,6 @@ class LoginLogService:
raw_ip = get_client_ip(request)
ip_address = normalize_ip(raw_ip)
# 获取并简化User-Agent
from app.utils import simplify_user_agent
raw_user_agent = request.headers.get("User-Agent", "")
user_agent = simplify_user_agent(raw_user_agent, max_length=500)

View File

@@ -7,15 +7,16 @@ from __future__ import annotations
from datetime import timedelta
import secrets
import string
from typing import Literal
from app.config import settings
from app.database.email_verification import EmailVerification, LoginSession
from app.database.verification import EmailVerification, LoginSession
from app.log import logger
from app.service.email_queue import email_queue # 导入邮件队列
from app.utils import utcnow
from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -279,20 +280,18 @@ This email was sent automatically, please do not reply.
return False
@staticmethod
async def verify_code(
async def verify_email_code(
db: AsyncSession,
redis: Redis,
user_id: int,
code: str,
ip_address: str | None = None,
) -> tuple[bool, str]:
"""验证验证码"""
"""验证邮箱验证码"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}")
# 仍然标记登录会话为已验证
await LoginSessionService.mark_session_verified(db, user_id)
return True, "验证成功(邮件验证功能已禁用)"
# 先从 Redis 检查
@@ -319,9 +318,6 @@ This email was sent automatically, please do not reply.
verification.is_used = True
verification.used_at = utcnow()
# 同时更新对应的登录会话状态
await LoginSessionService.mark_session_verified(db, user_id)
await db.commit()
# 删除 Redis 记录
@@ -382,10 +378,12 @@ class LoginSessionService:
db: AsyncSession,
redis: Redis,
user_id: int,
token_id: int,
ip_address: str,
user_agent: str | None = None,
country_code: str | None = None,
is_new_location: bool = False,
is_verified: bool = False,
) -> LoginSession:
"""创建登录会话"""
@@ -393,13 +391,13 @@ class LoginSessionService:
session = LoginSession(
user_id=user_id,
session_token=session_token,
token_id=token_id,
ip_address=ip_address,
user_agent=None,
country_code=country_code,
is_new_location=is_new_location,
expires_at=utcnow() + timedelta(hours=24), # 24小时过期
is_verified=not is_new_location, # 新位置需要验证
is_verified=is_verified,
)
db.add(session)
@@ -416,46 +414,21 @@ class LoginSessionService:
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
return session
@staticmethod
async def verify_session(
db: AsyncSession, redis: Redis, session_token: str, verification_code: str
) -> tuple[bool, str]:
"""验证会话(通过邮件验证码)"""
try:
# 从 Redis 获取用户ID
user_id = await redis.get(f"login_session:{session_token}")
if not user_id:
return False, "会话无效或已过期"
@classmethod
def _session_verify_redis_key(cls, user_id: int, token_id: int) -> str:
return f"session_verification_method:{user_id}:{token_id}"
user_id = int(user_id)
@classmethod
async def get_login_method(cls, user_id: int, token_id: int, redis: Redis) -> Literal["totp", "mail"] | None:
return await redis.get(cls._session_verify_redis_key(user_id, token_id))
# 验证邮件验证码
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code)
@classmethod
async def set_login_method(cls, user_id: int, token_id: int, method: Literal["totp", "mail"], redis: Redis) -> None:
await redis.set(cls._session_verify_redis_key(user_id, token_id), method)
if not success:
return False, message
# 更新会话状态
result = await db.exec(
select(LoginSession).where(
LoginSession.session_token == session_token,
LoginSession.user_id == user_id,
col(LoginSession.is_verified).is_(False),
)
)
session = result.first()
if session:
session.is_verified = True
session.verified_at = utcnow()
await db.commit()
logger.info(f"[Login Session] User {user_id} session verification successful")
return True, "会话验证成功"
except Exception as e:
logger.error(f"[Login Session] Exception during session verification: {e}")
return False, "验证过程中发生错误"
@classmethod
async def clear_login_method(cls, user_id: int, token_id: int, redis: Redis) -> None:
await redis.delete(cls._session_verify_redis_key(user_id, token_id))
@staticmethod
async def check_new_location(
@@ -485,7 +458,7 @@ class LoginSessionService:
return True
@staticmethod
async def mark_session_verified(db: AsyncSession, user_id: int) -> bool:
async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool:
"""标记用户的未验证会话为已验证"""
try:
# 查找用户所有未验证且未过期的会话
@@ -494,6 +467,7 @@ class LoginSessionService:
LoginSession.user_id == user_id,
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > utcnow(),
LoginSession.token_id == token_id,
)
)
@@ -507,8 +481,27 @@ class LoginSessionService:
if sessions:
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
await LoginSessionService.clear_login_method(user_id, token_id, redis)
return len(sessions) > 0
except Exception as e:
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
return False
@staticmethod
async def check_is_need_verification(db: AsyncSession, user_id: int, token_id: int) -> bool:
"""检查用户是否需要验证(有未验证的会话)"""
if settings.enable_totp_verification or settings.enable_email_verification:
unverified_session = (
await db.exec(
select(exists()).where(
LoginSession.user_id == user_id,
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > utcnow(),
LoginSession.token_id == token_id,
)
)
).first()
return unverified_session or False
return False