添加邮件验证
This commit is contained in:
289
app/service/database_cleanup_service.py
Normal file
289
app/service/database_cleanup_service.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
数据库清理服务 - 清理过期的验证码和会话
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, UTC, timedelta
|
||||
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.log import logger
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
class DatabaseCleanupService:
|
||||
"""数据库清理服务"""
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_verification_codes(db: AsyncSession) -> int:
|
||||
"""
|
||||
清理过期的邮件验证码
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找过期的验证码记录
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
stmt = select(EmailVerification).where(
|
||||
EmailVerification.expires_at < current_time
|
||||
)
|
||||
result = await db.exec(stmt)
|
||||
expired_codes = result.all()
|
||||
|
||||
# 删除过期的记录
|
||||
deleted_count = 0
|
||||
for code in expired_codes:
|
||||
await db.delete(code)
|
||||
deleted_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {str(e)}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_login_sessions(db: AsyncSession) -> int:
|
||||
"""
|
||||
清理过期的登录会话
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找过期的登录会话记录
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
stmt = select(LoginSession).where(
|
||||
LoginSession.expires_at < current_time
|
||||
)
|
||||
result = await db.exec(stmt)
|
||||
expired_sessions = result.all()
|
||||
|
||||
# 删除过期的记录
|
||||
deleted_count = 0
|
||||
for session in expired_sessions:
|
||||
await db.delete(session)
|
||||
deleted_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {str(e)}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int:
|
||||
"""
|
||||
清理旧的已使用验证码记录
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
days_old: 清理多少天前的已使用记录,默认7天
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找指定天数前的已使用验证码记录
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||
|
||||
stmt = select(EmailVerification).where(
|
||||
EmailVerification.is_used == True
|
||||
)
|
||||
result = await db.exec(stmt)
|
||||
all_used_codes = result.all()
|
||||
|
||||
# 筛选出过期的记录
|
||||
old_used_codes = [
|
||||
code for code in all_used_codes
|
||||
if code.used_at and code.used_at < cutoff_time
|
||||
]
|
||||
|
||||
# 删除旧的已使用记录
|
||||
deleted_count = 0
|
||||
for code in old_used_codes:
|
||||
await db.delete(code)
|
||||
deleted_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {str(e)}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
|
||||
"""
|
||||
清理旧的已验证会话记录
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
days_old: 清理多少天前的已验证记录,默认30天
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找指定天数前的已验证会话记录
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||
|
||||
stmt = select(LoginSession).where(
|
||||
LoginSession.is_verified == True
|
||||
)
|
||||
result = await db.exec(stmt)
|
||||
all_verified_sessions = result.all()
|
||||
|
||||
# 筛选出过期的记录
|
||||
old_verified_sessions = [
|
||||
session for session in all_verified_sessions
|
||||
if session.verified_at and session.verified_at < cutoff_time
|
||||
]
|
||||
|
||||
# 删除旧的已验证记录
|
||||
deleted_count = 0
|
||||
for session in old_verified_sessions:
|
||||
await db.delete(session)
|
||||
deleted_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {str(e)}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def run_full_cleanup(db: AsyncSession) -> dict[str, int]:
|
||||
"""
|
||||
运行完整的清理流程
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
dict: 各项清理的结果统计
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# 清理过期的验证码
|
||||
results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
|
||||
|
||||
# 清理过期的登录会话
|
||||
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||
|
||||
# 清理7天前的已使用验证码
|
||||
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
|
||||
|
||||
# 清理30天前的已验证会话
|
||||
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
|
||||
|
||||
total_cleaned = sum(results.values())
|
||||
if total_cleaned > 0:
|
||||
logger.debug(f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}")
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]:
|
||||
"""
|
||||
获取清理统计信息
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
dict: 统计信息
|
||||
"""
|
||||
try:
|
||||
current_time = datetime.now(UTC)
|
||||
cutoff_7_days = current_time - timedelta(days=7)
|
||||
cutoff_30_days = current_time - timedelta(days=30)
|
||||
|
||||
# 统计过期的验证码数量
|
||||
expired_codes_stmt = select(EmailVerification).where(
|
||||
EmailVerification.expires_at < current_time
|
||||
)
|
||||
expired_codes_result = await db.exec(expired_codes_stmt)
|
||||
expired_codes_count = len(expired_codes_result.all())
|
||||
|
||||
# 统计过期的登录会话数量
|
||||
expired_sessions_stmt = select(LoginSession).where(
|
||||
LoginSession.expires_at < current_time
|
||||
)
|
||||
expired_sessions_result = await db.exec(expired_sessions_stmt)
|
||||
expired_sessions_count = len(expired_sessions_result.all())
|
||||
|
||||
# 统计7天前的已使用验证码数量
|
||||
old_used_codes_stmt = select(EmailVerification).where(
|
||||
EmailVerification.is_used == True
|
||||
)
|
||||
old_used_codes_result = await db.exec(old_used_codes_stmt)
|
||||
all_used_codes = old_used_codes_result.all()
|
||||
old_used_codes_count = len([
|
||||
code for code in all_used_codes
|
||||
if code.used_at and code.used_at < cutoff_7_days
|
||||
])
|
||||
|
||||
# 统计30天前的已验证会话数量
|
||||
old_verified_sessions_stmt = select(LoginSession).where(
|
||||
LoginSession.is_verified == True
|
||||
)
|
||||
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
|
||||
all_verified_sessions = old_verified_sessions_result.all()
|
||||
old_verified_sessions_count = len([
|
||||
session for session in all_verified_sessions
|
||||
if session.verified_at and session.verified_at < cutoff_30_days
|
||||
])
|
||||
|
||||
return {
|
||||
"expired_verification_codes": expired_codes_count,
|
||||
"expired_login_sessions": expired_sessions_count,
|
||||
"old_used_verification_codes": old_used_codes_count,
|
||||
"old_verified_sessions": old_verified_sessions_count,
|
||||
"total_cleanable": expired_codes_count + expired_sessions_count + old_used_codes_count + old_verified_sessions_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {str(e)}")
|
||||
return {
|
||||
"expired_verification_codes": 0,
|
||||
"expired_login_sessions": 0,
|
||||
"old_used_verification_codes": 0,
|
||||
"old_verified_sessions": 0,
|
||||
"total_cleanable": 0
|
||||
}
|
||||
167
app/service/email_service.py
Normal file
167
app/service/email_service.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
邮件验证服务
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class EmailService:
|
||||
"""邮件发送服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
|
||||
self.smtp_port = getattr(settings, 'smtp_port', 587)
|
||||
self.smtp_username = getattr(settings, 'smtp_username', '')
|
||||
self.smtp_password = getattr(settings, 'smtp_password', '')
|
||||
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
|
||||
self.from_name = getattr(settings, 'from_name', 'osu! server')
|
||||
|
||||
def generate_verification_code(self) -> str:
|
||||
"""生成8位验证码"""
|
||||
# 只使用数字,避免混淆
|
||||
return ''.join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
async def send_verification_email(self, email: str, code: str, username: str) -> bool:
|
||||
"""发送验证邮件"""
|
||||
try:
|
||||
msg = MIMEMultipart()
|
||||
msg['From'] = f"{self.from_name} <{self.from_email}>"
|
||||
msg['To'] = email
|
||||
msg['Subject'] = "邮箱验证 - Email Verification"
|
||||
|
||||
# HTML 邮件内容
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<style>
|
||||
.container {{
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
font-family: Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
}}
|
||||
.header {{
|
||||
background: linear-gradient(135deg, #ff66aa, #ff9966);
|
||||
color: white;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
border-radius: 10px 10px 0 0;
|
||||
}}
|
||||
.content {{
|
||||
background: #f9f9f9;
|
||||
padding: 30px;
|
||||
border: 1px solid #ddd;
|
||||
}}
|
||||
.code {{
|
||||
background: #fff;
|
||||
border: 2px solid #ff66aa;
|
||||
border-radius: 8px;
|
||||
padding: 15px;
|
||||
text-align: center;
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
letter-spacing: 3px;
|
||||
margin: 20px 0;
|
||||
color: #333;
|
||||
}}
|
||||
.footer {{
|
||||
background: #333;
|
||||
color: #fff;
|
||||
padding: 15px;
|
||||
text-align: center;
|
||||
border-radius: 0 0 10px 10px;
|
||||
font-size: 12px;
|
||||
}}
|
||||
.warning {{
|
||||
background: #fff3cd;
|
||||
border: 1px solid #ffeaa7;
|
||||
border-radius: 5px;
|
||||
padding: 10px;
|
||||
margin: 15px 0;
|
||||
color: #856404;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1> osu! 邮箱验证</h1>
|
||||
<p>Email Verification</p>
|
||||
</div>
|
||||
|
||||
<div class="content">
|
||||
<h2>你好 {username}!</h2>
|
||||
<p>感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:</p>
|
||||
|
||||
<div class="code">{code}</div>
|
||||
|
||||
<p>这个验证码将在 <strong>10 分钟后过期</strong>。</p>
|
||||
|
||||
<div class="warning">
|
||||
<strong>注意:</strong>
|
||||
<ul>
|
||||
<li>请不要与任何人分享这个验证码</li>
|
||||
<li>如果你没有请求此验证码,请忽略这封邮件</li>
|
||||
<li>验证码只能使用一次</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<p>如果你有任何问题,请联系我们的支持团队。</p>
|
||||
|
||||
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
|
||||
|
||||
<h3>Hello {username}!</h3>
|
||||
<p>Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:</p>
|
||||
|
||||
<p>This verification code will expire in <strong>10 minutes</strong>.</p>
|
||||
|
||||
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
|
||||
</div>
|
||||
|
||||
<div class="footer">
|
||||
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
|
||||
<p>This email was sent automatically, please do not reply.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
|
||||
|
||||
# 发送邮件
|
||||
if not settings.enable_email_sending:
|
||||
# 邮件发送功能禁用时只记录日志,不实际发送
|
||||
logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}")
|
||||
return True
|
||||
|
||||
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
|
||||
if self.smtp_username and self.smtp_password:
|
||||
server.starttls()
|
||||
server.login(self.smtp_username, self.smtp_password)
|
||||
|
||||
server.send_message(msg)
|
||||
|
||||
logger.info(f"[Email Verification] Successfully sent verification code to {email}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Failed to send email: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局邮件服务实例
|
||||
email_service = EmailService()
|
||||
367
app/service/email_verification_service.py
Normal file
367
app/service/email_verification_service.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
邮件验证管理服务
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.service.email_service import email_service
|
||||
from app.log import logger
|
||||
from app.config import settings
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel import select
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class EmailVerificationService:
|
||||
"""邮件验证服务"""
|
||||
|
||||
@staticmethod
|
||||
def generate_verification_code() -> str:
|
||||
"""生成8位验证码"""
|
||||
return ''.join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
@staticmethod
|
||||
def generate_session_token() -> str:
|
||||
"""生成会话令牌"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
async def create_verification_record(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
) -> tuple[EmailVerification, str]:
|
||||
"""创建邮件验证记录"""
|
||||
|
||||
# 检查是否有未过期的验证码
|
||||
existing_result = await db.exec(
|
||||
select(EmailVerification).where(
|
||||
EmailVerification.user_id == user_id,
|
||||
EmailVerification.is_used == False,
|
||||
EmailVerification.expires_at > datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
existing = existing_result.first()
|
||||
|
||||
if existing:
|
||||
# 如果有未过期的验证码,直接返回
|
||||
return existing, existing.verification_code
|
||||
|
||||
# 生成新的验证码
|
||||
code = EmailVerificationService.generate_verification_code()
|
||||
|
||||
# 创建验证记录
|
||||
verification = EmailVerification(
|
||||
user_id=user_id,
|
||||
email=email,
|
||||
verification_code=code,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
db.add(verification)
|
||||
await db.commit()
|
||||
await db.refresh(verification)
|
||||
|
||||
# 存储到 Redis(用于快速验证)
|
||||
await redis.setex(
|
||||
f"email_verification:{user_id}:{code}",
|
||||
600, # 10分钟过期
|
||||
str(verification.id) if verification.id else "0"
|
||||
)
|
||||
|
||||
logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}")
|
||||
return verification, code
|
||||
|
||||
@staticmethod
|
||||
async def send_verification_email(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
) -> bool:
|
||||
"""发送验证邮件"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
|
||||
return True # 返回成功,但不执行验证流程
|
||||
|
||||
# 创建验证记录
|
||||
verification, code = await EmailVerificationService.create_verification_record(
|
||||
db, redis, user_id, email, ip_address, user_agent
|
||||
)
|
||||
|
||||
# 发送邮件
|
||||
success = await email_service.send_verification_email(email, code, username)
|
||||
|
||||
if success:
|
||||
logger.info(f"[Email Verification] Successfully sent verification email to {email} (user: {username})")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[Email Verification] Failed to send verification email: {email} (user: {username})")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def verify_code(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
code: str,
|
||||
ip_address: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""验证验证码"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}")
|
||||
# 仍然标记登录会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
return True, "验证成功(邮件验证功能已禁用)"
|
||||
|
||||
# 先从 Redis 检查
|
||||
verification_id = await redis.get(f"email_verification:{user_id}:{code}")
|
||||
if not verification_id:
|
||||
return False, "验证码无效或已过期"
|
||||
|
||||
# 从数据库获取验证记录
|
||||
result = await db.exec(
|
||||
select(EmailVerification).where(
|
||||
EmailVerification.id == int(verification_id),
|
||||
EmailVerification.user_id == user_id,
|
||||
EmailVerification.verification_code == code,
|
||||
EmailVerification.is_used == False,
|
||||
EmailVerification.expires_at > datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
|
||||
verification = result.first()
|
||||
if not verification:
|
||||
return False, "验证码无效或已过期"
|
||||
|
||||
# 标记为已使用
|
||||
verification.is_used = True
|
||||
verification.used_at = datetime.now(UTC)
|
||||
|
||||
# 同时更新对应的登录会话状态
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# 删除 Redis 记录
|
||||
await redis.delete(f"email_verification:{user_id}:{code}")
|
||||
|
||||
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
|
||||
return True, "验证成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during verification code validation: {e}")
|
||||
return False, "验证过程中发生错误"
|
||||
|
||||
@staticmethod
|
||||
async def resend_verification_code(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
) -> tuple[bool, str]:
|
||||
"""重新发送验证码"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
|
||||
return True, "验证码已发送(邮件验证功能已禁用)"
|
||||
|
||||
# 检查重发频率限制(60秒内只能发送一次)
|
||||
rate_limit_key = f"email_verification_rate_limit:{user_id}"
|
||||
if await redis.get(rate_limit_key):
|
||||
return False, "请等待60秒后再重新发送"
|
||||
|
||||
# 设置频率限制
|
||||
await redis.setex(rate_limit_key, 60, "1")
|
||||
|
||||
# 生成新的验证码
|
||||
success = await EmailVerificationService.send_verification_email(
|
||||
db, redis, user_id, username, email, ip_address, user_agent
|
||||
)
|
||||
|
||||
if success:
|
||||
return True, "验证码已重新发送"
|
||||
else:
|
||||
return False, "重新发送失败,请稍后再试"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during resending verification code: {e}")
|
||||
return False, "重新发送过程中发生错误"
|
||||
|
||||
|
||||
class LoginSessionService:
|
||||
"""登录会话服务"""
|
||||
|
||||
@staticmethod
|
||||
async def create_session(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
ip_address: str,
|
||||
user_agent: str | None = None,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False
|
||||
) -> LoginSession:
|
||||
"""创建登录会话"""
|
||||
session_token = EmailVerificationService.generate_session_token()
|
||||
|
||||
session = LoginSession(
|
||||
user_id=user_id,
|
||||
session_token=session_token,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location,
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期
|
||||
is_verified=not is_new_location # 新位置需要验证
|
||||
)
|
||||
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
# 存储到 Redis
|
||||
await redis.setex(
|
||||
f"login_session:{session_token}",
|
||||
86400, # 24小时
|
||||
user_id
|
||||
)
|
||||
|
||||
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
async def verify_session(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
session_token: str,
|
||||
verification_code: str
|
||||
) -> tuple[bool, str]:
|
||||
"""验证会话(通过邮件验证码)"""
|
||||
try:
|
||||
# 从 Redis 获取用户ID
|
||||
user_id = await redis.get(f"login_session:{session_token}")
|
||||
if not user_id:
|
||||
return False, "会话无效或已过期"
|
||||
|
||||
user_id = int(user_id)
|
||||
|
||||
# 验证邮件验证码
|
||||
success, message = await EmailVerificationService.verify_code(
|
||||
db, redis, user_id, verification_code
|
||||
)
|
||||
|
||||
if not success:
|
||||
return False, message
|
||||
|
||||
# 更新会话状态
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.session_token == session_token,
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.is_verified == False
|
||||
)
|
||||
)
|
||||
|
||||
session = result.first()
|
||||
if session:
|
||||
session.is_verified = True
|
||||
session.verified_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"[Login Session] User {user_id} session verification successful")
|
||||
return True, "会话验证成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during session verification: {e}")
|
||||
return False, "验证过程中发生错误"
|
||||
|
||||
@staticmethod
|
||||
async def check_new_location(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
ip_address: str,
|
||||
country_code: str | None = None
|
||||
) -> bool:
|
||||
"""检查是否为新位置登录"""
|
||||
try:
|
||||
# 查看过去30天内是否有相同IP或相同国家的登录记录
|
||||
thirty_days_ago = datetime.now(UTC) - timedelta(days=30)
|
||||
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.created_at > thirty_days_ago,
|
||||
(LoginSession.ip_address == ip_address) |
|
||||
(LoginSession.country_code == country_code)
|
||||
)
|
||||
)
|
||||
|
||||
existing_sessions = result.all()
|
||||
|
||||
# 如果有历史记录,则不是新位置
|
||||
return len(existing_sessions) == 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during new location check: {e}")
|
||||
# 出错时默认为新位置(更安全)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def mark_session_verified(
|
||||
db: AsyncSession,
|
||||
user_id: int
|
||||
) -> bool:
|
||||
"""标记用户的未验证会话为已验证"""
|
||||
try:
|
||||
# 查找用户所有未验证且未过期的会话
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.is_verified == False,
|
||||
LoginSession.expires_at > datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
|
||||
sessions = result.all()
|
||||
|
||||
# 标记所有会话为已验证
|
||||
for session in sessions:
|
||||
session.is_verified = True
|
||||
session.verified_at = datetime.now(UTC)
|
||||
|
||||
if sessions:
|
||||
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
|
||||
|
||||
return len(sessions) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
|
||||
return False
|
||||
@@ -65,6 +65,17 @@ class RedisMessageSystem:
|
||||
if not user.id:
|
||||
raise ValueError("User ID is required")
|
||||
|
||||
# 获取频道类型以判断是否需要存储到数据库
|
||||
async with with_db() as session:
|
||||
from app.database.chat import ChatChannel, ChannelType
|
||||
from sqlmodel import select
|
||||
|
||||
channel_result = await session.exec(
|
||||
select(ChatChannel.type).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
channel_type = channel_result.first()
|
||||
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
|
||||
|
||||
# 准备消息数据
|
||||
message_data = {
|
||||
"message_id": message_id,
|
||||
@@ -76,6 +87,7 @@ class RedisMessageSystem:
|
||||
"uuid": user_uuid or "",
|
||||
"status": "cached", # Redis 缓存状态
|
||||
"created_at": time.time(),
|
||||
"is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
|
||||
}
|
||||
|
||||
# 立即存储到 Redis
|
||||
@@ -118,9 +130,14 @@ class RedisMessageSystem:
|
||||
uuid=user_uuid,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Message {message_id} sent to Redis cache for channel {channel_id}"
|
||||
)
|
||||
if is_multiplayer:
|
||||
logger.info(
|
||||
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}, will not be persisted to database"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Message {message_id} sent to Redis cache for channel {channel_id}"
|
||||
)
|
||||
return response
|
||||
|
||||
async def get_messages(
|
||||
@@ -222,6 +239,9 @@ class RedisMessageSystem:
|
||||
):
|
||||
"""存储消息到 Redis"""
|
||||
try:
|
||||
# 检查是否是多人房间消息
|
||||
is_multiplayer = message_data.get("is_multiplayer", False)
|
||||
|
||||
# 存储消息数据
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
@@ -267,10 +287,14 @@ class RedisMessageSystem:
|
||||
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
|
||||
)
|
||||
|
||||
# 添加到待持久化队列
|
||||
await self._redis_exec(
|
||||
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
|
||||
)
|
||||
# 只有非多人房间消息才添加到待持久化队列
|
||||
if not is_multiplayer:
|
||||
await self._redis_exec(
|
||||
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
|
||||
)
|
||||
logger.debug(f"Message {message_id} added to persistence queue")
|
||||
else:
|
||||
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store message to Redis: {e}")
|
||||
@@ -475,6 +499,19 @@ class RedisMessageSystem:
|
||||
v = v.decode("utf-8")
|
||||
message_data[k] = v
|
||||
|
||||
# 检查是否是多人房间消息,如果是则跳过数据库存储
|
||||
is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
|
||||
if is_multiplayer:
|
||||
# 多人房间消息不存储到数据库,直接标记为已跳过
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
"status",
|
||||
"skipped_multiplayer",
|
||||
)
|
||||
logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
|
||||
continue
|
||||
|
||||
# 检查消息是否已存在于数据库
|
||||
existing = await session.get(ChatMessage, int(message_id))
|
||||
if existing:
|
||||
|
||||
121
app/service/session_manager.py
Normal file
121
app/service/session_manager.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
API 状态管理 - 模拟 osu! 的 APIState 和会话管理
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class APIState(str, Enum):
|
||||
"""API 连接状态,对应 osu! 的 APIState"""
|
||||
OFFLINE = "offline"
|
||||
CONNECTING = "connecting"
|
||||
REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证
|
||||
ONLINE = "online"
|
||||
FAILING = "failing"
|
||||
|
||||
|
||||
class UserSession(BaseModel):
|
||||
"""用户会话信息"""
|
||||
user_id: int
|
||||
username: str
|
||||
email: str
|
||||
session_token: str | None = None
|
||||
state: APIState = APIState.OFFLINE
|
||||
requires_verification: bool = False
|
||||
verification_sent: bool = False
|
||||
last_verification_attempt: datetime | None = None
|
||||
failed_attempts: int = 0
|
||||
ip_address: str | None = None
|
||||
country_code: str | None = None
|
||||
is_new_location: bool = False
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._sessions: dict[str, UserSession] = {}
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_id: int,
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False
|
||||
) -> UserSession:
|
||||
"""创建新的用户会话"""
|
||||
import secrets
|
||||
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
|
||||
# 根据是否为新位置决定初始状态
|
||||
if is_new_location:
|
||||
state = APIState.REQUIRES_SECOND_FACTOR_AUTH
|
||||
else:
|
||||
state = APIState.ONLINE
|
||||
|
||||
session = UserSession(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
email=email,
|
||||
session_token=session_token,
|
||||
state=state,
|
||||
requires_verification=is_new_location,
|
||||
ip_address=ip_address,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location
|
||||
)
|
||||
|
||||
self._sessions[session_token] = session
|
||||
return session
|
||||
|
||||
def get_session(self, session_token: str) -> UserSession | None:
|
||||
"""获取会话"""
|
||||
return self._sessions.get(session_token)
|
||||
|
||||
def update_session_state(self, session_token: str, state: APIState):
|
||||
"""更新会话状态"""
|
||||
if session_token in self._sessions:
|
||||
self._sessions[session_token].state = state
|
||||
|
||||
def mark_verification_sent(self, session_token: str):
|
||||
"""标记验证邮件已发送"""
|
||||
if session_token in self._sessions:
|
||||
session = self._sessions[session_token]
|
||||
session.verification_sent = True
|
||||
session.last_verification_attempt = datetime.now()
|
||||
|
||||
def increment_failed_attempts(self, session_token: str):
|
||||
"""增加失败尝试次数"""
|
||||
if session_token in self._sessions:
|
||||
self._sessions[session_token].failed_attempts += 1
|
||||
|
||||
def verify_session(self, session_token: str) -> bool:
|
||||
"""验证会话成功"""
|
||||
if session_token in self._sessions:
|
||||
session = self._sessions[session_token]
|
||||
session.state = APIState.ONLINE
|
||||
session.requires_verification = False
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_session(self, session_token: str):
|
||||
"""移除会话"""
|
||||
self._sessions.pop(session_token, None)
|
||||
|
||||
def cleanup_expired_sessions(self):
|
||||
"""清理过期会话"""
|
||||
# 这里可以实现清理逻辑
|
||||
pass
|
||||
|
||||
|
||||
# 全局会话管理器
|
||||
session_manager = SessionManager()
|
||||
Reference in New Issue
Block a user