修复邮件验证部分问题

This commit is contained in:
咕谷酱
2025-09-23 22:13:15 +08:00
parent 7d6eeae073
commit 99d6af1c1f
10 changed files with 770 additions and 31 deletions

4
.gitignore vendored
View File

@@ -223,3 +223,7 @@ logs/
osu-server-spectator-master/* osu-server-spectator-master/*
spectator-server/ spectator-server/
.github/copilot-instructions.md .github/copilot-instructions.md
osu-web-master/*
osu-web-master/.env.dusk.local.example
osu-web-master/.env.example
osu-web-master/.env.testing.example

View File

@@ -217,15 +217,41 @@ async def store_token(
access_token: str, access_token: str,
refresh_token: str, refresh_token: str,
expires_in: int, expires_in: int,
allow_multiple_devices: bool = True,
) -> OAuthToken: ) -> OAuthToken:
"""存储令牌到数据库""" """存储令牌到数据库(支持多设备)"""
expires_at = utcnow() + timedelta(seconds=expires_in) expires_at = utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌 if not allow_multiple_devices:
# 旧的行为:删除用户的旧令牌(单设备模式)
statement = select(OAuthToken).where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id) statement = select(OAuthToken).where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id)
old_tokens = (await db.exec(statement)).all() old_tokens = (await db.exec(statement)).all()
for token in old_tokens: for token in old_tokens:
await db.delete(token) await db.delete(token)
else:
# 新的行为:只删除过期的令牌,保留有效的令牌(多设备模式)
statement = select(OAuthToken).where(
OAuthToken.user_id == user_id, OAuthToken.client_id == client_id, OAuthToken.expires_at <= utcnow()
)
expired_tokens = (await db.exec(statement)).all()
for token in expired_tokens:
await db.delete(token)
# 限制每个用户每个客户端的最大令牌数量(防止无限增长)
max_tokens_per_client = settings.max_tokens_per_client
statement = (
select(OAuthToken)
.where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id, OAuthToken.expires_at > utcnow())
.order_by(OAuthToken.created_at.desc())
)
active_tokens = (await db.exec(statement)).all()
if len(active_tokens) >= max_tokens_per_client:
# 删除最旧的令牌
tokens_to_delete = active_tokens[max_tokens_per_client - 1 :]
for token in tokens_to_delete:
await db.delete(token)
logger.info(f"[Auth] Cleaned up {len(tokens_to_delete)} old tokens for user {user_id}")
# 检查是否有重复的 access_token # 检查是否有重复的 access_token
duplicate_token = (await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))).first() duplicate_token = (await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))).first()
@@ -244,6 +270,10 @@ async def store_token(
db.add(token_record) db.add(token_record)
await db.commit() await db.commit()
await db.refresh(token_record) await db.refresh(token_record)
logger.info(
f"[Auth] Created new token for user {user_id}, client {client_id} (multi-device: {allow_multiple_devices})"
)
return token_record return token_record

View File

@@ -309,6 +309,31 @@ STORAGE_SETTINGS='{
Field(default=False, description="是否启用邮件验证功能"), Field(default=False, description="是否启用邮件验证功能"),
"验证服务设置", "验证服务设置",
] ]
enable_smart_verification: Annotated[
bool,
Field(default=True, description="是否启用智能验证(基于客户端类型和设备信任)"),
"验证服务设置",
]
enable_multi_device_login: Annotated[
bool,
Field(default=True, description="是否允许多设备同时登录"),
"验证服务设置",
]
max_tokens_per_client: Annotated[
int,
Field(default=10, description="每个用户每个客户端的最大令牌数量"),
"验证服务设置",
]
device_trust_duration_days: Annotated[
int,
Field(default=30, description="设备信任持续天数"),
"验证服务设置",
]
location_trust_duration_days: Annotated[
int,
Field(default=90, description="位置信任持续天数"),
"验证服务设置",
]
smtp_server: Annotated[ smtp_server: Annotated[
str, str,
Field(default="localhost", description="SMTP 服务器地址"), Field(default="localhost", description="SMTP 服务器地址"),

View File

@@ -214,6 +214,9 @@ async def oauth_token(
): ):
scopes = scope.split(" ") scopes = scope.split(" ")
# 打印请求头
# logger.info(f"Request headers: {request.headers}")
client = ( client = (
await db.exec( await db.exec(
select(OAuthClient).where( select(OAuthClient).where(
@@ -303,6 +306,7 @@ async def oauth_token(
access_token, access_token,
refresh_token_str, refresh_token_str,
settings.access_token_expire_minutes * 60, settings.access_token_expire_minutes * 60,
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
) )
token_id = token.id token_id = token.id
@@ -333,24 +337,41 @@ async def oauth_token(
await db.refresh(user) await db.refresh(user)
session_verification_method = "mail" session_verification_method = "mail"
# 发送邮件验证码 # 使用智能验证发送邮件
verification_sent = await EmailVerificationService.send_verification_email( (
db, redis, user_id, user.username, user.email, ip_address, user_agent verification_sent,
verification_message,
client_info,
) = await EmailVerificationService.send_smart_verification_email(
db,
redis,
user_id,
user.username,
user.email,
ip_address,
user_agent,
client_id,
country_code,
is_new_location,
) )
# 记录需要二次验证的登录尝试 # 记录需要二次验证的登录尝试
client_display_name = client_info.client_type if client_info else "unknown"
await LoginLogService.record_login( await LoginLogService.record_login(
db=db, db=db,
user_id=user_id, user_id=user_id,
request=request, request=request,
login_success=True, login_success=True,
login_method="password_pending_verification", login_method="password_pending_verification",
notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}", notes=f"智能验证: {verification_message} - 客户端: {client_display_name}, "
f"IP: {ip_address}, 国家: {country_code}",
) )
if not verification_sent: if not verification_sent:
# 邮件发送失败,记录错误 # 邮件发送失败,记录错误
logger.error(f"[Auth] Failed to send email verification code for user {user_id}") logger.error(f"[Auth] Smart verification failed for user {user_id}: {verification_message}")
else:
logger.info(f"[Auth] Smart verification result for user {user_id}: {verification_message}")
elif is_new_location: elif is_new_location:
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证 # 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
await LoginSessionService.mark_session_verified(db, redis, user_id, token_id) await LoginSessionService.mark_session_verified(db, redis, user_id, token_id)
@@ -428,6 +449,7 @@ async def oauth_token(
access_token, access_token,
new_refresh_token, new_refresh_token,
settings.access_token_expire_minutes * 60, settings.access_token_expire_minutes * 60,
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
) )
return TokenResponse( return TokenResponse(
access_token=access_token, access_token=access_token,
@@ -492,6 +514,7 @@ async def oauth_token(
access_token, access_token,
refresh_token_str, refresh_token_str,
settings.access_token_expire_minutes * 60, settings.access_token_expire_minutes * 60,
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
) )
# 打印jwt # 打印jwt
@@ -538,6 +561,7 @@ async def oauth_token(
access_token, access_token,
refresh_token_str, refresh_token_str,
settings.access_token_expire_minutes * 60, settings.access_token_expire_minutes * 60,
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
) )
return TokenResponse( return TokenResponse(

View File

@@ -0,0 +1,230 @@
"""
客户端检测服务
用于识别不同类型的 osu! 客户端和设备
"""
from __future__ import annotations
from dataclasses import dataclass
import hashlib
import re
from typing import ClassVar, Literal
from app.log import logger
@dataclass
class ClientInfo:
"""客户端信息"""
client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"]
platform: str | None = None
version: str | None = None
device_fingerprint: str | None = None
is_trusted_client: bool = False
class ClientDetectionService:
"""客户端检测服务"""
# osu! 客户端的 User-Agent 模式
OSU_CLIENT_PATTERNS: ClassVar[dict[str, list[str]]] = {
"osu_stable": [
r"osu!/(\d+(?:\.\d+)*)", # osu!/20241001
r"osu!", # 简单匹配
],
"osu_lazer": [
r"osu-lazer/(\d+(?:\.\d+)*)", # osu-lazer/2024.1009.0
r"osu!lazer/(\d+(?:\.\d+)*)", # osu!lazer/2024.1009.0
],
"osu_web": [
r"Mozilla.*osu\.ppy\.sh", # 网页客户端
],
"mobile": [
r"osu!.*mobile",
r"osu.*Mobile",
r"Mobile.*osu",
],
}
# 受信任的客户端类型(不需要频繁验证)
TRUSTED_CLIENT_TYPES: ClassVar[set[str]] = {"osu_stable", "osu_lazer"}
@staticmethod
def detect_client(user_agent: str | None, client_id: int | None = None) -> ClientInfo:
"""
检测客户端类型和信息
Args:
user_agent: 用户代理字符串
client_id: OAuth 客户端 ID
Returns:
ClientInfo: 客户端信息
"""
from app.config import settings # 导入在函数内部避免循环导入
if not user_agent:
return ClientInfo(client_type="unknown")
# 优先通过 client_id 判断客户端类型
if client_id is not None:
if client_id == settings.osu_client_id:
# osu! stable 客户端
return ClientInfo(
client_type="osu_stable",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=True,
)
elif client_id == settings.osu_web_client_id:
# 检查 User-Agent 是否表明这是 Lazer 客户端
if user_agent and user_agent.strip() == "osu!":
# Lazer 客户端使用 web client_id 但发送简单的 "osu!" User-Agent
return ClientInfo(
client_type="osu_lazer",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=True,
)
else:
# 真正的 web 客户端
return ClientInfo(
client_type="osu_web",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=False,
)
# 回退到基于 User-Agent 的检测
for client_type_str, patterns in ClientDetectionService.OSU_CLIENT_PATTERNS.items():
for pattern in patterns:
match = re.search(pattern, user_agent, re.IGNORECASE)
if match:
version = match.group(1) if match.groups() else None
platform = ClientDetectionService._extract_platform(user_agent)
# 确保 client_type 是正确的 Literal 类型
client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"] = client_type_str # type: ignore
return ClientInfo(
client_type=client_type,
platform=platform,
version=version,
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=client_type in ClientDetectionService.TRUSTED_CLIENT_TYPES,
)
# 检测常见浏览器
if any(browser in user_agent.lower() for browser in ["chrome", "firefox", "safari", "edge"]):
return ClientInfo(
client_type="osu_web",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=False,
)
return ClientInfo(
client_type="unknown",
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=False,
)
@staticmethod
def _extract_platform(user_agent: str) -> str | None:
"""从 User-Agent 中提取平台信息"""
platforms = {
"windows": ["windows", "win32", "win64"],
"macos": ["macintosh", "mac os", "darwin"],
"linux": ["linux", "ubuntu", "debian"],
"android": ["android"],
"ios": ["iphone", "ipad", "ios"],
}
user_agent_lower = user_agent.lower()
for platform, keywords in platforms.items():
if any(keyword in user_agent_lower for keyword in keywords):
return platform
return None
@staticmethod
def _generate_device_fingerprint(user_agent: str) -> str:
"""生成设备指纹"""
# 使用 User-Agent 的哈希值作为简单的设备指纹
# 在实际应用中可以结合更多信息IP、屏幕分辨率等
return hashlib.sha256(user_agent.encode()).hexdigest()[:16]
@staticmethod
def should_skip_email_verification(
client_info: ClientInfo,
is_new_location: bool,
user_id: int,
) -> bool:
"""
判断是否应该跳过邮件验证
Args:
client_info: 客户端信息
is_new_location: 是否为新位置登录
user_id: 用户 ID
Returns:
bool: 是否应该跳过邮件验证
"""
# 受信任的客户端类型可以减少验证频率
if client_info.is_trusted_client:
logger.info(
f"[Client Detection] Trusted client {client_info.client_type} for user {user_id}, "
f"reducing verification requirements"
)
return True
# 如果不是新位置,跳过验证
if not is_new_location:
return True
return False
@staticmethod
def get_verification_cooldown(client_info: ClientInfo) -> int:
"""
获取验证冷却时间(秒)
Args:
client_info: 客户端信息
Returns:
int: 冷却时间(秒)
"""
# 受信任的客户端有更长的冷却时间
if client_info.is_trusted_client:
return 3600 # 1小时
# 网页客户端较短的冷却时间
if client_info.client_type == "osu_web":
return 1800 # 30分钟
# 未知客户端最短冷却时间
return 900 # 15分钟
@staticmethod
def format_client_display_name(client_info: ClientInfo) -> str:
"""格式化客户端显示名称"""
display_names = {
"osu_stable": "osu! (stable)",
"osu_lazer": "osu!(lazer)",
"osu_web": "osu! web",
"mobile": "osu! mobile",
"unknown": "Unknown client",
}
base_name = display_names.get(client_info.client_type, "Unknown client")
if client_info.version:
base_name += f" v{client_info.version}"
if client_info.platform:
base_name += f" ({client_info.platform})"
return base_name

View File

@@ -0,0 +1,283 @@
"""
设备信任服务
管理用户的受信任设备,减少频繁验证
"""
from __future__ import annotations
from datetime import timedelta
from app.config import settings
from app.log import logger
from app.service.client_detection_service import ClientInfo
from app.utils import utcnow
from redis.asyncio import Redis
class DeviceTrustService:
"""设备信任服务"""
@staticmethod
def _get_device_trust_key(user_id: int, device_fingerprint: str) -> str:
"""获取设备信任的 Redis 键"""
return f"device_trust:{user_id}:{device_fingerprint}"
@staticmethod
def _get_location_trust_key(user_id: int, country_code: str) -> str:
"""获取位置信任的 Redis 键"""
return f"location_trust:{user_id}:{country_code}"
@staticmethod
def _get_verification_cooldown_key(user_id: int) -> str:
"""获取验证冷却的 Redis 键"""
return f"verification_cooldown:{user_id}"
@staticmethod
async def is_device_trusted(
redis: Redis,
user_id: int,
device_fingerprint: str,
) -> bool:
"""
检查设备是否受信任
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
Returns:
bool: 设备是否受信任
"""
if not device_fingerprint:
return False
trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint)
trust_data = await redis.get(trust_key)
return trust_data is not None
@staticmethod
async def is_location_trusted(
redis: Redis,
user_id: int,
country_code: str | None,
) -> bool:
"""
检查位置是否受信任
Args:
redis: Redis 连接
user_id: 用户 ID
country_code: 国家代码
Returns:
bool: 位置是否受信任
"""
if not country_code:
return False
trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code)
trust_data = await redis.get(trust_key)
return trust_data is not None
@staticmethod
async def is_in_verification_cooldown(
redis: Redis,
user_id: int,
) -> bool:
"""
检查用户是否在验证冷却期内
Args:
redis: Redis 连接
user_id: 用户 ID
Returns:
bool: 是否在冷却期内
"""
cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id)
cooldown_data = await redis.get(cooldown_key)
return cooldown_data is not None
@staticmethod
async def trust_device(
redis: Redis,
user_id: int,
device_fingerprint: str,
client_info: ClientInfo,
trust_duration_days: int | None = None,
) -> None:
"""
信任设备
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
client_info: 客户端信息
trust_duration_days: 信任持续天数
"""
if not device_fingerprint:
return
# 使用配置中的默认值
if trust_duration_days is None:
trust_duration_days = settings.device_trust_duration_days
trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint)
trust_data = {
"client_type": client_info.client_type,
"platform": client_info.platform or "unknown",
"trusted_at": utcnow().isoformat(),
}
# 设置信任期限
trust_duration_seconds = trust_duration_days * 24 * 3600
await redis.setex(trust_key, trust_duration_seconds, str(trust_data))
logger.info(
f"[Device Trust] Device trusted for user {user_id}: "
f"{client_info.client_type} on {client_info.platform} "
f"(fingerprint: {device_fingerprint[:8]}...)"
)
@staticmethod
async def trust_location(
redis: Redis,
user_id: int,
country_code: str,
trust_duration_days: int | None = None,
) -> None:
"""
信任位置
Args:
redis: Redis 连接
user_id: 用户 ID
country_code: 国家代码
trust_duration_days: 信任持续天数
"""
if not country_code:
return
# 使用配置中的默认值
if trust_duration_days is None:
trust_duration_days = settings.location_trust_duration_days
trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code)
trust_data = {
"country_code": country_code,
"trusted_at": utcnow().isoformat(),
}
# 设置信任期限
trust_duration_seconds = trust_duration_days * 24 * 3600
await redis.setex(trust_key, trust_duration_seconds, str(trust_data))
logger.info(f"[Location Trust] Location trusted for user {user_id}: {country_code}")
@staticmethod
async def set_verification_cooldown(
redis: Redis,
user_id: int,
cooldown_seconds: int,
) -> None:
"""
设置验证冷却期
Args:
redis: Redis 连接
user_id: 用户 ID
cooldown_seconds: 冷却时间(秒)
"""
cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id)
cooldown_data = {
"set_at": utcnow().isoformat(),
"expires_at": (utcnow() + timedelta(seconds=cooldown_seconds)).isoformat(),
}
await redis.setex(cooldown_key, cooldown_seconds, str(cooldown_data))
logger.info(f"[Verification Cooldown] Set cooldown for user {user_id}: {cooldown_seconds}s")
@staticmethod
async def should_require_verification(
redis: Redis,
user_id: int,
device_fingerprint: str | None,
country_code: str | None,
client_info: ClientInfo,
is_new_location: bool,
) -> tuple[bool, str]:
"""
判断是否需要验证
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
country_code: 国家代码
client_info: 客户端信息
is_new_location: 是否为新位置
Returns:
tuple[bool, str]: (是否需要验证, 原因)
"""
# 检查验证冷却期
if await DeviceTrustService.is_in_verification_cooldown(redis, user_id):
return False, "用户在验证冷却期内"
# 检查设备信任
if device_fingerprint and await DeviceTrustService.is_device_trusted(redis, user_id, device_fingerprint):
return False, "设备已受信任"
# 检查位置信任
if country_code and await DeviceTrustService.is_location_trusted(redis, user_id, country_code):
return False, "位置已受信任"
# 受信任的客户端类型降低验证要求
if client_info.is_trusted_client and not is_new_location:
return False, "受信任客户端且非新位置"
# 如果是新位置登录,需要验证
if is_new_location:
return True, "新位置登录需要验证"
# 默认不需要验证
return False, "常规登录无需验证"
@staticmethod
async def mark_verification_successful(
redis: Redis,
user_id: int,
device_fingerprint: str | None,
country_code: str | None,
client_info: ClientInfo,
) -> None:
"""
标记验证成功,更新信任信息
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
country_code: 国家代码
client_info: 客户端信息
"""
# 信任设备
if device_fingerprint:
await DeviceTrustService.trust_device(redis, user_id, device_fingerprint, client_info)
# 信任位置
if country_code:
await DeviceTrustService.trust_location(redis, user_id, country_code)
# 设置验证冷却期
cooldown_seconds = (client_info.is_trusted_client and 3600) or 1800 # 受信任客户端1小时其他30分钟
await DeviceTrustService.set_verification_cooldown(redis, user_id, cooldown_seconds)
logger.info(f"[Device Trust] Verification successful for user {user_id}, trust updated")

View File

@@ -242,14 +242,16 @@ class EmailQueue:
if html_content: if html_content:
msg.attach(MIMEText(html_content, "html", "utf-8")) msg.attach(MIMEText(html_content, "html", "utf-8"))
# 发送邮件 # 发送邮件 - 使用线程池避免阻塞事件循环
def send_smtp_email():
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server: with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
if self.smtp_username and self.smtp_password: if self.smtp_username and self.smtp_password:
server.starttls() server.starttls()
server.login(self.smtp_username, self.smtp_password) server.login(self.smtp_username, self.smtp_password)
server.send_message(msg) server.send_message(msg)
await self._run_in_executor(send_smtp_email)
return True return True
except Exception as e: except Exception as e:

View File

@@ -52,7 +52,7 @@ class EmailService:
line-height: 1.6; line-height: 1.6;
}} }}
.header {{ .header {{
background: linear-gradient(135deg, #ff66aa, #ff9966); background: #ED8EA6;
color: white; color: white;
padding: 20px; padding: 20px;
text-align: center; text-align: center;
@@ -65,7 +65,7 @@ class EmailService:
}} }}
.code {{ .code {{
background: #fff; background: #fff;
border: 2px solid #ff66aa; border: 2px solid #ED8EA6;
border-radius: 8px; border-radius: 8px;
padding: 15px; padding: 15px;
text-align: center; text-align: center;

View File

@@ -141,7 +141,7 @@ class PasswordResetService:
line-height: 1.6; line-height: 1.6;
}} }}
.header {{ .header {{
background: linear-gradient(135deg, #ff6b6b, #ee5a24); background: #ED8EA6;
color: white; color: white;
padding: 20px; padding: 20px;
text-align: center; text-align: center;
@@ -154,7 +154,7 @@ class PasswordResetService:
}} }}
.code {{ .code {{
background: #fff; background: #fff;
border: 2px solid #ff6b6b; border: 2px solid #ED8EA6;
border-radius: 8px; border-radius: 8px;
padding: 15px; padding: 15px;
text-align: center; text-align: center;

View File

@@ -12,11 +12,13 @@ from typing import Literal
from app.config import settings from app.config import settings
from app.database.verification import EmailVerification, LoginSession from app.database.verification import EmailVerification, LoginSession
from app.log import logger from app.log import logger
from app.service.client_detection_service import ClientDetectionService, ClientInfo
from app.service.device_trust_service import DeviceTrustService
from app.service.email_queue import email_queue # 导入邮件队列 from app.service.email_queue import email_queue # 导入邮件队列
from app.utils import utcnow from app.utils import utcnow
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import col, exists, select from sqlmodel import exists, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -56,7 +58,7 @@ class EmailVerificationService:
line-height: 1.6; line-height: 1.6;
}} }}
.header {{ .header {{
background: linear-gradient(135deg, #ff66aa, #ff9966); background: #ED8EA6;
color: white; color: white;
padding: 20px; padding: 20px;
text-align: center; text-align: center;
@@ -69,7 +71,7 @@ class EmailVerificationService:
}} }}
.code {{ .code {{
background: #fff; background: #fff;
border: 2px solid #ff66aa; border: 2px solid #ED8EA6;
border-radius: 8px; border-radius: 8px;
padding: 15px; padding: 15px;
text-align: center; text-align: center;
@@ -201,7 +203,7 @@ This email was sent automatically, please do not reply.
existing_result = await db.exec( existing_result = await db.exec(
select(EmailVerification).where( select(EmailVerification).where(
EmailVerification.user_id == user_id, EmailVerification.user_id == user_id,
col(EmailVerification.is_used).is_(False), EmailVerification.is_used == False, # noqa: E712
EmailVerification.expires_at > utcnow(), EmailVerification.expires_at > utcnow(),
) )
) )
@@ -247,14 +249,37 @@ This email was sent automatically, please do not reply.
email: str, email: str,
ip_address: str | None = None, ip_address: str | None = None,
user_agent: str | None = None, user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
) -> bool: ) -> bool:
"""发送验证邮件""" """发送验证邮件(带智能检测)"""
try: try:
# 检查是否启用邮件验证功能 # 检查是否启用邮件验证功能
if not settings.enable_email_verification: if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}") logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
return True # 返回成功,但不执行验证流程 return True # 返回成功,但不执行验证流程
# 检测客户端信息
client_info = ClientDetectionService.detect_client(user_agent, client_id)
logger.info(
f"[Email Verification] Detected client for user {user_id}: "
f"{ClientDetectionService.format_client_display_name(client_info)}"
)
# 检查是否需要验证
needs_verification, reason = await DeviceTrustService.should_require_verification(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
is_new_location=True, # 这里需要从调用方传入
)
if not needs_verification:
logger.info(f"[Email Verification] Skipping verification for user {user_id}: {reason}")
return True
# 创建验证记录 # 创建验证记录
( (
_, _,
@@ -279,6 +304,107 @@ This email was sent automatically, please do not reply.
logger.error(f"[Email Verification] Exception during sending verification email: {e}") logger.error(f"[Email Verification] Exception during sending verification email: {e}")
return False return False
@staticmethod
async def send_smart_verification_email(
db: AsyncSession,
redis: Redis,
user_id: int,
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
is_new_location: bool = False,
) -> tuple[bool, str, ClientInfo | None]:
"""
智能邮件验证发送
Args:
db: 数据库会话
redis: Redis 连接
user_id: 用户 ID
username: 用户名
email: 邮箱地址
ip_address: IP 地址
user_agent: 用户代理
client_id: 客户端 ID
country_code: 国家代码
is_new_location: 是否为新位置登录
Returns:
tuple[bool, str, ClientInfo | None]: (是否成功, 消息, 客户端信息)
"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
logger.debug(f"[Smart Verification] Email verification is disabled, skipping for user {user_id}")
return True, "邮件验证功能已禁用", None
# 检查是否启用智能验证
if not settings.enable_smart_verification:
logger.debug(
f"[Smart Verification] Smart verification is disabled, using legacy logic for user {user_id}"
)
# 回退到传统验证逻辑
verification, code = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
success = await EmailVerificationService.send_verification_email_via_queue(
email, code, username, user_id
)
return success, "使用传统验证逻辑发送邮件" if success else "传统验证邮件发送失败", None
# 检测客户端信息
client_info = ClientDetectionService.detect_client(user_agent, client_id)
client_display_name = ClientDetectionService.format_client_display_name(client_info)
logger.info(f"[Smart Verification] Detected client for user {user_id}: {client_display_name}")
# 检查是否需要验证
needs_verification, reason = await DeviceTrustService.should_require_verification(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
is_new_location=is_new_location,
)
if not needs_verification:
logger.info(f"[Smart Verification] Skipping verification for user {user_id}: {reason}")
# 即使不需要验证,也要更新设备信任信息
if client_info.device_fingerprint:
await DeviceTrustService.trust_device(redis, user_id, client_info.device_fingerprint, client_info)
if country_code:
await DeviceTrustService.trust_location(redis, user_id, country_code)
return True, f"跳过验证: {reason}", client_info
# 创建验证记录
verification, code = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
_ = verification # 避免未使用变量警告
# 使用邮件队列发送验证邮件
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
if success:
logger.info(
f"[Smart Verification] Successfully sent verification email to {email} "
f"for user {username} using {client_display_name}"
)
return True, "验证邮件已发送", client_info
else:
logger.error(f"[Smart Verification] Failed to send verification email: {email} (user: {username})")
return False, "验证邮件发送失败", client_info
except Exception as e:
logger.error(f"[Smart Verification] Exception during smart verification: {e}")
return False, f"验证过程中发生错误: {e!s}", None
@staticmethod @staticmethod
async def verify_email_code( async def verify_email_code(
db: AsyncSession, db: AsyncSession,
@@ -286,8 +412,11 @@ This email was sent automatically, please do not reply.
user_id: int, user_id: int,
code: str, code: str,
ip_address: str | None = None, ip_address: str | None = None,
user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""验证邮箱验证码""" """验证邮箱验证码(带智能信任更新)"""
try: try:
# 检查是否启用邮件验证功能 # 检查是否启用邮件验证功能
if not settings.enable_email_verification: if not settings.enable_email_verification:
@@ -305,7 +434,7 @@ This email was sent automatically, please do not reply.
EmailVerification.id == int(verification_id), EmailVerification.id == int(verification_id),
EmailVerification.user_id == user_id, EmailVerification.user_id == user_id,
EmailVerification.verification_code == code, EmailVerification.verification_code == code,
col(EmailVerification.is_used).is_(False), EmailVerification.is_used == False, # noqa: E712
EmailVerification.expires_at > utcnow(), EmailVerification.expires_at > utcnow(),
) )
) )
@@ -323,6 +452,16 @@ This email was sent automatically, please do not reply.
# 删除 Redis 记录 # 删除 Redis 记录
await redis.delete(f"email_verification:{user_id}:{code}") await redis.delete(f"email_verification:{user_id}:{code}")
# 检测客户端信息并更新信任状态
client_info = ClientDetectionService.detect_client(user_agent, client_id)
await DeviceTrustService.mark_verification_successful(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
)
logger.info(f"[Email Verification] User {user_id} verification code verified successfully") logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
return True, "验证成功" return True, "验证成功"
@@ -342,6 +481,8 @@ This email was sent automatically, please do not reply.
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""重新发送验证码""" """重新发送验证码"""
try: try:
# 避免未使用参数警告
_ = user_agent
# 检查是否启用邮件验证功能 # 检查是否启用邮件验证功能
if not settings.enable_email_verification: if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}") logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
@@ -465,7 +606,7 @@ class LoginSessionService:
result = await db.exec( result = await db.exec(
select(LoginSession).where( select(LoginSession).where(
LoginSession.user_id == user_id, LoginSession.user_id == user_id,
col(LoginSession.is_verified).is_(False), LoginSession.is_verified == False, # noqa: E712
LoginSession.expires_at > utcnow(), LoginSession.expires_at > utcnow(),
LoginSession.token_id == token_id, LoginSession.token_id == token_id,
) )
@@ -497,7 +638,7 @@ class LoginSessionService:
await db.exec( await db.exec(
select(exists()).where( select(exists()).where(
LoginSession.user_id == user_id, LoginSession.user_id == user_id,
col(LoginSession.is_verified).is_(False), LoginSession.is_verified == False, # noqa: E712
LoginSession.expires_at > utcnow(), LoginSession.expires_at > utcnow(),
LoginSession.token_id == token_id, LoginSession.token_id == token_id,
) )