添加邮件验证
This commit is contained in:
367
app/service/email_verification_service.py
Normal file
367
app/service/email_verification_service.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
邮件验证管理服务
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.service.email_service import email_service
|
||||
from app.log import logger
|
||||
from app.config import settings
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel import select
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class EmailVerificationService:
|
||||
"""邮件验证服务"""
|
||||
|
||||
@staticmethod
|
||||
def generate_verification_code() -> str:
|
||||
"""生成8位验证码"""
|
||||
return ''.join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
@staticmethod
|
||||
def generate_session_token() -> str:
|
||||
"""生成会话令牌"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
async def create_verification_record(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
) -> tuple[EmailVerification, str]:
|
||||
"""创建邮件验证记录"""
|
||||
|
||||
# 检查是否有未过期的验证码
|
||||
existing_result = await db.exec(
|
||||
select(EmailVerification).where(
|
||||
EmailVerification.user_id == user_id,
|
||||
EmailVerification.is_used == False,
|
||||
EmailVerification.expires_at > datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
existing = existing_result.first()
|
||||
|
||||
if existing:
|
||||
# 如果有未过期的验证码,直接返回
|
||||
return existing, existing.verification_code
|
||||
|
||||
# 生成新的验证码
|
||||
code = EmailVerificationService.generate_verification_code()
|
||||
|
||||
# 创建验证记录
|
||||
verification = EmailVerification(
|
||||
user_id=user_id,
|
||||
email=email,
|
||||
verification_code=code,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
db.add(verification)
|
||||
await db.commit()
|
||||
await db.refresh(verification)
|
||||
|
||||
# 存储到 Redis(用于快速验证)
|
||||
await redis.setex(
|
||||
f"email_verification:{user_id}:{code}",
|
||||
600, # 10分钟过期
|
||||
str(verification.id) if verification.id else "0"
|
||||
)
|
||||
|
||||
logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}")
|
||||
return verification, code
|
||||
|
||||
@staticmethod
|
||||
async def send_verification_email(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
) -> bool:
|
||||
"""发送验证邮件"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
|
||||
return True # 返回成功,但不执行验证流程
|
||||
|
||||
# 创建验证记录
|
||||
verification, code = await EmailVerificationService.create_verification_record(
|
||||
db, redis, user_id, email, ip_address, user_agent
|
||||
)
|
||||
|
||||
# 发送邮件
|
||||
success = await email_service.send_verification_email(email, code, username)
|
||||
|
||||
if success:
|
||||
logger.info(f"[Email Verification] Successfully sent verification email to {email} (user: {username})")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[Email Verification] Failed to send verification email: {email} (user: {username})")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def verify_code(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
code: str,
|
||||
ip_address: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""验证验证码"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}")
|
||||
# 仍然标记登录会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
return True, "验证成功(邮件验证功能已禁用)"
|
||||
|
||||
# 先从 Redis 检查
|
||||
verification_id = await redis.get(f"email_verification:{user_id}:{code}")
|
||||
if not verification_id:
|
||||
return False, "验证码无效或已过期"
|
||||
|
||||
# 从数据库获取验证记录
|
||||
result = await db.exec(
|
||||
select(EmailVerification).where(
|
||||
EmailVerification.id == int(verification_id),
|
||||
EmailVerification.user_id == user_id,
|
||||
EmailVerification.verification_code == code,
|
||||
EmailVerification.is_used == False,
|
||||
EmailVerification.expires_at > datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
|
||||
verification = result.first()
|
||||
if not verification:
|
||||
return False, "验证码无效或已过期"
|
||||
|
||||
# 标记为已使用
|
||||
verification.is_used = True
|
||||
verification.used_at = datetime.now(UTC)
|
||||
|
||||
# 同时更新对应的登录会话状态
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# 删除 Redis 记录
|
||||
await redis.delete(f"email_verification:{user_id}:{code}")
|
||||
|
||||
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
|
||||
return True, "验证成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during verification code validation: {e}")
|
||||
return False, "验证过程中发生错误"
|
||||
|
||||
@staticmethod
|
||||
async def resend_verification_code(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""重新发送验证码"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
|
||||
return True, "验证码已发送(邮件验证功能已禁用)"
|
||||
|
||||
# 检查重发频率限制(60秒内只能发送一次)
|
||||
rate_limit_key = f"email_verification_rate_limit:{user_id}"
|
||||
if await redis.get(rate_limit_key):
|
||||
return False, "请等待60秒后再重新发送"
|
||||
|
||||
# 设置频率限制
|
||||
await redis.setex(rate_limit_key, 60, "1")
|
||||
|
||||
# 生成新的验证码
|
||||
success = await EmailVerificationService.send_verification_email(
|
||||
db, redis, user_id, username, email, ip_address, user_agent
|
||||
)
|
||||
|
||||
if success:
|
||||
return True, "验证码已重新发送"
|
||||
else:
|
||||
return False, "重新发送失败,请稍后再试"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during resending verification code: {e}")
|
||||
return False, "重新发送过程中发生错误"
|
||||
|
||||
|
||||
class LoginSessionService:
|
||||
"""登录会话服务"""
|
||||
|
||||
@staticmethod
|
||||
async def create_session(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
ip_address: str,
|
||||
user_agent: str | None = None,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False
|
||||
) -> LoginSession:
|
||||
"""创建登录会话"""
|
||||
session_token = EmailVerificationService.generate_session_token()
|
||||
|
||||
session = LoginSession(
|
||||
user_id=user_id,
|
||||
session_token=session_token,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location,
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期
|
||||
is_verified=not is_new_location # 新位置需要验证
|
||||
)
|
||||
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
# 存储到 Redis
|
||||
await redis.setex(
|
||||
f"login_session:{session_token}",
|
||||
86400, # 24小时
|
||||
user_id
|
||||
)
|
||||
|
||||
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
async def verify_session(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
session_token: str,
|
||||
verification_code: str
|
||||
) -> tuple[bool, str]:
|
||||
"""验证会话(通过邮件验证码)"""
|
||||
try:
|
||||
# 从 Redis 获取用户ID
|
||||
user_id = await redis.get(f"login_session:{session_token}")
|
||||
if not user_id:
|
||||
return False, "会话无效或已过期"
|
||||
|
||||
user_id = int(user_id)
|
||||
|
||||
# 验证邮件验证码
|
||||
success, message = await EmailVerificationService.verify_code(
|
||||
db, redis, user_id, verification_code
|
||||
)
|
||||
|
||||
if not success:
|
||||
return False, message
|
||||
|
||||
# 更新会话状态
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.session_token == session_token,
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.is_verified == False
|
||||
)
|
||||
)
|
||||
|
||||
session = result.first()
|
||||
if session:
|
||||
session.is_verified = True
|
||||
session.verified_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"[Login Session] User {user_id} session verification successful")
|
||||
return True, "会话验证成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during session verification: {e}")
|
||||
return False, "验证过程中发生错误"
|
||||
|
||||
@staticmethod
|
||||
async def check_new_location(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
ip_address: str,
|
||||
country_code: str | None = None
|
||||
) -> bool:
|
||||
"""检查是否为新位置登录"""
|
||||
try:
|
||||
# 查看过去30天内是否有相同IP或相同国家的登录记录
|
||||
thirty_days_ago = datetime.now(UTC) - timedelta(days=30)
|
||||
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.created_at > thirty_days_ago,
|
||||
(LoginSession.ip_address == ip_address) |
|
||||
(LoginSession.country_code == country_code)
|
||||
)
|
||||
)
|
||||
|
||||
existing_sessions = result.all()
|
||||
|
||||
# 如果有历史记录,则不是新位置
|
||||
return len(existing_sessions) == 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during new location check: {e}")
|
||||
# 出错时默认为新位置(更安全)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def mark_session_verified(
|
||||
db: AsyncSession,
|
||||
user_id: int
|
||||
) -> bool:
|
||||
"""标记用户的未验证会话为已验证"""
|
||||
try:
|
||||
# 查找用户所有未验证且未过期的会话
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.is_verified == False,
|
||||
LoginSession.expires_at > datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
|
||||
sessions = result.all()
|
||||
|
||||
# 标记所有会话为已验证
|
||||
for session in sessions:
|
||||
session.is_verified = True
|
||||
session.verified_at = datetime.now(UTC)
|
||||
|
||||
if sessions:
|
||||
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
|
||||
|
||||
return len(sessions) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
|
||||
return False
|
||||
Reference in New Issue
Block a user