feat(auth): support trusted device (#52)

New API to maintain sessions and devices:

- GET /api/private/admin/sessions
- DELETE /api/private/admin/sessions/{session_id}
- GET /api/private/admin/trusted-devices
- DELETE /api/private/admin/trusted-devices/{device_id}

Auth:

web clients request `/oauth/token` and `/api/v2/session/verify` with `X-UUID` header to save the client as trusted device.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
MingxuanGame
2025-10-03 11:26:43 +08:00
committed by GitHub
parent f34ed53a55
commit 40670c094b
28 changed files with 897 additions and 1456 deletions

View File

@@ -1,230 +0,0 @@
"""
客户端检测服务
用于识别不同类型的 osu! 客户端和设备
"""
from __future__ import annotations
from dataclasses import dataclass
import hashlib
import re
from typing import ClassVar, Literal
from app.log import logger
@dataclass
class ClientInfo:
"""客户端信息"""
client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"]
platform: str | None = None
version: str | None = None
device_fingerprint: str | None = None
is_trusted_client: bool = False
class ClientDetectionService:
"""客户端检测服务"""
# osu! 客户端的 User-Agent 模式
OSU_CLIENT_PATTERNS: ClassVar[dict[str, list[str]]] = {
"osu_stable": [
r"osu!/(\d+(?:\.\d+)*)", # osu!/20241001
r"osu!", # 简单匹配
],
"osu_lazer": [
r"osu-lazer/(\d+(?:\.\d+)*)", # osu-lazer/2024.1009.0
r"osu!lazer/(\d+(?:\.\d+)*)", # osu!lazer/2024.1009.0
],
"osu_web": [
r"Mozilla.*osu\.ppy\.sh", # 网页客户端
],
"mobile": [
r"osu!.*mobile",
r"osu.*Mobile",
r"Mobile.*osu",
],
}
# 受信任的客户端类型(不需要频繁验证)
TRUSTED_CLIENT_TYPES: ClassVar[set[str]] = {"osu_stable", "osu_lazer"}
@staticmethod
def detect_client(user_agent: str | None, client_id: int | None = None) -> ClientInfo:
"""
检测客户端类型和信息
Args:
user_agent: 用户代理字符串
client_id: OAuth 客户端 ID
Returns:
ClientInfo: 客户端信息
"""
from app.config import settings # 导入在函数内部避免循环导入
if not user_agent:
return ClientInfo(client_type="unknown")
# 优先通过 client_id 判断客户端类型
if client_id is not None:
if client_id == settings.osu_client_id:
# osu! stable 客户端
return ClientInfo(
client_type="osu_stable",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=True,
)
elif client_id == settings.osu_web_client_id:
# 检查 User-Agent 是否表明这是 Lazer 客户端
if user_agent and user_agent.strip() == "osu!":
# Lazer 客户端使用 web client_id 但发送简单的 "osu!" User-Agent
return ClientInfo(
client_type="osu_lazer",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=True,
)
else:
# 真正的 web 客户端
return ClientInfo(
client_type="osu_web",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=False,
)
# 回退到基于 User-Agent 的检测
for client_type_str, patterns in ClientDetectionService.OSU_CLIENT_PATTERNS.items():
for pattern in patterns:
match = re.search(pattern, user_agent, re.IGNORECASE)
if match:
version = match.group(1) if match.groups() else None
platform = ClientDetectionService._extract_platform(user_agent)
# 确保 client_type 是正确的 Literal 类型
client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"] = client_type_str # type: ignore
return ClientInfo(
client_type=client_type,
platform=platform,
version=version,
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=client_type in ClientDetectionService.TRUSTED_CLIENT_TYPES,
)
# 检测常见浏览器
if any(browser in user_agent.lower() for browser in ["chrome", "firefox", "safari", "edge"]):
return ClientInfo(
client_type="osu_web",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=False,
)
return ClientInfo(
client_type="unknown",
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=False,
)
@staticmethod
def _extract_platform(user_agent: str) -> str | None:
"""从 User-Agent 中提取平台信息"""
platforms = {
"windows": ["windows", "win32", "win64"],
"macos": ["macintosh", "mac os", "darwin"],
"linux": ["linux", "ubuntu", "debian"],
"android": ["android"],
"ios": ["iphone", "ipad", "ios"],
}
user_agent_lower = user_agent.lower()
for platform, keywords in platforms.items():
if any(keyword in user_agent_lower for keyword in keywords):
return platform
return None
@staticmethod
def _generate_device_fingerprint(user_agent: str) -> str:
"""生成设备指纹"""
# 使用 User-Agent 的哈希值作为简单的设备指纹
# 在实际应用中可以结合更多信息IP、屏幕分辨率等
return hashlib.sha256(user_agent.encode()).hexdigest()[:16]
@staticmethod
def should_skip_email_verification(
client_info: ClientInfo,
is_new_location: bool,
user_id: int,
) -> bool:
"""
判断是否应该跳过邮件验证
Args:
client_info: 客户端信息
is_new_location: 是否为新位置登录
user_id: 用户 ID
Returns:
bool: 是否应该跳过邮件验证
"""
# 受信任的客户端类型可以减少验证频率
if client_info.is_trusted_client:
logger.info(
f"[Client Detection] Trusted client {client_info.client_type} for user {user_id}, "
f"reducing verification requirements"
)
return True
# 如果不是新位置,跳过验证
if not is_new_location:
return True
return False
@staticmethod
def get_verification_cooldown(client_info: ClientInfo) -> int:
"""
获取验证冷却时间(秒)
Args:
client_info: 客户端信息
Returns:
int: 冷却时间(秒)
"""
# 受信任的客户端有更长的冷却时间
if client_info.is_trusted_client:
return 3600 # 1小时
# 网页客户端较短的冷却时间
if client_info.client_type == "osu_web":
return 1800 # 30分钟
# 未知客户端最短冷却时间
return 900 # 15分钟
@staticmethod
def format_client_display_name(client_info: ClientInfo) -> str:
"""格式化客户端显示名称"""
display_names = {
"osu_stable": "osu! (stable)",
"osu_lazer": "osu!(lazer)",
"osu_web": "osu! web",
"mobile": "osu! mobile",
"unknown": "Unknown client",
}
base_name = display_names.get(client_info.client_type, "Unknown client")
if client_info.version:
base_name += f" v{client_info.version}"
if client_info.platform:
base_name += f" ({client_info.platform})"
return base_name

View File

@@ -6,11 +6,14 @@ from __future__ import annotations
from datetime import timedelta
from app.database.verification import EmailVerification, LoginSession
from app.database.auth import OAuthToken
from app.database.verification import EmailVerification, LoginSession, TrustedDevice
from app.dependencies.database import with_db
from app.dependencies.scheduler import get_scheduler
from app.log import logger
from app.utils import utcnow
from sqlmodel import col, select
from sqlmodel import col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -69,7 +72,9 @@ class DatabaseCleanupService:
# 查找过期的登录会话记录
current_time = utcnow()
stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
stmt = select(LoginSession).where(
LoginSession.expires_at < current_time, col(LoginSession.is_verified).is_(False)
)
result = await db.exec(stmt)
expired_sessions = result.all()
@@ -179,50 +184,109 @@ class DatabaseCleanupService:
return 0
@staticmethod
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
async def cleanup_outdated_verified_sessions(db: AsyncSession) -> int:
"""
清理旧的已验证会话记录
清理过期会话记录
Args:
db: 数据库会话
days_old: 清理多少天前的已验证记录默认30天
Returns:
int: 清理的记录数
"""
try:
# 查找指定天数前的已验证会话记录
cutoff_time = utcnow() - timedelta(days=days_old)
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
stmt = select(LoginSession).where(
col(LoginSession.is_verified).is_(True), col(LoginSession.token_id).is_(None)
)
result = await db.exec(stmt)
all_verified_sessions = result.all()
# 筛选出过期的记录
old_verified_sessions = [
session
for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_time
]
# 删除旧的已验证记录
deleted_count = 0
for session in old_verified_sessions:
for session in result.all():
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days"
)
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} outdated verified sessions")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}")
logger.error(f"[Cleanup Service] Error cleaning outdated verified sessions: {e!s}")
return 0
@staticmethod
async def cleanup_outdated_trusted_devices(db: AsyncSession) -> int:
"""
清理过期的受信任设备记录
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
# 查找过期的受信任设备记录
current_time = utcnow()
stmt = select(TrustedDevice).where(TrustedDevice.expires_at < current_time)
result = await db.exec(stmt)
expired_devices = result.all()
# 删除过期的记录
deleted_count = 0
for device in expired_devices:
await db.delete(device)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired trusted devices")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired trusted devices: {e!s}")
return 0
@staticmethod
async def cleanup_outdated_tokens(db: AsyncSession) -> int:
"""
清理过期的 OAuth 令牌
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
current_time = utcnow()
stmt = select(OAuthToken).where(OAuthToken.refresh_token_expires_at < current_time)
result = await db.exec(stmt)
expired_tokens = result.all()
deleted_count = 0
for token in expired_tokens:
await db.delete(token)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired OAuth tokens")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired OAuth tokens: {e!s}")
return 0
@staticmethod
@@ -250,8 +314,14 @@ class DatabaseCleanupService:
# 清理7天前的已使用验证码
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
# 清理30天前的已验证会话
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
# 清理过期的受信任设备
results["outdated_trusted_devices"] = await DatabaseCleanupService.cleanup_outdated_trusted_devices(db)
# 清理过期的 OAuth 令牌
results["outdated_oauth_tokens"] = await DatabaseCleanupService.cleanup_outdated_tokens(db)
# 清理过期token 过期)的已验证会话
results["outdated_verified_sessions"] = await DatabaseCleanupService.cleanup_outdated_verified_sessions(db)
total_cleaned = sum(results.values())
if total_cleaned > 0:
@@ -279,21 +349,27 @@ class DatabaseCleanupService:
cutoff_30_days = current_time - timedelta(days=30)
# 统计过期的验证码数量
expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
expired_codes_stmt = (
select(func.count()).select_from(EmailVerification).where(EmailVerification.expires_at < current_time)
)
expired_codes_result = await db.exec(expired_codes_stmt)
expired_codes_count = len(expired_codes_result.all())
expired_codes_count = expired_codes_result.one()
# 统计过期的登录会话数量
expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
expired_sessions_stmt = (
select(func.count()).select_from(LoginSession).where(LoginSession.expires_at < current_time)
)
expired_sessions_result = await db.exec(expired_sessions_stmt)
expired_sessions_count = len(expired_sessions_result.all())
expired_sessions_count = expired_sessions_result.one()
# 统计1小时前未验证的登录会话数量
unverified_sessions_stmt = select(LoginSession).where(
col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour
unverified_sessions_stmt = (
select(func.count())
.select_from(LoginSession)
.where(col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour)
)
unverified_sessions_result = await db.exec(unverified_sessions_stmt)
unverified_sessions_count = len(unverified_sessions_result.all())
unverified_sessions_count = unverified_sessions_result.one()
# 统计7天前的已使用验证码数量
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
@@ -304,10 +380,10 @@ class DatabaseCleanupService:
)
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
all_verified_sessions = old_verified_sessions_result.all()
old_verified_sessions_count = len(
outdated_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
outdated_verified_sessions_result = await db.exec(outdated_verified_sessions_stmt)
all_verified_sessions = outdated_verified_sessions_result.all()
outdated_verified_sessions_count = len(
[
session
for session in all_verified_sessions
@@ -315,17 +391,35 @@ class DatabaseCleanupService:
]
)
# 统计过期的 OAuth 令牌数量
outdated_tokens_stmt = (
select(func.count()).select_from(OAuthToken).where(OAuthToken.refresh_token_expires_at < current_time)
)
outdated_tokens_result = await db.exec(outdated_tokens_stmt)
outdated_tokens_count = outdated_tokens_result.one()
# 统计过期的受信任设备数量
outdated_devices_stmt = (
select(func.count()).select_from(TrustedDevice).where(TrustedDevice.expires_at < current_time)
)
outdated_devices_result = await db.exec(outdated_devices_stmt)
outdated_devices_count = outdated_devices_result.one()
return {
"expired_verification_codes": expired_codes_count,
"expired_login_sessions": expired_sessions_count,
"unverified_login_sessions": unverified_sessions_count,
"old_used_verification_codes": old_used_codes_count,
"old_verified_sessions": old_verified_sessions_count,
"outdated_verified_sessions": outdated_verified_sessions_count,
"outdated_oauth_tokens": outdated_tokens_count,
"outdated_trusted_devices": outdated_devices_count,
"total_cleanable": expired_codes_count
+ expired_sessions_count
+ unverified_sessions_count
+ old_used_codes_count
+ old_verified_sessions_count,
+ outdated_verified_sessions_count
+ outdated_tokens_count
+ outdated_devices_count,
}
except Exception as e:
@@ -335,6 +429,23 @@ class DatabaseCleanupService:
"expired_login_sessions": 0,
"unverified_login_sessions": 0,
"old_used_verification_codes": 0,
"old_verified_sessions": 0,
"outdated_verified_sessions": 0,
"outdated_oauth_tokens": 0,
"outdated_trusted_devices": 0,
"total_cleanable": 0,
}
@get_scheduler().scheduled_job(
"interval",
id="cleanup_database",
hours=1,
)
async def scheduled_cleanup_job():
async with with_db() as session:
logger.debug("Starting database cleanup...")
results = await DatabaseCleanupService.run_full_cleanup(session)
total = sum(results.values())
if total > 0:
logger.debug(f"Cleanup completed, total records cleaned: {total}")
return results

View File

@@ -1,283 +0,0 @@
"""
设备信任服务
管理用户的受信任设备,减少频繁验证
"""
from __future__ import annotations
from datetime import timedelta
from app.config import settings
from app.log import logger
from app.service.client_detection_service import ClientInfo
from app.utils import utcnow
from redis.asyncio import Redis
class DeviceTrustService:
"""设备信任服务"""
@staticmethod
def _get_device_trust_key(user_id: int, device_fingerprint: str) -> str:
"""获取设备信任的 Redis 键"""
return f"device_trust:{user_id}:{device_fingerprint}"
@staticmethod
def _get_location_trust_key(user_id: int, country_code: str) -> str:
"""获取位置信任的 Redis 键"""
return f"location_trust:{user_id}:{country_code}"
@staticmethod
def _get_verification_cooldown_key(user_id: int) -> str:
"""获取验证冷却的 Redis 键"""
return f"verification_cooldown:{user_id}"
@staticmethod
async def is_device_trusted(
redis: Redis,
user_id: int,
device_fingerprint: str,
) -> bool:
"""
检查设备是否受信任
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
Returns:
bool: 设备是否受信任
"""
if not device_fingerprint:
return False
trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint)
trust_data = await redis.get(trust_key)
return trust_data is not None
@staticmethod
async def is_location_trusted(
redis: Redis,
user_id: int,
country_code: str | None,
) -> bool:
"""
检查位置是否受信任
Args:
redis: Redis 连接
user_id: 用户 ID
country_code: 国家代码
Returns:
bool: 位置是否受信任
"""
if not country_code:
return False
trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code)
trust_data = await redis.get(trust_key)
return trust_data is not None
@staticmethod
async def is_in_verification_cooldown(
redis: Redis,
user_id: int,
) -> bool:
"""
检查用户是否在验证冷却期内
Args:
redis: Redis 连接
user_id: 用户 ID
Returns:
bool: 是否在冷却期内
"""
cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id)
cooldown_data = await redis.get(cooldown_key)
return cooldown_data is not None
@staticmethod
async def trust_device(
redis: Redis,
user_id: int,
device_fingerprint: str,
client_info: ClientInfo,
trust_duration_days: int | None = None,
) -> None:
"""
信任设备
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
client_info: 客户端信息
trust_duration_days: 信任持续天数
"""
if not device_fingerprint:
return
# 使用配置中的默认值
if trust_duration_days is None:
trust_duration_days = settings.device_trust_duration_days
trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint)
trust_data = {
"client_type": client_info.client_type,
"platform": client_info.platform or "unknown",
"trusted_at": utcnow().isoformat(),
}
# 设置信任期限
trust_duration_seconds = trust_duration_days * 24 * 3600
await redis.setex(trust_key, trust_duration_seconds, str(trust_data))
logger.info(
f"[Device Trust] Device trusted for user {user_id}: "
f"{client_info.client_type} on {client_info.platform} "
f"(fingerprint: {device_fingerprint[:8]}...)"
)
@staticmethod
async def trust_location(
redis: Redis,
user_id: int,
country_code: str,
trust_duration_days: int | None = None,
) -> None:
"""
信任位置
Args:
redis: Redis 连接
user_id: 用户 ID
country_code: 国家代码
trust_duration_days: 信任持续天数
"""
if not country_code:
return
# 使用配置中的默认值
if trust_duration_days is None:
trust_duration_days = settings.location_trust_duration_days
trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code)
trust_data = {
"country_code": country_code,
"trusted_at": utcnow().isoformat(),
}
# 设置信任期限
trust_duration_seconds = trust_duration_days * 24 * 3600
await redis.setex(trust_key, trust_duration_seconds, str(trust_data))
logger.info(f"[Location Trust] Location trusted for user {user_id}: {country_code}")
@staticmethod
async def set_verification_cooldown(
redis: Redis,
user_id: int,
cooldown_seconds: int,
) -> None:
"""
设置验证冷却期
Args:
redis: Redis 连接
user_id: 用户 ID
cooldown_seconds: 冷却时间(秒)
"""
cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id)
cooldown_data = {
"set_at": utcnow().isoformat(),
"expires_at": (utcnow() + timedelta(seconds=cooldown_seconds)).isoformat(),
}
await redis.setex(cooldown_key, cooldown_seconds, str(cooldown_data))
logger.info(f"[Verification Cooldown] Set cooldown for user {user_id}: {cooldown_seconds}s")
@staticmethod
async def should_require_verification(
redis: Redis,
user_id: int,
device_fingerprint: str | None,
country_code: str | None,
client_info: ClientInfo,
is_new_location: bool,
) -> tuple[bool, str]:
"""
判断是否需要验证
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
country_code: 国家代码
client_info: 客户端信息
is_new_location: 是否为新位置
Returns:
tuple[bool, str]: (是否需要验证, 原因)
"""
# 检查验证冷却期
if await DeviceTrustService.is_in_verification_cooldown(redis, user_id):
return False, "用户在验证冷却期内"
# 检查设备信任
if device_fingerprint and await DeviceTrustService.is_device_trusted(redis, user_id, device_fingerprint):
return False, "设备已受信任"
# 检查位置信任
if country_code and await DeviceTrustService.is_location_trusted(redis, user_id, country_code):
return False, "位置已受信任"
# 受信任的客户端类型降低验证要求
if client_info.is_trusted_client and not is_new_location:
return False, "受信任客户端且非新位置"
# 如果是新位置登录,需要验证
if is_new_location:
return True, "新位置登录需要验证"
# 默认不需要验证
return False, "常规登录无需验证"
@staticmethod
async def mark_verification_successful(
redis: Redis,
user_id: int,
device_fingerprint: str | None,
country_code: str | None,
client_info: ClientInfo,
) -> None:
"""
标记验证成功,更新信任信息
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
country_code: 国家代码
client_info: 客户端信息
"""
# 信任设备
if device_fingerprint:
await DeviceTrustService.trust_device(redis, user_id, device_fingerprint, client_info)
# 信任位置
if country_code:
await DeviceTrustService.trust_location(redis, user_id, country_code)
# 设置验证冷却期
cooldown_seconds = (client_info.is_trusted_client and 3600) or 1800 # 受信任客户端1小时其他30分钟
await DeviceTrustService.set_verification_cooldown(redis, user_id, cooldown_seconds)
logger.info(f"[Device Trust] Verification successful for user {user_id}, trust updated")

View File

@@ -9,7 +9,7 @@ import asyncio
from app.database.user_login_log import UserLoginLog
from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip
from app.log import logger
from app.utils import simplify_user_agent, utcnow
from app.utils import utcnow
from fastapi import Request
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -23,6 +23,7 @@ class LoginLogService:
db: AsyncSession,
user_id: int,
request: Request,
user_agent: str | None = None,
login_success: bool = True,
login_method: str = "password",
notes: str | None = None,
@@ -45,9 +46,6 @@ class LoginLogService:
raw_ip = get_client_ip(request)
ip_address = normalize_ip(raw_ip)
raw_user_agent = request.headers.get("User-Agent", "")
user_agent = simplify_user_agent(raw_user_agent, max_length=500)
# 创建基本的登录记录
login_log = UserLoginLog(
user_id=user_id,
@@ -107,6 +105,7 @@ class LoginLogService:
attempted_username: str | None = None,
login_method: str = "password",
notes: str | None = None,
user_agent: str | None = None,
) -> UserLoginLog:
"""
记录失败的登录尝试
@@ -128,6 +127,7 @@ class LoginLogService:
request=request,
login_success=False,
login_method=login_method,
user_agent=user_agent,
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt",
)

View File

@@ -1,122 +0,0 @@
"""
API 状态管理 - 模拟 osu! 的 APIState 和会话管理
"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
class APIState(str, Enum):
"""API 连接状态,对应 osu! 的 APIState"""
OFFLINE = "offline"
CONNECTING = "connecting"
REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证
ONLINE = "online"
FAILING = "failing"
class UserSession(BaseModel):
"""用户会话信息"""
user_id: int
username: str
email: str
session_token: str | None = None
state: APIState = APIState.OFFLINE
requires_verification: bool = False
verification_sent: bool = False
last_verification_attempt: datetime | None = None
failed_attempts: int = 0
ip_address: str | None = None
country_code: str | None = None
is_new_location: bool = False
class SessionManager:
"""会话管理器"""
def __init__(self):
self._sessions: dict[str, UserSession] = {}
def create_session(
self,
user_id: int,
username: str,
email: str,
ip_address: str,
country_code: str | None = None,
is_new_location: bool = False,
) -> UserSession:
"""创建新的用户会话"""
import secrets
session_token = secrets.token_urlsafe(32)
# 根据是否为新位置决定初始状态
if is_new_location:
state = APIState.REQUIRES_SECOND_FACTOR_AUTH
else:
state = APIState.ONLINE
session = UserSession(
user_id=user_id,
username=username,
email=email,
session_token=session_token,
state=state,
requires_verification=is_new_location,
ip_address=ip_address,
country_code=country_code,
is_new_location=is_new_location,
)
self._sessions[session_token] = session
return session
def get_session(self, session_token: str) -> UserSession | None:
"""获取会话"""
return self._sessions.get(session_token)
def update_session_state(self, session_token: str, state: APIState):
"""更新会话状态"""
if session_token in self._sessions:
self._sessions[session_token].state = state
def mark_verification_sent(self, session_token: str):
"""标记验证邮件已发送"""
if session_token in self._sessions:
session = self._sessions[session_token]
session.verification_sent = True
session.last_verification_attempt = datetime.now()
def increment_failed_attempts(self, session_token: str):
"""增加失败尝试次数"""
if session_token in self._sessions:
self._sessions[session_token].failed_attempts += 1
def verify_session(self, session_token: str) -> bool:
"""验证会话成功"""
if session_token in self._sessions:
session = self._sessions[session_token]
session.state = APIState.ONLINE
session.requires_verification = False
return True
return False
def remove_session(self, session_token: str):
"""移除会话"""
self._sessions.pop(session_token, None)
def cleanup_expired_sessions(self):
"""清理过期会话"""
# 这里可以实现清理逻辑
pass
# 全局会话管理器
session_manager = SessionManager()

View File

@@ -10,15 +10,15 @@ import string
from typing import Literal
from app.config import settings
from app.database.verification import EmailVerification, LoginSession
from app.database.auth import OAuthToken
from app.database.verification import EmailVerification, LoginSession, TrustedDevice
from app.log import logger
from app.service.client_detection_service import ClientDetectionService, ClientInfo
from app.service.device_trust_service import DeviceTrustService
from app.service.email_queue import email_queue # 导入邮件队列
from app.models.model import UserAgentInfo
from app.service.email_queue import email_queue
from app.utils import utcnow
from redis.asyncio import Redis
from sqlmodel import exists, select
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -248,11 +248,9 @@ This email was sent automatically, please do not reply.
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
user_agent: UserAgentInfo | None = None,
) -> bool:
"""发送验证邮件(带智能检测)"""
"""发送验证邮件"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
@@ -260,32 +258,14 @@ This email was sent automatically, please do not reply.
return True # 返回成功,但不执行验证流程
# 检测客户端信息
client_info = ClientDetectionService.detect_client(user_agent, client_id)
logger.info(
f"[Email Verification] Detected client for user {user_id}: "
f"{ClientDetectionService.format_client_display_name(client_info)}"
)
# 检查是否需要验证
needs_verification, reason = await DeviceTrustService.should_require_verification(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
is_new_location=True, # 这里需要从调用方传入
)
if not needs_verification:
logger.info(f"[Email Verification] Skipping verification for user {user_id}: {reason}")
return True
logger.info(f"[Email Verification] Detected client for user {user_id}: {user_agent}")
# 创建验证记录
(
_,
code,
) = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
db, redis, user_id, email, ip_address, user_agent.raw_ua if user_agent else None
)
# 使用邮件队列发送验证邮件
@@ -304,107 +284,6 @@ This email was sent automatically, please do not reply.
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
return False
@staticmethod
async def send_smart_verification_email(
db: AsyncSession,
redis: Redis,
user_id: int,
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
is_new_location: bool = False,
) -> tuple[bool, str, ClientInfo | None]:
"""
智能邮件验证发送
Args:
db: 数据库会话
redis: Redis 连接
user_id: 用户 ID
username: 用户名
email: 邮箱地址
ip_address: IP 地址
user_agent: 用户代理
client_id: 客户端 ID
country_code: 国家代码
is_new_location: 是否为新位置登录
Returns:
tuple[bool, str, ClientInfo | None]: (是否成功, 消息, 客户端信息)
"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
logger.debug(f"[Smart Verification] Email verification is disabled, skipping for user {user_id}")
return True, "邮件验证功能已禁用", None
# 检查是否启用智能验证
if not settings.enable_smart_verification:
logger.debug(
f"[Smart Verification] Smart verification is disabled, using legacy logic for user {user_id}"
)
# 回退到传统验证逻辑
verification, code = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
success = await EmailVerificationService.send_verification_email_via_queue(
email, code, username, user_id
)
return success, "使用传统验证逻辑发送邮件" if success else "传统验证邮件发送失败", None
# 检测客户端信息
client_info = ClientDetectionService.detect_client(user_agent, client_id)
client_display_name = ClientDetectionService.format_client_display_name(client_info)
logger.info(f"[Smart Verification] Detected client for user {user_id}: {client_display_name}")
# 检查是否需要验证
needs_verification, reason = await DeviceTrustService.should_require_verification(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
is_new_location=is_new_location,
)
if not needs_verification:
logger.info(f"[Smart Verification] Skipping verification for user {user_id}: {reason}")
# 即使不需要验证,也要更新设备信任信息
if client_info.device_fingerprint:
await DeviceTrustService.trust_device(redis, user_id, client_info.device_fingerprint, client_info)
if country_code:
await DeviceTrustService.trust_location(redis, user_id, country_code)
return True, f"跳过验证: {reason}", client_info
# 创建验证记录
verification, code = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
_ = verification # 避免未使用变量警告
# 使用邮件队列发送验证邮件
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
if success:
logger.info(
f"[Smart Verification] Successfully sent verification email to {email} "
f"for user {username} using {client_display_name}"
)
return True, "验证邮件已发送", client_info
else:
logger.error(f"[Smart Verification] Failed to send verification email: {email} (user: {username})")
return False, "验证邮件发送失败", client_info
except Exception as e:
logger.error(f"[Smart Verification] Exception during smart verification: {e}")
return False, f"验证过程中发生错误: {e!s}", None
@staticmethod
async def verify_email_code(
db: AsyncSession,
@@ -416,7 +295,7 @@ This email was sent automatically, please do not reply.
client_id: int | None = None,
country_code: str | None = None,
) -> tuple[bool, str]:
"""验证邮箱验证码(带智能信任更新)"""
"""验证邮箱验证码"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
@@ -452,16 +331,6 @@ This email was sent automatically, please do not reply.
# 删除 Redis 记录
await redis.delete(f"email_verification:{user_id}:{code}")
# 检测客户端信息并更新信任状态
client_info = ClientDetectionService.detect_client(user_agent, client_id)
await DeviceTrustService.mark_verification_successful(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
)
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
return True, "验证成功"
@@ -477,7 +346,7 @@ This email was sent automatically, please do not reply.
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None,
user_agent: UserAgentInfo | None = None,
) -> tuple[bool, str]:
"""重新发送验证码"""
try:
@@ -516,12 +385,12 @@ class LoginSessionService:
# Session verification interface methods
@staticmethod
async def find_for_verification(db: AsyncSession, session_id: str) -> LoginSession | None:
async def find_for_verification(db: AsyncSession, token: str) -> LoginSession | None:
"""根据会话ID查找会话用于验证"""
try:
result = await db.exec(
select(LoginSession).where(
LoginSession.session_token == session_id,
col(LoginSession.token).has(col(OAuthToken.access_token) == token),
LoginSession.expires_at > utcnow(),
)
)
@@ -537,42 +406,31 @@ class LoginSessionService:
@staticmethod
async def create_session(
db: AsyncSession,
redis: Redis,
user_id: int,
token_id: int,
ip_address: str,
user_agent: str | None = None,
country_code: str | None = None,
is_new_location: bool = False,
is_new_device: bool = False,
web_uuid: str | None = None,
is_verified: bool = False,
) -> LoginSession:
"""创建登录会话"""
session_token = EmailVerificationService.generate_session_token()
session = LoginSession(
user_id=user_id,
token_id=token_id,
ip_address=ip_address,
user_agent=None,
country_code=country_code,
is_new_location=is_new_location,
user_agent=user_agent,
is_new_device=is_new_device,
expires_at=utcnow() + timedelta(hours=24), # 24小时过期
is_verified=is_verified,
web_uuid=web_uuid,
)
db.add(session)
await db.commit()
await db.refresh(session)
# 存储到 Redis
await redis.setex(
f"login_session:{session_token}",
86400, # 24小时
user_id,
)
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
logger.info(f"[Login Session] Created session for user {user_id} (new device: {is_new_device})")
return session
@classmethod
@@ -592,35 +450,98 @@ class LoginSessionService:
await redis.delete(cls._session_verify_redis_key(user_id, token_id))
@staticmethod
async def check_new_location(
db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None
async def check_trusted_device(
db: AsyncSession, user_id: int, ip_address: str, user_agent: UserAgentInfo, web_uuid: str | None = None
) -> bool:
"""检查是否为新位置登录"""
try:
# 查看过去30天内是否有相同IP或相同国家的登录记录
thirty_days_ago = utcnow() - timedelta(days=30)
result = await db.exec(
select(LoginSession).where(
LoginSession.user_id == user_id,
LoginSession.created_at > thirty_days_ago,
(LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code),
)
if user_agent.is_client:
query = select(exists()).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "client",
TrustedDevice.ip_address == ip_address,
TrustedDevice.expires_at > utcnow(),
)
existing_sessions = result.all()
# 如果有历史记录,则不是新位置
return len(existing_sessions) == 0
except Exception as e:
logger.error(f"[Login Session] Exception during new location check: {e}")
# 出错时默认为新位置(更安全)
return True
else:
if web_uuid is None:
return False
query = select(exists()).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "web",
TrustedDevice.web_uuid == web_uuid,
TrustedDevice.expires_at > utcnow(),
)
return (await db.exec(query)).first() or False
@staticmethod
async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool:
async def create_trusted_device(
db: AsyncSession,
user_id: int,
ip_address: str,
user_agent: UserAgentInfo,
web_uuid: str | None = None,
) -> TrustedDevice:
device = TrustedDevice(
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent.raw_ua,
client_type="client" if user_agent.is_client else "web",
web_uuid=web_uuid if not user_agent.is_client else None,
expires_at=utcnow() + timedelta(days=settings.device_trust_duration_days),
)
db.add(device)
await db.commit()
await db.refresh(device)
return device
@staticmethod
async def get_or_create_trusted_device(
db: AsyncSession,
user_id: int,
ip_address: str,
user_agent: UserAgentInfo,
web_uuid: str | None = None,
) -> TrustedDevice:
if user_agent.is_client:
query = select(TrustedDevice).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "client",
TrustedDevice.ip_address == ip_address,
)
else:
if web_uuid is None:
raise ValueError("web_uuid is required for web clients")
query = select(TrustedDevice).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "web",
TrustedDevice.web_uuid == web_uuid,
)
device = (await db.exec(query)).first()
if device is None:
device = await LoginSessionService.create_trusted_device(db, user_id, ip_address, user_agent, web_uuid)
else:
device.last_used_at = utcnow()
device.expires_at = utcnow() + timedelta(days=settings.device_trust_duration_days)
await db.commit()
await db.refresh(device)
return device
@staticmethod
async def mark_session_verified(
db: AsyncSession,
redis: Redis,
user_id: int,
token_id: int,
ip_address: str,
user_agent: UserAgentInfo,
web_uuid: str | None = None,
) -> bool:
"""标记用户的未验证会话为已验证"""
device_info: TrustedDevice | None = None
if user_agent.is_client or web_uuid:
device_info = await LoginSessionService.get_or_create_trusted_device(
db, user_id, ip_address, user_agent, web_uuid
)
try:
# 查找用户所有未验证且未过期的会话
result = await db.exec(
@@ -631,18 +552,20 @@ class LoginSessionService:
LoginSession.token_id == token_id,
)
)
sessions = result.all()
# 标记所有会话为已验证
for session in sessions:
session.is_verified = True
session.verified_at = utcnow()
if device_info:
session.device_id = device_info.id
if sessions:
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
await LoginSessionService.clear_login_method(user_id, token_id, redis)
await db.commit()
return len(sessions) > 0
@@ -658,7 +581,7 @@ class LoginSessionService:
await db.exec(
select(exists()).where(
LoginSession.user_id == user_id,
LoginSession.is_verified == False, # noqa: E712
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > utcnow(),
LoginSession.token_id == token_id,
)