feat(auth): support trusted device (#52)
New API to maintain sessions and devices:
- GET /api/private/admin/sessions
- DELETE /api/private/admin/sessions/{session_id}
- GET /api/private/admin/trusted-devices
- DELETE /api/private/admin/trusted-devices/{device_id}
Auth:
web clients request `/oauth/token` and `/api/v2/session/verify` with `X-UUID` header to save the client as trusted device.
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -10,15 +10,15 @@ import string
|
||||
from typing import Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.database.verification import EmailVerification, LoginSession
|
||||
from app.database.auth import OAuthToken
|
||||
from app.database.verification import EmailVerification, LoginSession, TrustedDevice
|
||||
from app.log import logger
|
||||
from app.service.client_detection_service import ClientDetectionService, ClientInfo
|
||||
from app.service.device_trust_service import DeviceTrustService
|
||||
from app.service.email_queue import email_queue # 导入邮件队列
|
||||
from app.models.model import UserAgentInfo
|
||||
from app.service.email_queue import email_queue
|
||||
from app.utils import utcnow
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import exists, select
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -248,11 +248,9 @@ This email was sent automatically, please do not reply.
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
client_id: int | None = None,
|
||||
country_code: str | None = None,
|
||||
user_agent: UserAgentInfo | None = None,
|
||||
) -> bool:
|
||||
"""发送验证邮件(带智能检测)"""
|
||||
"""发送验证邮件"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
@@ -260,32 +258,14 @@ This email was sent automatically, please do not reply.
|
||||
return True # 返回成功,但不执行验证流程
|
||||
|
||||
# 检测客户端信息
|
||||
client_info = ClientDetectionService.detect_client(user_agent, client_id)
|
||||
logger.info(
|
||||
f"[Email Verification] Detected client for user {user_id}: "
|
||||
f"{ClientDetectionService.format_client_display_name(client_info)}"
|
||||
)
|
||||
|
||||
# 检查是否需要验证
|
||||
needs_verification, reason = await DeviceTrustService.should_require_verification(
|
||||
redis=redis,
|
||||
user_id=user_id,
|
||||
device_fingerprint=client_info.device_fingerprint,
|
||||
country_code=country_code,
|
||||
client_info=client_info,
|
||||
is_new_location=True, # 这里需要从调用方传入
|
||||
)
|
||||
|
||||
if not needs_verification:
|
||||
logger.info(f"[Email Verification] Skipping verification for user {user_id}: {reason}")
|
||||
return True
|
||||
logger.info(f"[Email Verification] Detected client for user {user_id}: {user_agent}")
|
||||
|
||||
# 创建验证记录
|
||||
(
|
||||
_,
|
||||
code,
|
||||
) = await EmailVerificationService.create_verification_record(
|
||||
db, redis, user_id, email, ip_address, user_agent
|
||||
db, redis, user_id, email, ip_address, user_agent.raw_ua if user_agent else None
|
||||
)
|
||||
|
||||
# 使用邮件队列发送验证邮件
|
||||
@@ -304,107 +284,6 @@ This email was sent automatically, please do not reply.
|
||||
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def send_smart_verification_email(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
client_id: int | None = None,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False,
|
||||
) -> tuple[bool, str, ClientInfo | None]:
|
||||
"""
|
||||
智能邮件验证发送
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
redis: Redis 连接
|
||||
user_id: 用户 ID
|
||||
username: 用户名
|
||||
email: 邮箱地址
|
||||
ip_address: IP 地址
|
||||
user_agent: 用户代理
|
||||
client_id: 客户端 ID
|
||||
country_code: 国家代码
|
||||
is_new_location: 是否为新位置登录
|
||||
|
||||
Returns:
|
||||
tuple[bool, str, ClientInfo | None]: (是否成功, 消息, 客户端信息)
|
||||
"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Smart Verification] Email verification is disabled, skipping for user {user_id}")
|
||||
return True, "邮件验证功能已禁用", None
|
||||
|
||||
# 检查是否启用智能验证
|
||||
if not settings.enable_smart_verification:
|
||||
logger.debug(
|
||||
f"[Smart Verification] Smart verification is disabled, using legacy logic for user {user_id}"
|
||||
)
|
||||
# 回退到传统验证逻辑
|
||||
verification, code = await EmailVerificationService.create_verification_record(
|
||||
db, redis, user_id, email, ip_address, user_agent
|
||||
)
|
||||
success = await EmailVerificationService.send_verification_email_via_queue(
|
||||
email, code, username, user_id
|
||||
)
|
||||
return success, "使用传统验证逻辑发送邮件" if success else "传统验证邮件发送失败", None
|
||||
|
||||
# 检测客户端信息
|
||||
client_info = ClientDetectionService.detect_client(user_agent, client_id)
|
||||
client_display_name = ClientDetectionService.format_client_display_name(client_info)
|
||||
|
||||
logger.info(f"[Smart Verification] Detected client for user {user_id}: {client_display_name}")
|
||||
|
||||
# 检查是否需要验证
|
||||
needs_verification, reason = await DeviceTrustService.should_require_verification(
|
||||
redis=redis,
|
||||
user_id=user_id,
|
||||
device_fingerprint=client_info.device_fingerprint,
|
||||
country_code=country_code,
|
||||
client_info=client_info,
|
||||
is_new_location=is_new_location,
|
||||
)
|
||||
|
||||
if not needs_verification:
|
||||
logger.info(f"[Smart Verification] Skipping verification for user {user_id}: {reason}")
|
||||
|
||||
# 即使不需要验证,也要更新设备信任信息
|
||||
if client_info.device_fingerprint:
|
||||
await DeviceTrustService.trust_device(redis, user_id, client_info.device_fingerprint, client_info)
|
||||
if country_code:
|
||||
await DeviceTrustService.trust_location(redis, user_id, country_code)
|
||||
|
||||
return True, f"跳过验证: {reason}", client_info
|
||||
|
||||
# 创建验证记录
|
||||
verification, code = await EmailVerificationService.create_verification_record(
|
||||
db, redis, user_id, email, ip_address, user_agent
|
||||
)
|
||||
_ = verification # 避免未使用变量警告
|
||||
|
||||
# 使用邮件队列发送验证邮件
|
||||
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
f"[Smart Verification] Successfully sent verification email to {email} "
|
||||
f"for user {username} using {client_display_name}"
|
||||
)
|
||||
return True, "验证邮件已发送", client_info
|
||||
else:
|
||||
logger.error(f"[Smart Verification] Failed to send verification email: {email} (user: {username})")
|
||||
return False, "验证邮件发送失败", client_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Smart Verification] Exception during smart verification: {e}")
|
||||
return False, f"验证过程中发生错误: {e!s}", None
|
||||
|
||||
@staticmethod
|
||||
async def verify_email_code(
|
||||
db: AsyncSession,
|
||||
@@ -416,7 +295,7 @@ This email was sent automatically, please do not reply.
|
||||
client_id: int | None = None,
|
||||
country_code: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""验证邮箱验证码(带智能信任更新)"""
|
||||
"""验证邮箱验证码"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
@@ -452,16 +331,6 @@ This email was sent automatically, please do not reply.
|
||||
# 删除 Redis 记录
|
||||
await redis.delete(f"email_verification:{user_id}:{code}")
|
||||
|
||||
# 检测客户端信息并更新信任状态
|
||||
client_info = ClientDetectionService.detect_client(user_agent, client_id)
|
||||
await DeviceTrustService.mark_verification_successful(
|
||||
redis=redis,
|
||||
user_id=user_id,
|
||||
device_fingerprint=client_info.device_fingerprint,
|
||||
country_code=country_code,
|
||||
client_info=client_info,
|
||||
)
|
||||
|
||||
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
|
||||
return True, "验证成功"
|
||||
|
||||
@@ -477,7 +346,7 @@ This email was sent automatically, please do not reply.
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
user_agent: UserAgentInfo | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""重新发送验证码"""
|
||||
try:
|
||||
@@ -516,12 +385,12 @@ class LoginSessionService:
|
||||
|
||||
# Session verification interface methods
|
||||
@staticmethod
|
||||
async def find_for_verification(db: AsyncSession, session_id: str) -> LoginSession | None:
|
||||
async def find_for_verification(db: AsyncSession, token: str) -> LoginSession | None:
|
||||
"""根据会话ID查找会话用于验证"""
|
||||
try:
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.session_token == session_id,
|
||||
col(LoginSession.token).has(col(OAuthToken.access_token) == token),
|
||||
LoginSession.expires_at > utcnow(),
|
||||
)
|
||||
)
|
||||
@@ -537,42 +406,31 @@ class LoginSessionService:
|
||||
@staticmethod
|
||||
async def create_session(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
token_id: int,
|
||||
ip_address: str,
|
||||
user_agent: str | None = None,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False,
|
||||
is_new_device: bool = False,
|
||||
web_uuid: str | None = None,
|
||||
is_verified: bool = False,
|
||||
) -> LoginSession:
|
||||
"""创建登录会话"""
|
||||
|
||||
session_token = EmailVerificationService.generate_session_token()
|
||||
|
||||
session = LoginSession(
|
||||
user_id=user_id,
|
||||
token_id=token_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=None,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location,
|
||||
user_agent=user_agent,
|
||||
is_new_device=is_new_device,
|
||||
expires_at=utcnow() + timedelta(hours=24), # 24小时过期
|
||||
is_verified=is_verified,
|
||||
web_uuid=web_uuid,
|
||||
)
|
||||
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
# 存储到 Redis
|
||||
await redis.setex(
|
||||
f"login_session:{session_token}",
|
||||
86400, # 24小时
|
||||
user_id,
|
||||
)
|
||||
|
||||
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
|
||||
logger.info(f"[Login Session] Created session for user {user_id} (new device: {is_new_device})")
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
@@ -592,35 +450,98 @@ class LoginSessionService:
|
||||
await redis.delete(cls._session_verify_redis_key(user_id, token_id))
|
||||
|
||||
@staticmethod
|
||||
async def check_new_location(
|
||||
db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None
|
||||
async def check_trusted_device(
|
||||
db: AsyncSession, user_id: int, ip_address: str, user_agent: UserAgentInfo, web_uuid: str | None = None
|
||||
) -> bool:
|
||||
"""检查是否为新位置登录"""
|
||||
try:
|
||||
# 查看过去30天内是否有相同IP或相同国家的登录记录
|
||||
thirty_days_ago = utcnow() - timedelta(days=30)
|
||||
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.created_at > thirty_days_ago,
|
||||
(LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code),
|
||||
)
|
||||
if user_agent.is_client:
|
||||
query = select(exists()).where(
|
||||
TrustedDevice.user_id == user_id,
|
||||
TrustedDevice.client_type == "client",
|
||||
TrustedDevice.ip_address == ip_address,
|
||||
TrustedDevice.expires_at > utcnow(),
|
||||
)
|
||||
|
||||
existing_sessions = result.all()
|
||||
|
||||
# 如果有历史记录,则不是新位置
|
||||
return len(existing_sessions) == 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during new location check: {e}")
|
||||
# 出错时默认为新位置(更安全)
|
||||
return True
|
||||
else:
|
||||
if web_uuid is None:
|
||||
return False
|
||||
query = select(exists()).where(
|
||||
TrustedDevice.user_id == user_id,
|
||||
TrustedDevice.client_type == "web",
|
||||
TrustedDevice.web_uuid == web_uuid,
|
||||
TrustedDevice.expires_at > utcnow(),
|
||||
)
|
||||
return (await db.exec(query)).first() or False
|
||||
|
||||
@staticmethod
|
||||
async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool:
|
||||
async def create_trusted_device(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
ip_address: str,
|
||||
user_agent: UserAgentInfo,
|
||||
web_uuid: str | None = None,
|
||||
) -> TrustedDevice:
|
||||
device = TrustedDevice(
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent.raw_ua,
|
||||
client_type="client" if user_agent.is_client else "web",
|
||||
web_uuid=web_uuid if not user_agent.is_client else None,
|
||||
expires_at=utcnow() + timedelta(days=settings.device_trust_duration_days),
|
||||
)
|
||||
db.add(device)
|
||||
await db.commit()
|
||||
await db.refresh(device)
|
||||
return device
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_trusted_device(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
ip_address: str,
|
||||
user_agent: UserAgentInfo,
|
||||
web_uuid: str | None = None,
|
||||
) -> TrustedDevice:
|
||||
if user_agent.is_client:
|
||||
query = select(TrustedDevice).where(
|
||||
TrustedDevice.user_id == user_id,
|
||||
TrustedDevice.client_type == "client",
|
||||
TrustedDevice.ip_address == ip_address,
|
||||
)
|
||||
else:
|
||||
if web_uuid is None:
|
||||
raise ValueError("web_uuid is required for web clients")
|
||||
query = select(TrustedDevice).where(
|
||||
TrustedDevice.user_id == user_id,
|
||||
TrustedDevice.client_type == "web",
|
||||
TrustedDevice.web_uuid == web_uuid,
|
||||
)
|
||||
|
||||
device = (await db.exec(query)).first()
|
||||
if device is None:
|
||||
device = await LoginSessionService.create_trusted_device(db, user_id, ip_address, user_agent, web_uuid)
|
||||
else:
|
||||
device.last_used_at = utcnow()
|
||||
device.expires_at = utcnow() + timedelta(days=settings.device_trust_duration_days)
|
||||
await db.commit()
|
||||
await db.refresh(device)
|
||||
return device
|
||||
|
||||
@staticmethod
|
||||
async def mark_session_verified(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
token_id: int,
|
||||
ip_address: str,
|
||||
user_agent: UserAgentInfo,
|
||||
web_uuid: str | None = None,
|
||||
) -> bool:
|
||||
"""标记用户的未验证会话为已验证"""
|
||||
device_info: TrustedDevice | None = None
|
||||
if user_agent.is_client or web_uuid:
|
||||
device_info = await LoginSessionService.get_or_create_trusted_device(
|
||||
db, user_id, ip_address, user_agent, web_uuid
|
||||
)
|
||||
|
||||
try:
|
||||
# 查找用户所有未验证且未过期的会话
|
||||
result = await db.exec(
|
||||
@@ -631,18 +552,20 @@ class LoginSessionService:
|
||||
LoginSession.token_id == token_id,
|
||||
)
|
||||
)
|
||||
|
||||
sessions = result.all()
|
||||
|
||||
# 标记所有会话为已验证
|
||||
for session in sessions:
|
||||
session.is_verified = True
|
||||
session.verified_at = utcnow()
|
||||
if device_info:
|
||||
session.device_id = device_info.id
|
||||
|
||||
if sessions:
|
||||
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
|
||||
|
||||
await LoginSessionService.clear_login_method(user_id, token_id, redis)
|
||||
await db.commit()
|
||||
|
||||
return len(sessions) > 0
|
||||
|
||||
@@ -658,7 +581,7 @@ class LoginSessionService:
|
||||
await db.exec(
|
||||
select(exists()).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.is_verified == False, # noqa: E712
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > utcnow(),
|
||||
LoginSession.token_id == token_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user