添加邮件验证

This commit is contained in:
咕谷酱
2025-08-22 08:19:12 +08:00
parent 42f17d0c66
commit 3bee2421fa
19 changed files with 1594 additions and 22 deletions

View File

@@ -0,0 +1,289 @@
"""
数据库清理服务 - 清理过期的验证码和会话
"""
from __future__ import annotations
from datetime import datetime, UTC, timedelta
from app.database.email_verification import EmailVerification, LoginSession
from app.log import logger
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy import and_
class DatabaseCleanupService:
"""数据库清理服务"""
@staticmethod
async def cleanup_expired_verification_codes(db: AsyncSession) -> int:
"""
清理过期的邮件验证码
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
# 查找过期的验证码记录
current_time = datetime.now(UTC)
stmt = select(EmailVerification).where(
EmailVerification.expires_at < current_time
)
result = await db.exec(stmt)
expired_codes = result.all()
# 删除过期的记录
deleted_count = 0
for code in expired_codes:
await db.delete(code)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {str(e)}")
return 0
@staticmethod
async def cleanup_expired_login_sessions(db: AsyncSession) -> int:
"""
清理过期的登录会话
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
# 查找过期的登录会话记录
current_time = datetime.now(UTC)
stmt = select(LoginSession).where(
LoginSession.expires_at < current_time
)
result = await db.exec(stmt)
expired_sessions = result.all()
# 删除过期的记录
deleted_count = 0
for session in expired_sessions:
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {str(e)}")
return 0
@staticmethod
async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int:
"""
清理旧的已使用验证码记录
Args:
db: 数据库会话
days_old: 清理多少天前的已使用记录默认7天
Returns:
int: 清理的记录数
"""
try:
# 查找指定天数前的已使用验证码记录
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
stmt = select(EmailVerification).where(
EmailVerification.is_used == True
)
result = await db.exec(stmt)
all_used_codes = result.all()
# 筛选出过期的记录
old_used_codes = [
code for code in all_used_codes
if code.used_at and code.used_at < cutoff_time
]
# 删除旧的已使用记录
deleted_count = 0
for code in old_used_codes:
await db.delete(code)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {str(e)}")
return 0
@staticmethod
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
"""
清理旧的已验证会话记录
Args:
db: 数据库会话
days_old: 清理多少天前的已验证记录默认30天
Returns:
int: 清理的记录数
"""
try:
# 查找指定天数前的已验证会话记录
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
stmt = select(LoginSession).where(
LoginSession.is_verified == True
)
result = await db.exec(stmt)
all_verified_sessions = result.all()
# 筛选出过期的记录
old_verified_sessions = [
session for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_time
]
# 删除旧的已验证记录
deleted_count = 0
for session in old_verified_sessions:
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {str(e)}")
return 0
@staticmethod
async def run_full_cleanup(db: AsyncSession) -> dict[str, int]:
"""
运行完整的清理流程
Args:
db: 数据库会话
Returns:
dict: 各项清理的结果统计
"""
results = {}
# 清理过期的验证码
results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
# 清理过期的登录会话
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 清理7天前的已使用验证码
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
# 清理30天前的已验证会话
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
total_cleaned = sum(results.values())
if total_cleaned > 0:
logger.debug(f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}")
return results
@staticmethod
async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]:
"""
获取清理统计信息
Args:
db: 数据库会话
Returns:
dict: 统计信息
"""
try:
current_time = datetime.now(UTC)
cutoff_7_days = current_time - timedelta(days=7)
cutoff_30_days = current_time - timedelta(days=30)
# 统计过期的验证码数量
expired_codes_stmt = select(EmailVerification).where(
EmailVerification.expires_at < current_time
)
expired_codes_result = await db.exec(expired_codes_stmt)
expired_codes_count = len(expired_codes_result.all())
# 统计过期的登录会话数量
expired_sessions_stmt = select(LoginSession).where(
LoginSession.expires_at < current_time
)
expired_sessions_result = await db.exec(expired_sessions_stmt)
expired_sessions_count = len(expired_sessions_result.all())
# 统计7天前的已使用验证码数量
old_used_codes_stmt = select(EmailVerification).where(
EmailVerification.is_used == True
)
old_used_codes_result = await db.exec(old_used_codes_stmt)
all_used_codes = old_used_codes_result.all()
old_used_codes_count = len([
code for code in all_used_codes
if code.used_at and code.used_at < cutoff_7_days
])
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(
LoginSession.is_verified == True
)
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
all_verified_sessions = old_verified_sessions_result.all()
old_verified_sessions_count = len([
session for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_30_days
])
return {
"expired_verification_codes": expired_codes_count,
"expired_login_sessions": expired_sessions_count,
"old_used_verification_codes": old_used_codes_count,
"old_verified_sessions": old_verified_sessions_count,
"total_cleanable": expired_codes_count + expired_sessions_count + old_used_codes_count + old_verified_sessions_count
}
except Exception as e:
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {str(e)}")
return {
"expired_verification_codes": 0,
"expired_login_sessions": 0,
"old_used_verification_codes": 0,
"old_verified_sessions": 0,
"total_cleanable": 0
}

View File

@@ -0,0 +1,167 @@
"""
邮件验证服务
"""
from __future__ import annotations
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
import secrets
import string
from datetime import datetime, UTC, timedelta
from typing import Optional
from app.config import settings
from app.log import logger
class EmailService:
"""邮件发送服务"""
def __init__(self):
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
self.smtp_port = getattr(settings, 'smtp_port', 587)
self.smtp_username = getattr(settings, 'smtp_username', '')
self.smtp_password = getattr(settings, 'smtp_password', '')
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
self.from_name = getattr(settings, 'from_name', 'osu! server')
def generate_verification_code(self) -> str:
"""生成8位验证码"""
# 只使用数字,避免混淆
return ''.join(secrets.choice(string.digits) for _ in range(8))
async def send_verification_email(self, email: str, code: str, username: str) -> bool:
"""发送验证邮件"""
try:
msg = MIMEMultipart()
msg['From'] = f"{self.from_name} <{self.from_email}>"
msg['To'] = email
msg['Subject'] = "邮箱验证 - Email Verification"
# HTML 邮件内容
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
.container {{
max-width: 600px;
margin: 0 auto;
font-family: Arial, sans-serif;
line-height: 1.6;
}}
.header {{
background: linear-gradient(135deg, #ff66aa, #ff9966);
color: white;
padding: 20px;
text-align: center;
border-radius: 10px 10px 0 0;
}}
.content {{
background: #f9f9f9;
padding: 30px;
border: 1px solid #ddd;
}}
.code {{
background: #fff;
border: 2px solid #ff66aa;
border-radius: 8px;
padding: 15px;
text-align: center;
font-size: 24px;
font-weight: bold;
letter-spacing: 3px;
margin: 20px 0;
color: #333;
}}
.footer {{
background: #333;
color: #fff;
padding: 15px;
text-align: center;
border-radius: 0 0 10px 10px;
font-size: 12px;
}}
.warning {{
background: #fff3cd;
border: 1px solid #ffeaa7;
border-radius: 5px;
padding: 10px;
margin: 15px 0;
color: #856404;
}}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1> osu! 邮箱验证</h1>
<p>Email Verification</p>
</div>
<div class="content">
<h2>你好 {username}</h2>
<p>感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:</p>
<div class="code">{code}</div>
<p>这个验证码将在 <strong>10 分钟后过期</strong>。</p>
<div class="warning">
<strong>注意:</strong>
<ul>
<li>请不要与任何人分享这个验证码</li>
<li>如果你没有请求此验证码,请忽略这封邮件</li>
<li>验证码只能使用一次</li>
</ul>
</div>
<p>如果你有任何问题,请联系我们的支持团队。</p>
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
<h3>Hello {username}!</h3>
<p>Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:</p>
<p>This verification code will expire in <strong>10 minutes</strong>.</p>
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
</div>
<div class="footer">
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
<p>This email was sent automatically, please do not reply.</p>
</div>
</div>
</body>
</html>
"""
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
# 发送邮件
if not settings.enable_email_sending:
# 邮件发送功能禁用时只记录日志,不实际发送
logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}")
return True
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
if self.smtp_username and self.smtp_password:
server.starttls()
server.login(self.smtp_username, self.smtp_password)
server.send_message(msg)
logger.info(f"[Email Verification] Successfully sent verification code to {email}")
return True
except Exception as e:
logger.error(f"[Email Verification] Failed to send email: {e}")
return False
# 全局邮件服务实例
email_service = EmailService()

View File

@@ -0,0 +1,367 @@
"""
邮件验证管理服务
"""
from __future__ import annotations
import secrets
import string
from datetime import datetime, UTC, timedelta
from typing import Optional
from app.database.email_verification import EmailVerification, LoginSession
from app.service.email_service import email_service
from app.log import logger
from app.config import settings
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import select
from redis.asyncio import Redis
class EmailVerificationService:
"""邮件验证服务"""
@staticmethod
def generate_verification_code() -> str:
"""生成8位验证码"""
return ''.join(secrets.choice(string.digits) for _ in range(8))
@staticmethod
def generate_session_token() -> str:
"""生成会话令牌"""
return secrets.token_urlsafe(32)
@staticmethod
async def create_verification_record(
db: AsyncSession,
redis: Redis,
user_id: int,
email: str,
ip_address: str | None = None,
user_agent: str | None = None
) -> tuple[EmailVerification, str]:
"""创建邮件验证记录"""
# 检查是否有未过期的验证码
existing_result = await db.exec(
select(EmailVerification).where(
EmailVerification.user_id == user_id,
EmailVerification.is_used == False,
EmailVerification.expires_at > datetime.now(UTC)
)
)
existing = existing_result.first()
if existing:
# 如果有未过期的验证码,直接返回
return existing, existing.verification_code
# 生成新的验证码
code = EmailVerificationService.generate_verification_code()
# 创建验证记录
verification = EmailVerification(
user_id=user_id,
email=email,
verification_code=code,
expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期
ip_address=ip_address,
user_agent=user_agent
)
db.add(verification)
await db.commit()
await db.refresh(verification)
# 存储到 Redis用于快速验证
await redis.setex(
f"email_verification:{user_id}:{code}",
600, # 10分钟过期
str(verification.id) if verification.id else "0"
)
logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}")
return verification, code
@staticmethod
async def send_verification_email(
db: AsyncSession,
redis: Redis,
user_id: int,
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None
) -> bool:
"""发送验证邮件"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
return True # 返回成功,但不执行验证流程
# 创建验证记录
verification, code = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
# 发送邮件
success = await email_service.send_verification_email(email, code, username)
if success:
logger.info(f"[Email Verification] Successfully sent verification email to {email} (user: {username})")
return True
else:
logger.error(f"[Email Verification] Failed to send verification email: {email} (user: {username})")
return False
except Exception as e:
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
return False
@staticmethod
async def verify_code(
db: AsyncSession,
redis: Redis,
user_id: int,
code: str,
ip_address: str | None = None
) -> tuple[bool, str]:
"""验证验证码"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}")
# 仍然标记登录会话为已验证
await LoginSessionService.mark_session_verified(db, user_id)
return True, "验证成功(邮件验证功能已禁用)"
# 先从 Redis 检查
verification_id = await redis.get(f"email_verification:{user_id}:{code}")
if not verification_id:
return False, "验证码无效或已过期"
# 从数据库获取验证记录
result = await db.exec(
select(EmailVerification).where(
EmailVerification.id == int(verification_id),
EmailVerification.user_id == user_id,
EmailVerification.verification_code == code,
EmailVerification.is_used == False,
EmailVerification.expires_at > datetime.now(UTC)
)
)
verification = result.first()
if not verification:
return False, "验证码无效或已过期"
# 标记为已使用
verification.is_used = True
verification.used_at = datetime.now(UTC)
# 同时更新对应的登录会话状态
await LoginSessionService.mark_session_verified(db, user_id)
await db.commit()
# 删除 Redis 记录
await redis.delete(f"email_verification:{user_id}:{code}")
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
return True, "验证成功"
except Exception as e:
logger.error(f"[Email Verification] Exception during verification code validation: {e}")
return False, "验证过程中发生错误"
@staticmethod
async def resend_verification_code(
db: AsyncSession,
redis: Redis,
user_id: int,
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None
) -> tuple[bool, str]:
"""重新发送验证码"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
return True, "验证码已发送(邮件验证功能已禁用)"
# 检查重发频率限制60秒内只能发送一次
rate_limit_key = f"email_verification_rate_limit:{user_id}"
if await redis.get(rate_limit_key):
return False, "请等待60秒后再重新发送"
# 设置频率限制
await redis.setex(rate_limit_key, 60, "1")
# 生成新的验证码
success = await EmailVerificationService.send_verification_email(
db, redis, user_id, username, email, ip_address, user_agent
)
if success:
return True, "验证码已重新发送"
else:
return False, "重新发送失败,请稍后再试"
except Exception as e:
logger.error(f"[Email Verification] Exception during resending verification code: {e}")
return False, "重新发送过程中发生错误"
class LoginSessionService:
"""登录会话服务"""
@staticmethod
async def create_session(
db: AsyncSession,
redis: Redis,
user_id: int,
ip_address: str,
user_agent: str | None = None,
country_code: str | None = None,
is_new_location: bool = False
) -> LoginSession:
"""创建登录会话"""
session_token = EmailVerificationService.generate_session_token()
session = LoginSession(
user_id=user_id,
session_token=session_token,
ip_address=ip_address,
user_agent=user_agent,
country_code=country_code,
is_new_location=is_new_location,
expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期
is_verified=not is_new_location # 新位置需要验证
)
db.add(session)
await db.commit()
await db.refresh(session)
# 存储到 Redis
await redis.setex(
f"login_session:{session_token}",
86400, # 24小时
user_id
)
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
return session
@staticmethod
async def verify_session(
db: AsyncSession,
redis: Redis,
session_token: str,
verification_code: str
) -> tuple[bool, str]:
"""验证会话(通过邮件验证码)"""
try:
# 从 Redis 获取用户ID
user_id = await redis.get(f"login_session:{session_token}")
if not user_id:
return False, "会话无效或已过期"
user_id = int(user_id)
# 验证邮件验证码
success, message = await EmailVerificationService.verify_code(
db, redis, user_id, verification_code
)
if not success:
return False, message
# 更新会话状态
result = await db.exec(
select(LoginSession).where(
LoginSession.session_token == session_token,
LoginSession.user_id == user_id,
LoginSession.is_verified == False
)
)
session = result.first()
if session:
session.is_verified = True
session.verified_at = datetime.now(UTC)
await db.commit()
logger.info(f"[Login Session] User {user_id} session verification successful")
return True, "会话验证成功"
except Exception as e:
logger.error(f"[Login Session] Exception during session verification: {e}")
return False, "验证过程中发生错误"
@staticmethod
async def check_new_location(
db: AsyncSession,
user_id: int,
ip_address: str,
country_code: str | None = None
) -> bool:
"""检查是否为新位置登录"""
try:
# 查看过去30天内是否有相同IP或相同国家的登录记录
thirty_days_ago = datetime.now(UTC) - timedelta(days=30)
result = await db.exec(
select(LoginSession).where(
LoginSession.user_id == user_id,
LoginSession.created_at > thirty_days_ago,
(LoginSession.ip_address == ip_address) |
(LoginSession.country_code == country_code)
)
)
existing_sessions = result.all()
# 如果有历史记录,则不是新位置
return len(existing_sessions) == 0
except Exception as e:
logger.error(f"[Login Session] Exception during new location check: {e}")
# 出错时默认为新位置(更安全)
return True
@staticmethod
async def mark_session_verified(
db: AsyncSession,
user_id: int
) -> bool:
"""标记用户的未验证会话为已验证"""
try:
# 查找用户所有未验证且未过期的会话
result = await db.exec(
select(LoginSession).where(
LoginSession.user_id == user_id,
LoginSession.is_verified == False,
LoginSession.expires_at > datetime.now(UTC)
)
)
sessions = result.all()
# 标记所有会话为已验证
for session in sessions:
session.is_verified = True
session.verified_at = datetime.now(UTC)
if sessions:
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
return len(sessions) > 0
except Exception as e:
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
return False

View File

@@ -65,6 +65,17 @@ class RedisMessageSystem:
if not user.id:
raise ValueError("User ID is required")
# 获取频道类型以判断是否需要存储到数据库
async with with_db() as session:
from app.database.chat import ChatChannel, ChannelType
from sqlmodel import select
channel_result = await session.exec(
select(ChatChannel.type).where(ChatChannel.channel_id == channel_id)
)
channel_type = channel_result.first()
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
# 准备消息数据
message_data = {
"message_id": message_id,
@@ -76,6 +87,7 @@ class RedisMessageSystem:
"uuid": user_uuid or "",
"status": "cached", # Redis 缓存状态
"created_at": time.time(),
"is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
}
# 立即存储到 Redis
@@ -118,9 +130,14 @@ class RedisMessageSystem:
uuid=user_uuid,
)
logger.info(
f"Message {message_id} sent to Redis cache for channel {channel_id}"
)
if is_multiplayer:
logger.info(
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}, will not be persisted to database"
)
else:
logger.info(
f"Message {message_id} sent to Redis cache for channel {channel_id}"
)
return response
async def get_messages(
@@ -222,6 +239,9 @@ class RedisMessageSystem:
):
"""存储消息到 Redis"""
try:
# 检查是否是多人房间消息
is_multiplayer = message_data.get("is_multiplayer", False)
# 存储消息数据
await self._redis_exec(
self.redis.hset,
@@ -267,10 +287,14 @@ class RedisMessageSystem:
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
)
# 添加到待持久化队列
await self._redis_exec(
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
)
# 只有非多人房间消息才添加到待持久化队列
if not is_multiplayer:
await self._redis_exec(
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
)
logger.debug(f"Message {message_id} added to persistence queue")
else:
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
except Exception as e:
logger.error(f"Failed to store message to Redis: {e}")
@@ -475,6 +499,19 @@ class RedisMessageSystem:
v = v.decode("utf-8")
message_data[k] = v
# 检查是否是多人房间消息,如果是则跳过数据库存储
is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
if is_multiplayer:
# 多人房间消息不存储到数据库,直接标记为已跳过
await self._redis_exec(
self.redis.hset,
f"msg:{channel_id}:{message_id}",
"status",
"skipped_multiplayer",
)
logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
continue
# 检查消息是否已存在于数据库
existing = await session.get(ChatMessage, int(message_id))
if existing:

View File

@@ -0,0 +1,121 @@
"""
API 状态管理 - 模拟 osu! 的 APIState 和会话管理
"""
from __future__ import annotations
from enum import Enum
from typing import Optional
from datetime import datetime
from pydantic import BaseModel
class APIState(str, Enum):
"""API 连接状态,对应 osu! 的 APIState"""
OFFLINE = "offline"
CONNECTING = "connecting"
REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证
ONLINE = "online"
FAILING = "failing"
class UserSession(BaseModel):
"""用户会话信息"""
user_id: int
username: str
email: str
session_token: str | None = None
state: APIState = APIState.OFFLINE
requires_verification: bool = False
verification_sent: bool = False
last_verification_attempt: datetime | None = None
failed_attempts: int = 0
ip_address: str | None = None
country_code: str | None = None
is_new_location: bool = False
class SessionManager:
"""会话管理器"""
def __init__(self):
self._sessions: dict[str, UserSession] = {}
def create_session(
self,
user_id: int,
username: str,
email: str,
ip_address: str,
country_code: str | None = None,
is_new_location: bool = False
) -> UserSession:
"""创建新的用户会话"""
import secrets
session_token = secrets.token_urlsafe(32)
# 根据是否为新位置决定初始状态
if is_new_location:
state = APIState.REQUIRES_SECOND_FACTOR_AUTH
else:
state = APIState.ONLINE
session = UserSession(
user_id=user_id,
username=username,
email=email,
session_token=session_token,
state=state,
requires_verification=is_new_location,
ip_address=ip_address,
country_code=country_code,
is_new_location=is_new_location
)
self._sessions[session_token] = session
return session
def get_session(self, session_token: str) -> UserSession | None:
"""获取会话"""
return self._sessions.get(session_token)
def update_session_state(self, session_token: str, state: APIState):
"""更新会话状态"""
if session_token in self._sessions:
self._sessions[session_token].state = state
def mark_verification_sent(self, session_token: str):
"""标记验证邮件已发送"""
if session_token in self._sessions:
session = self._sessions[session_token]
session.verification_sent = True
session.last_verification_attempt = datetime.now()
def increment_failed_attempts(self, session_token: str):
"""增加失败尝试次数"""
if session_token in self._sessions:
self._sessions[session_token].failed_attempts += 1
def verify_session(self, session_token: str) -> bool:
"""验证会话成功"""
if session_token in self._sessions:
session = self._sessions[session_token]
session.state = APIState.ONLINE
session.requires_verification = False
return True
return False
def remove_session(self, session_token: str):
"""移除会话"""
self._sessions.pop(session_token, None)
def cleanup_expired_sessions(self):
"""清理过期会话"""
# 这里可以实现清理逻辑
pass
# 全局会话管理器
session_manager = SessionManager()