feat(auth): support trusted device (#52)

New API to maintain sessions and devices:

- GET /api/private/admin/sessions
- DELETE /api/private/admin/sessions/{session_id}
- GET /api/private/admin/trusted-devices
- DELETE /api/private/admin/trusted-devices/{device_id}

Auth:

web clients request `/oauth/token` and `/api/v2/session/verify` with `X-UUID` header to save the client as trusted device.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
MingxuanGame
2025-10-03 11:26:43 +08:00
committed by GitHub
parent f34ed53a55
commit 40670c094b
28 changed files with 897 additions and 1456 deletions

View File

@@ -10,15 +10,15 @@ import string
from typing import Literal
from app.config import settings
from app.database.verification import EmailVerification, LoginSession
from app.database.auth import OAuthToken
from app.database.verification import EmailVerification, LoginSession, TrustedDevice
from app.log import logger
from app.service.client_detection_service import ClientDetectionService, ClientInfo
from app.service.device_trust_service import DeviceTrustService
from app.service.email_queue import email_queue # 导入邮件队列
from app.models.model import UserAgentInfo
from app.service.email_queue import email_queue
from app.utils import utcnow
from redis.asyncio import Redis
from sqlmodel import exists, select
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -248,11 +248,9 @@ This email was sent automatically, please do not reply.
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
user_agent: UserAgentInfo | None = None,
) -> bool:
"""发送验证邮件(带智能检测)"""
"""发送验证邮件"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
@@ -260,32 +258,14 @@ This email was sent automatically, please do not reply.
return True # 返回成功,但不执行验证流程
# 检测客户端信息
client_info = ClientDetectionService.detect_client(user_agent, client_id)
logger.info(
f"[Email Verification] Detected client for user {user_id}: "
f"{ClientDetectionService.format_client_display_name(client_info)}"
)
# 检查是否需要验证
needs_verification, reason = await DeviceTrustService.should_require_verification(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
is_new_location=True, # 这里需要从调用方传入
)
if not needs_verification:
logger.info(f"[Email Verification] Skipping verification for user {user_id}: {reason}")
return True
logger.info(f"[Email Verification] Detected client for user {user_id}: {user_agent}")
# 创建验证记录
(
_,
code,
) = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
db, redis, user_id, email, ip_address, user_agent.raw_ua if user_agent else None
)
# 使用邮件队列发送验证邮件
@@ -304,107 +284,6 @@ This email was sent automatically, please do not reply.
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
return False
@staticmethod
async def send_smart_verification_email(
db: AsyncSession,
redis: Redis,
user_id: int,
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
is_new_location: bool = False,
) -> tuple[bool, str, ClientInfo | None]:
"""
智能邮件验证发送
Args:
db: 数据库会话
redis: Redis 连接
user_id: 用户 ID
username: 用户名
email: 邮箱地址
ip_address: IP 地址
user_agent: 用户代理
client_id: 客户端 ID
country_code: 国家代码
is_new_location: 是否为新位置登录
Returns:
tuple[bool, str, ClientInfo | None]: (是否成功, 消息, 客户端信息)
"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
logger.debug(f"[Smart Verification] Email verification is disabled, skipping for user {user_id}")
return True, "邮件验证功能已禁用", None
# 检查是否启用智能验证
if not settings.enable_smart_verification:
logger.debug(
f"[Smart Verification] Smart verification is disabled, using legacy logic for user {user_id}"
)
# 回退到传统验证逻辑
verification, code = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
success = await EmailVerificationService.send_verification_email_via_queue(
email, code, username, user_id
)
return success, "使用传统验证逻辑发送邮件" if success else "传统验证邮件发送失败", None
# 检测客户端信息
client_info = ClientDetectionService.detect_client(user_agent, client_id)
client_display_name = ClientDetectionService.format_client_display_name(client_info)
logger.info(f"[Smart Verification] Detected client for user {user_id}: {client_display_name}")
# 检查是否需要验证
needs_verification, reason = await DeviceTrustService.should_require_verification(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
is_new_location=is_new_location,
)
if not needs_verification:
logger.info(f"[Smart Verification] Skipping verification for user {user_id}: {reason}")
# 即使不需要验证,也要更新设备信任信息
if client_info.device_fingerprint:
await DeviceTrustService.trust_device(redis, user_id, client_info.device_fingerprint, client_info)
if country_code:
await DeviceTrustService.trust_location(redis, user_id, country_code)
return True, f"跳过验证: {reason}", client_info
# 创建验证记录
verification, code = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
_ = verification # 避免未使用变量警告
# 使用邮件队列发送验证邮件
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
if success:
logger.info(
f"[Smart Verification] Successfully sent verification email to {email} "
f"for user {username} using {client_display_name}"
)
return True, "验证邮件已发送", client_info
else:
logger.error(f"[Smart Verification] Failed to send verification email: {email} (user: {username})")
return False, "验证邮件发送失败", client_info
except Exception as e:
logger.error(f"[Smart Verification] Exception during smart verification: {e}")
return False, f"验证过程中发生错误: {e!s}", None
@staticmethod
async def verify_email_code(
db: AsyncSession,
@@ -416,7 +295,7 @@ This email was sent automatically, please do not reply.
client_id: int | None = None,
country_code: str | None = None,
) -> tuple[bool, str]:
"""验证邮箱验证码(带智能信任更新)"""
"""验证邮箱验证码"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
@@ -452,16 +331,6 @@ This email was sent automatically, please do not reply.
# 删除 Redis 记录
await redis.delete(f"email_verification:{user_id}:{code}")
# 检测客户端信息并更新信任状态
client_info = ClientDetectionService.detect_client(user_agent, client_id)
await DeviceTrustService.mark_verification_successful(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
)
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
return True, "验证成功"
@@ -477,7 +346,7 @@ This email was sent automatically, please do not reply.
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None,
user_agent: UserAgentInfo | None = None,
) -> tuple[bool, str]:
"""重新发送验证码"""
try:
@@ -516,12 +385,12 @@ class LoginSessionService:
# Session verification interface methods
@staticmethod
async def find_for_verification(db: AsyncSession, session_id: str) -> LoginSession | None:
async def find_for_verification(db: AsyncSession, token: str) -> LoginSession | None:
"""根据会话ID查找会话用于验证"""
try:
result = await db.exec(
select(LoginSession).where(
LoginSession.session_token == session_id,
col(LoginSession.token).has(col(OAuthToken.access_token) == token),
LoginSession.expires_at > utcnow(),
)
)
@@ -537,42 +406,31 @@ class LoginSessionService:
@staticmethod
async def create_session(
db: AsyncSession,
redis: Redis,
user_id: int,
token_id: int,
ip_address: str,
user_agent: str | None = None,
country_code: str | None = None,
is_new_location: bool = False,
is_new_device: bool = False,
web_uuid: str | None = None,
is_verified: bool = False,
) -> LoginSession:
"""创建登录会话"""
session_token = EmailVerificationService.generate_session_token()
session = LoginSession(
user_id=user_id,
token_id=token_id,
ip_address=ip_address,
user_agent=None,
country_code=country_code,
is_new_location=is_new_location,
user_agent=user_agent,
is_new_device=is_new_device,
expires_at=utcnow() + timedelta(hours=24), # 24小时过期
is_verified=is_verified,
web_uuid=web_uuid,
)
db.add(session)
await db.commit()
await db.refresh(session)
# 存储到 Redis
await redis.setex(
f"login_session:{session_token}",
86400, # 24小时
user_id,
)
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
logger.info(f"[Login Session] Created session for user {user_id} (new device: {is_new_device})")
return session
@classmethod
@@ -592,35 +450,98 @@ class LoginSessionService:
await redis.delete(cls._session_verify_redis_key(user_id, token_id))
@staticmethod
async def check_new_location(
db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None
async def check_trusted_device(
db: AsyncSession, user_id: int, ip_address: str, user_agent: UserAgentInfo, web_uuid: str | None = None
) -> bool:
"""检查是否为新位置登录"""
try:
# 查看过去30天内是否有相同IP或相同国家的登录记录
thirty_days_ago = utcnow() - timedelta(days=30)
result = await db.exec(
select(LoginSession).where(
LoginSession.user_id == user_id,
LoginSession.created_at > thirty_days_ago,
(LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code),
)
if user_agent.is_client:
query = select(exists()).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "client",
TrustedDevice.ip_address == ip_address,
TrustedDevice.expires_at > utcnow(),
)
existing_sessions = result.all()
# 如果有历史记录,则不是新位置
return len(existing_sessions) == 0
except Exception as e:
logger.error(f"[Login Session] Exception during new location check: {e}")
# 出错时默认为新位置(更安全)
return True
else:
if web_uuid is None:
return False
query = select(exists()).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "web",
TrustedDevice.web_uuid == web_uuid,
TrustedDevice.expires_at > utcnow(),
)
return (await db.exec(query)).first() or False
@staticmethod
async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool:
async def create_trusted_device(
db: AsyncSession,
user_id: int,
ip_address: str,
user_agent: UserAgentInfo,
web_uuid: str | None = None,
) -> TrustedDevice:
device = TrustedDevice(
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent.raw_ua,
client_type="client" if user_agent.is_client else "web",
web_uuid=web_uuid if not user_agent.is_client else None,
expires_at=utcnow() + timedelta(days=settings.device_trust_duration_days),
)
db.add(device)
await db.commit()
await db.refresh(device)
return device
@staticmethod
async def get_or_create_trusted_device(
db: AsyncSession,
user_id: int,
ip_address: str,
user_agent: UserAgentInfo,
web_uuid: str | None = None,
) -> TrustedDevice:
if user_agent.is_client:
query = select(TrustedDevice).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "client",
TrustedDevice.ip_address == ip_address,
)
else:
if web_uuid is None:
raise ValueError("web_uuid is required for web clients")
query = select(TrustedDevice).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "web",
TrustedDevice.web_uuid == web_uuid,
)
device = (await db.exec(query)).first()
if device is None:
device = await LoginSessionService.create_trusted_device(db, user_id, ip_address, user_agent, web_uuid)
else:
device.last_used_at = utcnow()
device.expires_at = utcnow() + timedelta(days=settings.device_trust_duration_days)
await db.commit()
await db.refresh(device)
return device
@staticmethod
async def mark_session_verified(
db: AsyncSession,
redis: Redis,
user_id: int,
token_id: int,
ip_address: str,
user_agent: UserAgentInfo,
web_uuid: str | None = None,
) -> bool:
"""标记用户的未验证会话为已验证"""
device_info: TrustedDevice | None = None
if user_agent.is_client or web_uuid:
device_info = await LoginSessionService.get_or_create_trusted_device(
db, user_id, ip_address, user_agent, web_uuid
)
try:
# 查找用户所有未验证且未过期的会话
result = await db.exec(
@@ -631,18 +552,20 @@ class LoginSessionService:
LoginSession.token_id == token_id,
)
)
sessions = result.all()
# 标记所有会话为已验证
for session in sessions:
session.is_verified = True
session.verified_at = utcnow()
if device_info:
session.device_id = device_info.id
if sessions:
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
await LoginSessionService.clear_login_method(user_id, token_id, redis)
await db.commit()
return len(sessions) > 0
@@ -658,7 +581,7 @@ class LoginSessionService:
await db.exec(
select(exists()).where(
LoginSession.user_id == user_id,
LoginSession.is_verified == False, # noqa: E712
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > utcnow(),
LoginSession.token_id == token_id,
)