feat(auth): support trusted device (#52)

New API to maintain sessions and devices:

- GET /api/private/admin/sessions
- DELETE /api/private/admin/sessions/{session_id}
- GET /api/private/admin/trusted-devices
- DELETE /api/private/admin/trusted-devices/{device_id}

Auth:

web clients request `/oauth/token` and `/api/v2/session/verify` with `X-UUID` header to save the client as trusted device.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
MingxuanGame
2025-10-03 11:26:43 +08:00
committed by GitHub
parent f34ed53a55
commit 40670c094b
28 changed files with 897 additions and 1456 deletions

View File

@@ -6,11 +6,14 @@ from __future__ import annotations
from datetime import timedelta
from app.database.verification import EmailVerification, LoginSession
from app.database.auth import OAuthToken
from app.database.verification import EmailVerification, LoginSession, TrustedDevice
from app.dependencies.database import with_db
from app.dependencies.scheduler import get_scheduler
from app.log import logger
from app.utils import utcnow
from sqlmodel import col, select
from sqlmodel import col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -69,7 +72,9 @@ class DatabaseCleanupService:
# 查找过期的登录会话记录
current_time = utcnow()
stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
stmt = select(LoginSession).where(
LoginSession.expires_at < current_time, col(LoginSession.is_verified).is_(False)
)
result = await db.exec(stmt)
expired_sessions = result.all()
@@ -179,50 +184,109 @@ class DatabaseCleanupService:
return 0
@staticmethod
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
async def cleanup_outdated_verified_sessions(db: AsyncSession) -> int:
"""
清理旧的已验证会话记录
清理过期会话记录
Args:
db: 数据库会话
days_old: 清理多少天前的已验证记录默认30天
Returns:
int: 清理的记录数
"""
try:
# 查找指定天数前的已验证会话记录
cutoff_time = utcnow() - timedelta(days=days_old)
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
stmt = select(LoginSession).where(
col(LoginSession.is_verified).is_(True), col(LoginSession.token_id).is_(None)
)
result = await db.exec(stmt)
all_verified_sessions = result.all()
# 筛选出过期的记录
old_verified_sessions = [
session
for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_time
]
# 删除旧的已验证记录
deleted_count = 0
for session in old_verified_sessions:
for session in result.all():
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days"
)
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} outdated verified sessions")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}")
logger.error(f"[Cleanup Service] Error cleaning outdated verified sessions: {e!s}")
return 0
@staticmethod
async def cleanup_outdated_trusted_devices(db: AsyncSession) -> int:
"""
清理过期的受信任设备记录
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
# 查找过期的受信任设备记录
current_time = utcnow()
stmt = select(TrustedDevice).where(TrustedDevice.expires_at < current_time)
result = await db.exec(stmt)
expired_devices = result.all()
# 删除过期的记录
deleted_count = 0
for device in expired_devices:
await db.delete(device)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired trusted devices")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired trusted devices: {e!s}")
return 0
@staticmethod
async def cleanup_outdated_tokens(db: AsyncSession) -> int:
"""
清理过期的 OAuth 令牌
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
current_time = utcnow()
stmt = select(OAuthToken).where(OAuthToken.refresh_token_expires_at < current_time)
result = await db.exec(stmt)
expired_tokens = result.all()
deleted_count = 0
for token in expired_tokens:
await db.delete(token)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired OAuth tokens")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired OAuth tokens: {e!s}")
return 0
@staticmethod
@@ -250,8 +314,14 @@ class DatabaseCleanupService:
# 清理7天前的已使用验证码
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
# 清理30天前的已验证会话
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
# 清理过期的受信任设备
results["outdated_trusted_devices"] = await DatabaseCleanupService.cleanup_outdated_trusted_devices(db)
# 清理过期的 OAuth 令牌
results["outdated_oauth_tokens"] = await DatabaseCleanupService.cleanup_outdated_tokens(db)
# 清理过期token 过期)的已验证会话
results["outdated_verified_sessions"] = await DatabaseCleanupService.cleanup_outdated_verified_sessions(db)
total_cleaned = sum(results.values())
if total_cleaned > 0:
@@ -279,21 +349,27 @@ class DatabaseCleanupService:
cutoff_30_days = current_time - timedelta(days=30)
# 统计过期的验证码数量
expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
expired_codes_stmt = (
select(func.count()).select_from(EmailVerification).where(EmailVerification.expires_at < current_time)
)
expired_codes_result = await db.exec(expired_codes_stmt)
expired_codes_count = len(expired_codes_result.all())
expired_codes_count = expired_codes_result.one()
# 统计过期的登录会话数量
expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
expired_sessions_stmt = (
select(func.count()).select_from(LoginSession).where(LoginSession.expires_at < current_time)
)
expired_sessions_result = await db.exec(expired_sessions_stmt)
expired_sessions_count = len(expired_sessions_result.all())
expired_sessions_count = expired_sessions_result.one()
# 统计1小时前未验证的登录会话数量
unverified_sessions_stmt = select(LoginSession).where(
col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour
unverified_sessions_stmt = (
select(func.count())
.select_from(LoginSession)
.where(col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour)
)
unverified_sessions_result = await db.exec(unverified_sessions_stmt)
unverified_sessions_count = len(unverified_sessions_result.all())
unverified_sessions_count = unverified_sessions_result.one()
# 统计7天前的已使用验证码数量
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
@@ -304,10 +380,10 @@ class DatabaseCleanupService:
)
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
all_verified_sessions = old_verified_sessions_result.all()
old_verified_sessions_count = len(
outdated_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
outdated_verified_sessions_result = await db.exec(outdated_verified_sessions_stmt)
all_verified_sessions = outdated_verified_sessions_result.all()
outdated_verified_sessions_count = len(
[
session
for session in all_verified_sessions
@@ -315,17 +391,35 @@ class DatabaseCleanupService:
]
)
# 统计过期的 OAuth 令牌数量
outdated_tokens_stmt = (
select(func.count()).select_from(OAuthToken).where(OAuthToken.refresh_token_expires_at < current_time)
)
outdated_tokens_result = await db.exec(outdated_tokens_stmt)
outdated_tokens_count = outdated_tokens_result.one()
# 统计过期的受信任设备数量
outdated_devices_stmt = (
select(func.count()).select_from(TrustedDevice).where(TrustedDevice.expires_at < current_time)
)
outdated_devices_result = await db.exec(outdated_devices_stmt)
outdated_devices_count = outdated_devices_result.one()
return {
"expired_verification_codes": expired_codes_count,
"expired_login_sessions": expired_sessions_count,
"unverified_login_sessions": unverified_sessions_count,
"old_used_verification_codes": old_used_codes_count,
"old_verified_sessions": old_verified_sessions_count,
"outdated_verified_sessions": outdated_verified_sessions_count,
"outdated_oauth_tokens": outdated_tokens_count,
"outdated_trusted_devices": outdated_devices_count,
"total_cleanable": expired_codes_count
+ expired_sessions_count
+ unverified_sessions_count
+ old_used_codes_count
+ old_verified_sessions_count,
+ outdated_verified_sessions_count
+ outdated_tokens_count
+ outdated_devices_count,
}
except Exception as e:
@@ -335,6 +429,23 @@ class DatabaseCleanupService:
"expired_login_sessions": 0,
"unverified_login_sessions": 0,
"old_used_verification_codes": 0,
"old_verified_sessions": 0,
"outdated_verified_sessions": 0,
"outdated_oauth_tokens": 0,
"outdated_trusted_devices": 0,
"total_cleanable": 0,
}
@get_scheduler().scheduled_job(
"interval",
id="cleanup_database",
hours=1,
)
async def scheduled_cleanup_job():
async with with_db() as session:
logger.debug("Starting database cleanup...")
results = await DatabaseCleanupService.run_full_cleanup(session)
total = sum(results.values())
if total > 0:
logger.debug(f"Cleanup completed, total records cleaned: {total}")
return results