From 99d6af1c1fc17e57992c648fc9cc139a3298b9bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= Date: Tue, 23 Sep 2025 22:13:15 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=82=AE=E4=BB=B6=E9=AA=8C?= =?UTF-8?q?=E8=AF=81=E9=83=A8=E5=88=86=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 6 +- app/auth.py | 42 +++- app/config.py | 25 +++ app/router/auth.py | 34 ++- app/service/client_detection_service.py | 230 +++++++++++++++++++ app/service/device_trust_service.py | 283 ++++++++++++++++++++++++ app/service/email_queue.py | 14 +- app/service/email_service.py | 4 +- app/service/password_reset_service.py | 4 +- app/service/verification_service.py | 159 ++++++++++++- 10 files changed, 770 insertions(+), 31 deletions(-) create mode 100644 app/service/client_detection_service.py create mode 100644 app/service/device_trust_service.py diff --git a/.gitignore b/.gitignore index e491cc1..364138f 100644 --- a/.gitignore +++ b/.gitignore @@ -222,4 +222,8 @@ newrelic.ini logs/ osu-server-spectator-master/* spectator-server/ -.github/copilot-instructions.md \ No newline at end of file +.github/copilot-instructions.md +osu-web-master/* +osu-web-master/.env.dusk.local.example +osu-web-master/.env.example +osu-web-master/.env.testing.example diff --git a/app/auth.py b/app/auth.py index 6f773a7..4435fdf 100644 --- a/app/auth.py +++ b/app/auth.py @@ -217,15 +217,41 @@ async def store_token( access_token: str, refresh_token: str, expires_in: int, + allow_multiple_devices: bool = True, ) -> OAuthToken: - """存储令牌到数据库""" + """存储令牌到数据库(支持多设备)""" expires_at = utcnow() + timedelta(seconds=expires_in) - # 删除用户的旧令牌 - statement = select(OAuthToken).where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id) - old_tokens = (await db.exec(statement)).all() - for token in old_tokens: - await db.delete(token) + if not allow_multiple_devices: + # 旧的行为:删除用户的旧令牌(单设备模式) + statement = select(OAuthToken).where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id) + old_tokens = (await db.exec(statement)).all() + for token in old_tokens: + await db.delete(token) + else: + # 新的行为:只删除过期的令牌,保留有效的令牌(多设备模式) + statement = select(OAuthToken).where( + OAuthToken.user_id == user_id, OAuthToken.client_id == client_id, OAuthToken.expires_at <= utcnow() + ) + expired_tokens = (await db.exec(statement)).all() + for token in expired_tokens: + await db.delete(token) + + # 限制每个用户每个客户端的最大令牌数量(防止无限增长) + max_tokens_per_client = settings.max_tokens_per_client + statement = ( + select(OAuthToken) + .where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id, OAuthToken.expires_at > utcnow()) + .order_by(OAuthToken.created_at.desc()) + ) + + active_tokens = (await db.exec(statement)).all() + if len(active_tokens) >= max_tokens_per_client: + # 删除最旧的令牌 + tokens_to_delete = active_tokens[max_tokens_per_client - 1 :] + for token in tokens_to_delete: + await db.delete(token) + logger.info(f"[Auth] Cleaned up {len(tokens_to_delete)} old tokens for user {user_id}") # 检查是否有重复的 access_token duplicate_token = (await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))).first() @@ -244,6 +270,10 @@ async def store_token( db.add(token_record) await db.commit() await db.refresh(token_record) + + logger.info( + f"[Auth] Created new token for user {user_id}, client {client_id} (multi-device: {allow_multiple_devices})" + ) return token_record diff --git a/app/config.py b/app/config.py index 335a307..1805976 100644 --- a/app/config.py +++ b/app/config.py @@ -309,6 +309,31 @@ STORAGE_SETTINGS='{ Field(default=False, description="是否启用邮件验证功能"), "验证服务设置", ] + enable_smart_verification: Annotated[ + bool, + Field(default=True, description="是否启用智能验证(基于客户端类型和设备信任)"), + "验证服务设置", + ] + enable_multi_device_login: Annotated[ + bool, + Field(default=True, description="是否允许多设备同时登录"), + "验证服务设置", + ] + max_tokens_per_client: Annotated[ + int, + Field(default=10, description="每个用户每个客户端的最大令牌数量"), + "验证服务设置", + ] + device_trust_duration_days: Annotated[ + int, + Field(default=30, description="设备信任持续天数"), + "验证服务设置", + ] + location_trust_duration_days: Annotated[ + int, + Field(default=90, description="位置信任持续天数"), + "验证服务设置", + ] smtp_server: Annotated[ str, Field(default="localhost", description="SMTP 服务器地址"), diff --git a/app/router/auth.py b/app/router/auth.py index c7b84e7..7254b2c 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -214,6 +214,9 @@ async def oauth_token( ): scopes = scope.split(" ") + # 打印请求头 + # logger.info(f"Request headers: {request.headers}") + client = ( await db.exec( select(OAuthClient).where( @@ -303,6 +306,7 @@ async def oauth_token( access_token, refresh_token_str, settings.access_token_expire_minutes * 60, + allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持 ) token_id = token.id @@ -333,24 +337,41 @@ async def oauth_token( await db.refresh(user) session_verification_method = "mail" - # 发送邮件验证码 - verification_sent = await EmailVerificationService.send_verification_email( - db, redis, user_id, user.username, user.email, ip_address, user_agent + # 使用智能验证发送邮件 + ( + verification_sent, + verification_message, + client_info, + ) = await EmailVerificationService.send_smart_verification_email( + db, + redis, + user_id, + user.username, + user.email, + ip_address, + user_agent, + client_id, + country_code, + is_new_location, ) # 记录需要二次验证的登录尝试 + client_display_name = client_info.client_type if client_info else "unknown" await LoginLogService.record_login( db=db, user_id=user_id, request=request, login_success=True, login_method="password_pending_verification", - notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}", + notes=f"智能验证: {verification_message} - 客户端: {client_display_name}, " + f"IP: {ip_address}, 国家: {country_code}", ) if not verification_sent: # 邮件发送失败,记录错误 - logger.error(f"[Auth] Failed to send email verification code for user {user_id}") + logger.error(f"[Auth] Smart verification failed for user {user_id}: {verification_message}") + else: + logger.info(f"[Auth] Smart verification result for user {user_id}: {verification_message}") elif is_new_location: # 新位置登录但邮件验证功能被禁用,直接标记会话为已验证 await LoginSessionService.mark_session_verified(db, redis, user_id, token_id) @@ -428,6 +449,7 @@ async def oauth_token( access_token, new_refresh_token, settings.access_token_expire_minutes * 60, + allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持 ) return TokenResponse( access_token=access_token, @@ -492,6 +514,7 @@ async def oauth_token( access_token, refresh_token_str, settings.access_token_expire_minutes * 60, + allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持 ) # 打印jwt @@ -538,6 +561,7 @@ async def oauth_token( access_token, refresh_token_str, settings.access_token_expire_minutes * 60, + allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持 ) return TokenResponse( diff --git a/app/service/client_detection_service.py b/app/service/client_detection_service.py new file mode 100644 index 0000000..5fc98c4 --- /dev/null +++ b/app/service/client_detection_service.py @@ -0,0 +1,230 @@ +""" +客户端检测服务 +用于识别不同类型的 osu! 客户端和设备 +""" + +from __future__ import annotations + +from dataclasses import dataclass +import hashlib +import re +from typing import ClassVar, Literal + +from app.log import logger + + +@dataclass +class ClientInfo: + """客户端信息""" + + client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"] + platform: str | None = None + version: str | None = None + device_fingerprint: str | None = None + is_trusted_client: bool = False + + +class ClientDetectionService: + """客户端检测服务""" + + # osu! 客户端的 User-Agent 模式 + OSU_CLIENT_PATTERNS: ClassVar[dict[str, list[str]]] = { + "osu_stable": [ + r"osu!/(\d+(?:\.\d+)*)", # osu!/20241001 + r"osu!", # 简单匹配 + ], + "osu_lazer": [ + r"osu-lazer/(\d+(?:\.\d+)*)", # osu-lazer/2024.1009.0 + r"osu!lazer/(\d+(?:\.\d+)*)", # osu!lazer/2024.1009.0 + ], + "osu_web": [ + r"Mozilla.*osu\.ppy\.sh", # 网页客户端 + ], + "mobile": [ + r"osu!.*mobile", + r"osu.*Mobile", + r"Mobile.*osu", + ], + } + + # 受信任的客户端类型(不需要频繁验证) + TRUSTED_CLIENT_TYPES: ClassVar[set[str]] = {"osu_stable", "osu_lazer"} + + @staticmethod + def detect_client(user_agent: str | None, client_id: int | None = None) -> ClientInfo: + """ + 检测客户端类型和信息 + + Args: + user_agent: 用户代理字符串 + client_id: OAuth 客户端 ID + + Returns: + ClientInfo: 客户端信息 + """ + from app.config import settings # 导入在函数内部避免循环导入 + + if not user_agent: + return ClientInfo(client_type="unknown") + + # 优先通过 client_id 判断客户端类型 + if client_id is not None: + if client_id == settings.osu_client_id: + # osu! stable 客户端 + return ClientInfo( + client_type="osu_stable", + platform=ClientDetectionService._extract_platform(user_agent), + device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), + is_trusted_client=True, + ) + elif client_id == settings.osu_web_client_id: + # 检查 User-Agent 是否表明这是 Lazer 客户端 + if user_agent and user_agent.strip() == "osu!": + # Lazer 客户端使用 web client_id 但发送简单的 "osu!" User-Agent + return ClientInfo( + client_type="osu_lazer", + platform=ClientDetectionService._extract_platform(user_agent), + device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), + is_trusted_client=True, + ) + else: + # 真正的 web 客户端 + return ClientInfo( + client_type="osu_web", + platform=ClientDetectionService._extract_platform(user_agent), + device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), + is_trusted_client=False, + ) + + # 回退到基于 User-Agent 的检测 + for client_type_str, patterns in ClientDetectionService.OSU_CLIENT_PATTERNS.items(): + for pattern in patterns: + match = re.search(pattern, user_agent, re.IGNORECASE) + if match: + version = match.group(1) if match.groups() else None + platform = ClientDetectionService._extract_platform(user_agent) + + # 确保 client_type 是正确的 Literal 类型 + client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"] = client_type_str # type: ignore + + return ClientInfo( + client_type=client_type, + platform=platform, + version=version, + device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), + is_trusted_client=client_type in ClientDetectionService.TRUSTED_CLIENT_TYPES, + ) + + # 检测常见浏览器 + if any(browser in user_agent.lower() for browser in ["chrome", "firefox", "safari", "edge"]): + return ClientInfo( + client_type="osu_web", + platform=ClientDetectionService._extract_platform(user_agent), + device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), + is_trusted_client=False, + ) + + return ClientInfo( + client_type="unknown", + device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), + is_trusted_client=False, + ) + + @staticmethod + def _extract_platform(user_agent: str) -> str | None: + """从 User-Agent 中提取平台信息""" + platforms = { + "windows": ["windows", "win32", "win64"], + "macos": ["macintosh", "mac os", "darwin"], + "linux": ["linux", "ubuntu", "debian"], + "android": ["android"], + "ios": ["iphone", "ipad", "ios"], + } + + user_agent_lower = user_agent.lower() + for platform, keywords in platforms.items(): + if any(keyword in user_agent_lower for keyword in keywords): + return platform + + return None + + @staticmethod + def _generate_device_fingerprint(user_agent: str) -> str: + """生成设备指纹""" + # 使用 User-Agent 的哈希值作为简单的设备指纹 + # 在实际应用中可以结合更多信息(IP、屏幕分辨率等) + return hashlib.sha256(user_agent.encode()).hexdigest()[:16] + + @staticmethod + def should_skip_email_verification( + client_info: ClientInfo, + is_new_location: bool, + user_id: int, + ) -> bool: + """ + 判断是否应该跳过邮件验证 + + Args: + client_info: 客户端信息 + is_new_location: 是否为新位置登录 + user_id: 用户 ID + + Returns: + bool: 是否应该跳过邮件验证 + """ + # 受信任的客户端类型可以减少验证频率 + if client_info.is_trusted_client: + logger.info( + f"[Client Detection] Trusted client {client_info.client_type} for user {user_id}, " + f"reducing verification requirements" + ) + return True + + # 如果不是新位置,跳过验证 + if not is_new_location: + return True + + return False + + @staticmethod + def get_verification_cooldown(client_info: ClientInfo) -> int: + """ + 获取验证冷却时间(秒) + + Args: + client_info: 客户端信息 + + Returns: + int: 冷却时间(秒) + """ + # 受信任的客户端有更长的冷却时间 + if client_info.is_trusted_client: + return 3600 # 1小时 + + # 网页客户端较短的冷却时间 + if client_info.client_type == "osu_web": + return 1800 # 30分钟 + + # 未知客户端最短冷却时间 + return 900 # 15分钟 + + @staticmethod + def format_client_display_name(client_info: ClientInfo) -> str: + """格式化客户端显示名称""" + display_names = { + "osu_stable": "osu! (stable)", + "osu_lazer": "osu!(lazer)", + "osu_web": "osu! web", + "mobile": "osu! mobile", + "unknown": "Unknown client", + } + + base_name = display_names.get(client_info.client_type, "Unknown client") + + if client_info.version: + base_name += f" v{client_info.version}" + + if client_info.platform: + base_name += f" ({client_info.platform})" + + return base_name diff --git a/app/service/device_trust_service.py b/app/service/device_trust_service.py new file mode 100644 index 0000000..1b4e623 --- /dev/null +++ b/app/service/device_trust_service.py @@ -0,0 +1,283 @@ +""" +设备信任服务 +管理用户的受信任设备,减少频繁验证 +""" + +from __future__ import annotations + +from datetime import timedelta + +from app.config import settings +from app.log import logger +from app.service.client_detection_service import ClientInfo +from app.utils import utcnow + +from redis.asyncio import Redis + + +class DeviceTrustService: + """设备信任服务""" + + @staticmethod + def _get_device_trust_key(user_id: int, device_fingerprint: str) -> str: + """获取设备信任的 Redis 键""" + return f"device_trust:{user_id}:{device_fingerprint}" + + @staticmethod + def _get_location_trust_key(user_id: int, country_code: str) -> str: + """获取位置信任的 Redis 键""" + return f"location_trust:{user_id}:{country_code}" + + @staticmethod + def _get_verification_cooldown_key(user_id: int) -> str: + """获取验证冷却的 Redis 键""" + return f"verification_cooldown:{user_id}" + + @staticmethod + async def is_device_trusted( + redis: Redis, + user_id: int, + device_fingerprint: str, + ) -> bool: + """ + 检查设备是否受信任 + + Args: + redis: Redis 连接 + user_id: 用户 ID + device_fingerprint: 设备指纹 + + Returns: + bool: 设备是否受信任 + """ + if not device_fingerprint: + return False + + trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint) + trust_data = await redis.get(trust_key) + + return trust_data is not None + + @staticmethod + async def is_location_trusted( + redis: Redis, + user_id: int, + country_code: str | None, + ) -> bool: + """ + 检查位置是否受信任 + + Args: + redis: Redis 连接 + user_id: 用户 ID + country_code: 国家代码 + + Returns: + bool: 位置是否受信任 + """ + if not country_code: + return False + + trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code) + trust_data = await redis.get(trust_key) + + return trust_data is not None + + @staticmethod + async def is_in_verification_cooldown( + redis: Redis, + user_id: int, + ) -> bool: + """ + 检查用户是否在验证冷却期内 + + Args: + redis: Redis 连接 + user_id: 用户 ID + + Returns: + bool: 是否在冷却期内 + """ + cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id) + cooldown_data = await redis.get(cooldown_key) + + return cooldown_data is not None + + @staticmethod + async def trust_device( + redis: Redis, + user_id: int, + device_fingerprint: str, + client_info: ClientInfo, + trust_duration_days: int | None = None, + ) -> None: + """ + 信任设备 + + Args: + redis: Redis 连接 + user_id: 用户 ID + device_fingerprint: 设备指纹 + client_info: 客户端信息 + trust_duration_days: 信任持续天数 + """ + if not device_fingerprint: + return + + # 使用配置中的默认值 + if trust_duration_days is None: + trust_duration_days = settings.device_trust_duration_days + + trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint) + trust_data = { + "client_type": client_info.client_type, + "platform": client_info.platform or "unknown", + "trusted_at": utcnow().isoformat(), + } + + # 设置信任期限 + trust_duration_seconds = trust_duration_days * 24 * 3600 + await redis.setex(trust_key, trust_duration_seconds, str(trust_data)) + + logger.info( + f"[Device Trust] Device trusted for user {user_id}: " + f"{client_info.client_type} on {client_info.platform} " + f"(fingerprint: {device_fingerprint[:8]}...)" + ) + + @staticmethod + async def trust_location( + redis: Redis, + user_id: int, + country_code: str, + trust_duration_days: int | None = None, + ) -> None: + """ + 信任位置 + + Args: + redis: Redis 连接 + user_id: 用户 ID + country_code: 国家代码 + trust_duration_days: 信任持续天数 + """ + if not country_code: + return + + # 使用配置中的默认值 + if trust_duration_days is None: + trust_duration_days = settings.location_trust_duration_days + + trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code) + trust_data = { + "country_code": country_code, + "trusted_at": utcnow().isoformat(), + } + + # 设置信任期限 + trust_duration_seconds = trust_duration_days * 24 * 3600 + await redis.setex(trust_key, trust_duration_seconds, str(trust_data)) + + logger.info(f"[Location Trust] Location trusted for user {user_id}: {country_code}") + + @staticmethod + async def set_verification_cooldown( + redis: Redis, + user_id: int, + cooldown_seconds: int, + ) -> None: + """ + 设置验证冷却期 + + Args: + redis: Redis 连接 + user_id: 用户 ID + cooldown_seconds: 冷却时间(秒) + """ + cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id) + cooldown_data = { + "set_at": utcnow().isoformat(), + "expires_at": (utcnow() + timedelta(seconds=cooldown_seconds)).isoformat(), + } + + await redis.setex(cooldown_key, cooldown_seconds, str(cooldown_data)) + + logger.info(f"[Verification Cooldown] Set cooldown for user {user_id}: {cooldown_seconds}s") + + @staticmethod + async def should_require_verification( + redis: Redis, + user_id: int, + device_fingerprint: str | None, + country_code: str | None, + client_info: ClientInfo, + is_new_location: bool, + ) -> tuple[bool, str]: + """ + 判断是否需要验证 + + Args: + redis: Redis 连接 + user_id: 用户 ID + device_fingerprint: 设备指纹 + country_code: 国家代码 + client_info: 客户端信息 + is_new_location: 是否为新位置 + + Returns: + tuple[bool, str]: (是否需要验证, 原因) + """ + # 检查验证冷却期 + if await DeviceTrustService.is_in_verification_cooldown(redis, user_id): + return False, "用户在验证冷却期内" + + # 检查设备信任 + if device_fingerprint and await DeviceTrustService.is_device_trusted(redis, user_id, device_fingerprint): + return False, "设备已受信任" + + # 检查位置信任 + if country_code and await DeviceTrustService.is_location_trusted(redis, user_id, country_code): + return False, "位置已受信任" + + # 受信任的客户端类型降低验证要求 + if client_info.is_trusted_client and not is_new_location: + return False, "受信任客户端且非新位置" + + # 如果是新位置登录,需要验证 + if is_new_location: + return True, "新位置登录需要验证" + + # 默认不需要验证 + return False, "常规登录无需验证" + + @staticmethod + async def mark_verification_successful( + redis: Redis, + user_id: int, + device_fingerprint: str | None, + country_code: str | None, + client_info: ClientInfo, + ) -> None: + """ + 标记验证成功,更新信任信息 + + Args: + redis: Redis 连接 + user_id: 用户 ID + device_fingerprint: 设备指纹 + country_code: 国家代码 + client_info: 客户端信息 + """ + # 信任设备 + if device_fingerprint: + await DeviceTrustService.trust_device(redis, user_id, device_fingerprint, client_info) + + # 信任位置 + if country_code: + await DeviceTrustService.trust_location(redis, user_id, country_code) + + # 设置验证冷却期 + cooldown_seconds = (client_info.is_trusted_client and 3600) or 1800 # 受信任客户端1小时,其他30分钟 + await DeviceTrustService.set_verification_cooldown(redis, user_id, cooldown_seconds) + + logger.info(f"[Device Trust] Verification successful for user {user_id}, trust updated") diff --git a/app/service/email_queue.py b/app/service/email_queue.py index 3088574..9c3ec62 100644 --- a/app/service/email_queue.py +++ b/app/service/email_queue.py @@ -242,13 +242,15 @@ class EmailQueue: if html_content: msg.attach(MIMEText(html_content, "html", "utf-8")) - # 发送邮件 - with smtplib.SMTP(self.smtp_server, self.smtp_port) as server: - if self.smtp_username and self.smtp_password: - server.starttls() - server.login(self.smtp_username, self.smtp_password) + # 发送邮件 - 使用线程池避免阻塞事件循环 + def send_smtp_email(): + with smtplib.SMTP(self.smtp_server, self.smtp_port) as server: + if self.smtp_username and self.smtp_password: + server.starttls() + server.login(self.smtp_username, self.smtp_password) + server.send_message(msg) - server.send_message(msg) + await self._run_in_executor(send_smtp_email) return True diff --git a/app/service/email_service.py b/app/service/email_service.py index 32f7124..73c4b7c 100644 --- a/app/service/email_service.py +++ b/app/service/email_service.py @@ -52,7 +52,7 @@ class EmailService: line-height: 1.6; }} .header {{ - background: linear-gradient(135deg, #ff66aa, #ff9966); + background: #ED8EA6; color: white; padding: 20px; text-align: center; @@ -65,7 +65,7 @@ class EmailService: }} .code {{ background: #fff; - border: 2px solid #ff66aa; + border: 2px solid #ED8EA6; border-radius: 8px; padding: 15px; text-align: center; diff --git a/app/service/password_reset_service.py b/app/service/password_reset_service.py index de301de..b822329 100644 --- a/app/service/password_reset_service.py +++ b/app/service/password_reset_service.py @@ -141,7 +141,7 @@ class PasswordResetService: line-height: 1.6; }} .header {{ - background: linear-gradient(135deg, #ff6b6b, #ee5a24); + background: #ED8EA6; color: white; padding: 20px; text-align: center; @@ -154,7 +154,7 @@ class PasswordResetService: }} .code {{ background: #fff; - border: 2px solid #ff6b6b; + border: 2px solid #ED8EA6; border-radius: 8px; padding: 15px; text-align: center; diff --git a/app/service/verification_service.py b/app/service/verification_service.py index 5aec20b..1a1dfd2 100644 --- a/app/service/verification_service.py +++ b/app/service/verification_service.py @@ -12,11 +12,13 @@ from typing import Literal from app.config import settings from app.database.verification import EmailVerification, LoginSession from app.log import logger +from app.service.client_detection_service import ClientDetectionService, ClientInfo +from app.service.device_trust_service import DeviceTrustService from app.service.email_queue import email_queue # 导入邮件队列 from app.utils import utcnow from redis.asyncio import Redis -from sqlmodel import col, exists, select +from sqlmodel import exists, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -56,7 +58,7 @@ class EmailVerificationService: line-height: 1.6; }} .header {{ - background: linear-gradient(135deg, #ff66aa, #ff9966); + background: #ED8EA6; color: white; padding: 20px; text-align: center; @@ -69,7 +71,7 @@ class EmailVerificationService: }} .code {{ background: #fff; - border: 2px solid #ff66aa; + border: 2px solid #ED8EA6; border-radius: 8px; padding: 15px; text-align: center; @@ -201,7 +203,7 @@ This email was sent automatically, please do not reply. existing_result = await db.exec( select(EmailVerification).where( EmailVerification.user_id == user_id, - col(EmailVerification.is_used).is_(False), + EmailVerification.is_used == False, # noqa: E712 EmailVerification.expires_at > utcnow(), ) ) @@ -247,14 +249,37 @@ This email was sent automatically, please do not reply. email: str, ip_address: str | None = None, user_agent: str | None = None, + client_id: int | None = None, + country_code: str | None = None, ) -> bool: - """发送验证邮件""" + """发送验证邮件(带智能检测)""" try: # 检查是否启用邮件验证功能 if not settings.enable_email_verification: logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}") return True # 返回成功,但不执行验证流程 + # 检测客户端信息 + client_info = ClientDetectionService.detect_client(user_agent, client_id) + logger.info( + f"[Email Verification] Detected client for user {user_id}: " + f"{ClientDetectionService.format_client_display_name(client_info)}" + ) + + # 检查是否需要验证 + needs_verification, reason = await DeviceTrustService.should_require_verification( + redis=redis, + user_id=user_id, + device_fingerprint=client_info.device_fingerprint, + country_code=country_code, + client_info=client_info, + is_new_location=True, # 这里需要从调用方传入 + ) + + if not needs_verification: + logger.info(f"[Email Verification] Skipping verification for user {user_id}: {reason}") + return True + # 创建验证记录 ( _, @@ -279,6 +304,107 @@ This email was sent automatically, please do not reply. logger.error(f"[Email Verification] Exception during sending verification email: {e}") return False + @staticmethod + async def send_smart_verification_email( + db: AsyncSession, + redis: Redis, + user_id: int, + username: str, + email: str, + ip_address: str | None = None, + user_agent: str | None = None, + client_id: int | None = None, + country_code: str | None = None, + is_new_location: bool = False, + ) -> tuple[bool, str, ClientInfo | None]: + """ + 智能邮件验证发送 + + Args: + db: 数据库会话 + redis: Redis 连接 + user_id: 用户 ID + username: 用户名 + email: 邮箱地址 + ip_address: IP 地址 + user_agent: 用户代理 + client_id: 客户端 ID + country_code: 国家代码 + is_new_location: 是否为新位置登录 + + Returns: + tuple[bool, str, ClientInfo | None]: (是否成功, 消息, 客户端信息) + """ + try: + # 检查是否启用邮件验证功能 + if not settings.enable_email_verification: + logger.debug(f"[Smart Verification] Email verification is disabled, skipping for user {user_id}") + return True, "邮件验证功能已禁用", None + + # 检查是否启用智能验证 + if not settings.enable_smart_verification: + logger.debug( + f"[Smart Verification] Smart verification is disabled, using legacy logic for user {user_id}" + ) + # 回退到传统验证逻辑 + verification, code = await EmailVerificationService.create_verification_record( + db, redis, user_id, email, ip_address, user_agent + ) + success = await EmailVerificationService.send_verification_email_via_queue( + email, code, username, user_id + ) + return success, "使用传统验证逻辑发送邮件" if success else "传统验证邮件发送失败", None + + # 检测客户端信息 + client_info = ClientDetectionService.detect_client(user_agent, client_id) + client_display_name = ClientDetectionService.format_client_display_name(client_info) + + logger.info(f"[Smart Verification] Detected client for user {user_id}: {client_display_name}") + + # 检查是否需要验证 + needs_verification, reason = await DeviceTrustService.should_require_verification( + redis=redis, + user_id=user_id, + device_fingerprint=client_info.device_fingerprint, + country_code=country_code, + client_info=client_info, + is_new_location=is_new_location, + ) + + if not needs_verification: + logger.info(f"[Smart Verification] Skipping verification for user {user_id}: {reason}") + + # 即使不需要验证,也要更新设备信任信息 + if client_info.device_fingerprint: + await DeviceTrustService.trust_device(redis, user_id, client_info.device_fingerprint, client_info) + if country_code: + await DeviceTrustService.trust_location(redis, user_id, country_code) + + return True, f"跳过验证: {reason}", client_info + + # 创建验证记录 + verification, code = await EmailVerificationService.create_verification_record( + db, redis, user_id, email, ip_address, user_agent + ) + _ = verification # 避免未使用变量警告 + + # 使用邮件队列发送验证邮件 + success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id) + + if success: + logger.info( + f"[Smart Verification] Successfully sent verification email to {email} " + f"for user {username} using {client_display_name}" + ) + return True, "验证邮件已发送", client_info + else: + logger.error(f"[Smart Verification] Failed to send verification email: {email} (user: {username})") + return False, "验证邮件发送失败", client_info + + except Exception as e: + logger.error(f"[Smart Verification] Exception during smart verification: {e}") + return False, f"验证过程中发生错误: {e!s}", None + @staticmethod async def verify_email_code( db: AsyncSession, @@ -286,8 +412,11 @@ This email was sent automatically, please do not reply. user_id: int, code: str, ip_address: str | None = None, + user_agent: str | None = None, + client_id: int | None = None, + country_code: str | None = None, ) -> tuple[bool, str]: - """验证邮箱验证码""" + """验证邮箱验证码(带智能信任更新)""" try: # 检查是否启用邮件验证功能 if not settings.enable_email_verification: @@ -305,7 +434,7 @@ This email was sent automatically, please do not reply. EmailVerification.id == int(verification_id), EmailVerification.user_id == user_id, EmailVerification.verification_code == code, - col(EmailVerification.is_used).is_(False), + EmailVerification.is_used == False, # noqa: E712 EmailVerification.expires_at > utcnow(), ) ) @@ -323,6 +452,16 @@ This email was sent automatically, please do not reply. # 删除 Redis 记录 await redis.delete(f"email_verification:{user_id}:{code}") + # 检测客户端信息并更新信任状态 + client_info = ClientDetectionService.detect_client(user_agent, client_id) + await DeviceTrustService.mark_verification_successful( + redis=redis, + user_id=user_id, + device_fingerprint=client_info.device_fingerprint, + country_code=country_code, + client_info=client_info, + ) + logger.info(f"[Email Verification] User {user_id} verification code verified successfully") return True, "验证成功" @@ -342,6 +481,8 @@ This email was sent automatically, please do not reply. ) -> tuple[bool, str]: """重新发送验证码""" try: + # 避免未使用参数警告 + _ = user_agent # 检查是否启用邮件验证功能 if not settings.enable_email_verification: logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}") @@ -465,7 +606,7 @@ class LoginSessionService: result = await db.exec( select(LoginSession).where( LoginSession.user_id == user_id, - col(LoginSession.is_verified).is_(False), + LoginSession.is_verified == False, # noqa: E712 LoginSession.expires_at > utcnow(), LoginSession.token_id == token_id, ) @@ -497,7 +638,7 @@ class LoginSessionService: await db.exec( select(exists()).where( LoginSession.user_id == user_id, - col(LoginSession.is_verified).is_(False), + LoginSession.is_verified == False, # noqa: E712 LoginSession.expires_at > utcnow(), LoginSession.token_id == token_id, )