diff --git a/.env.example b/.env.example
index 20ad3cd..56febcc 100644
--- a/.env.example
+++ b/.env.example
@@ -45,6 +45,18 @@ FETCHER_SCOPES=public
# 日志设置
LOG_LEVEL="INFO"
+# 邮件服务设置
+SMTP_SERVER="smtp.gmail.com" # SMTP 服务器地址
+SMTP_PORT=587 # SMTP 端口
+SMTP_USERNAME="your-email@gmail.com" # 邮箱用户名
+SMTP_PASSWORD="your-app-password" # 邮箱密码或应用专用密码
+FROM_EMAIL="noreply@your-server.com" # 发送方邮箱
+FROM_NAME="osu! Private Server" # 发送方名称
+
+# 邮件验证功能开关
+ENABLE_EMAIL_VERIFICATION=true # 是否启用邮件验证功能(新位置登录时需要邮件验证)
+ENABLE_EMAIL_SENDING=false # 是否真实发送邮件(false时仅模拟发送,输出到日志)
+
# Sentry 设置,为空表示不启用
SENTRY_DSN
diff --git a/app/config.py b/app/config.py
index 26d2bec..42140b7 100644
--- a/app/config.py
+++ b/app/config.py
@@ -117,6 +117,22 @@ class Settings(BaseSettings):
# 日志设置
log_level: str = "INFO"
+ # 邮件服务设置
+ smtp_server: str = "localhost"
+ smtp_port: int = 587
+ smtp_username: str = ""
+ smtp_password: str = ""
+ from_email: str = "noreply@example.com"
+ from_name: str = "osu! server"
+
+ # 邮件验证功能开关
+ enable_email_verification: bool = Field(
+ default=True, description="是否启用邮件验证功能"
+ )
+ enable_email_sending: bool = Field(
+ default=False, description="是否真实发送邮件(False时仅模拟发送)"
+ )
+
# Sentry 配置
sentry_dsn: HttpUrl | None = None
diff --git a/app/database/__init__.py b/app/database/__init__.py
index 6ff1d21..1fd0f00 100644
--- a/app/database/__init__.py
+++ b/app/database/__init__.py
@@ -23,6 +23,7 @@ from .counts import (
ReplayWatchedCount,
)
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
+from .email_verification import EmailVerification, LoginSession
from .favourite_beatmapset import FavouriteBeatmapset
from .lazer_user import (
User,
diff --git a/app/database/email_verification.py b/app/database/email_verification.py
new file mode 100644
index 0000000..baf053b
--- /dev/null
+++ b/app/database/email_verification.py
@@ -0,0 +1,44 @@
+"""
+邮件验证相关数据库模型
+"""
+
+from __future__ import annotations
+
+from datetime import datetime, UTC
+from sqlmodel import SQLModel, Field
+from sqlalchemy import Column, BigInteger, ForeignKey
+
+
+class EmailVerification(SQLModel, table=True):
+ """邮件验证记录"""
+
+ __tablename__: str = "email_verifications"
+
+ id: int | None = Field(default=None, primary_key=True)
+ user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
+ email: str = Field(index=True)
+ verification_code: str = Field(max_length=8) # 8位验证码
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
+ expires_at: datetime = Field() # 验证码过期时间
+ is_used: bool = Field(default=False) # 是否已使用
+ used_at: datetime | None = Field(default=None)
+ ip_address: str | None = Field(default=None) # 请求IP
+ user_agent: str | None = Field(default=None) # 用户代理
+
+
+class LoginSession(SQLModel, table=True):
+ """登录会话记录"""
+
+ __tablename__: str = "login_sessions"
+
+ id: int | None = Field(default=None, primary_key=True)
+ user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
+ session_token: str = Field(unique=True, index=True) # 会话令牌
+ ip_address: str = Field() # 登录IP
+ user_agent: str | None = Field(default=None)
+ country_code: str | None = Field(default=None)
+ is_verified: bool = Field(default=False) # 是否已验证
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
+ verified_at: datetime | None = Field(default=None)
+ expires_at: datetime = Field() # 会话过期时间
+ is_new_location: bool = Field(default=False) # 是否新位置登录
diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py
index 232b00d..e215dab 100644
--- a/app/database/lazer_user.py
+++ b/app/database/lazer_user.py
@@ -475,6 +475,25 @@ class UserResp(UserBase):
)
).one()
+ # 检查会话验证状态
+ # 如果邮件验证功能被禁用,则始终设置 session_verified 为 true
+ from app.config import settings
+ if not settings.enable_email_verification:
+ u.session_verified = True
+ else:
+ # 如果用户有未验证的登录会话,则设置 session_verified 为 false
+ from .email_verification import LoginSession
+ unverified_session = (
+ await session.exec(
+ select(LoginSession).where(
+ LoginSession.user_id == obj.id,
+ LoginSession.is_verified == False,
+ LoginSession.expires_at > datetime.now(UTC)
+ )
+ )
+ ).first()
+ u.session_verified = unverified_session is None
+
return u
diff --git a/app/models/api_me.py b/app/models/api_me.py
new file mode 100644
index 0000000..8e632e5
--- /dev/null
+++ b/app/models/api_me.py
@@ -0,0 +1,17 @@
+"""
+APIMe 响应模型 - 对应 osu! 的 APIMe 类型
+"""
+
+from __future__ import annotations
+
+from app.database.lazer_user import UserResp
+
+
+class APIMe(UserResp):
+ """
+ /me 端点的响应模型
+ 对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段
+
+ session_verified 字段已经在 UserResp 中定义,这里不需要重复定义
+ """
+ pass
diff --git a/app/models/extended_auth.py b/app/models/extended_auth.py
new file mode 100644
index 0000000..e98ea7d
--- /dev/null
+++ b/app/models/extended_auth.py
@@ -0,0 +1,31 @@
+"""
+扩展的 OAuth 响应模型,支持二次验证
+"""
+
+from __future__ import annotations
+
+from pydantic import BaseModel
+
+
+class ExtendedTokenResponse(BaseModel):
+ """扩展的令牌响应,支持二次验证状态"""
+ access_token: str | None = None
+ token_type: str = "Bearer"
+ expires_in: int | None = None
+ refresh_token: str | None = None
+ scope: str | None = None
+
+ # 二次验证相关字段
+ requires_second_factor: bool = False
+ verification_message: str | None = None
+ user_id: int | None = None # 用于二次验证的用户ID
+
+
+class SessionState(BaseModel):
+ """会话状态"""
+ user_id: int
+ username: str
+ email: str
+ requires_verification: bool
+ session_token: str | None = None
+ verification_sent: bool = False
diff --git a/app/router/auth.py b/app/router/auth.py
index 4120c46..a7717e3 100644
--- a/app/router/auth.py
+++ b/app/router/auth.py
@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
import re
-from typing import Literal
+from typing import Literal, Union
from app.auth import (
authenticate_user,
@@ -28,8 +28,13 @@ from app.models.oauth import (
TokenResponse,
UserRegistrationErrors,
)
+from app.models.extended_auth import ExtendedTokenResponse
from app.models.score import GameMode
from app.service.login_log_service import LoginLogService
+from app.service.email_verification_service import (
+ EmailVerificationService,
+ LoginSessionService
+)
from fastapi import APIRouter, Depends, Form, Request
from fastapi.responses import JSONResponse
@@ -198,7 +203,7 @@ async def register_user(
@router.post(
"/oauth/token",
- response_model=TokenResponse,
+ response_model=Union[TokenResponse, ExtendedTokenResponse],
name="获取访问令牌",
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
)
@@ -218,6 +223,7 @@ async def oauth_token(
None, description="刷新令牌(仅刷新令牌模式需要)"
),
redis: Redis = Depends(get_redis),
+ geoip: GeoIPHelper = Depends(get_geoip_helper),
):
scopes = scope.split(" ")
@@ -295,17 +301,68 @@ async def oauth_token(
# 确保用户对象与当前会话关联
await db.refresh(user)
- # 记录成功的登录
+ # 获取用户信息和客户端信息
user_id = getattr(user, "id")
assert user_id is not None, "User ID should not be None after authentication"
- await LoginLogService.record_login(
- db=db,
- user_id=user_id,
- request=request,
- login_success=True,
- login_method="password",
- notes=f"OAuth password grant for client {client_id}",
+
+ from app.dependencies.geoip import get_client_ip
+ ip_address = get_client_ip(request)
+ user_agent = request.headers.get("User-Agent", "")
+
+ # 获取国家代码
+ geo_info = geoip.lookup(ip_address)
+ country_code = geo_info.get("country_iso", "XX")
+
+ # 检查是否为新位置登录
+ is_new_location = await LoginSessionService.check_new_location(
+ db, user_id, ip_address, country_code
)
+
+ # 创建登录会话记录
+ login_session = await LoginSessionService.create_session(
+ db, redis, user_id, ip_address, user_agent, country_code, is_new_location
+ )
+
+ # 如果是新位置登录,需要邮件验证
+ if is_new_location and settings.enable_email_verification:
+ # 刷新用户对象以确保属性已加载
+ await db.refresh(user)
+
+ # 发送邮件验证码
+ verification_sent = await EmailVerificationService.send_verification_email(
+ db, redis, user_id, user.username, user.email, ip_address, user_agent
+ )
+
+ # 记录需要二次验证的登录尝试
+ await LoginLogService.record_login(
+ db=db,
+ user_id=user_id,
+ request=request,
+ login_success=True,
+ login_method="password_pending_verification",
+ notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}",
+ )
+
+ if not verification_sent:
+ # 邮件发送失败,记录错误
+ logger.error(f"[Auth] Failed to send email verification code for user {user_id}")
+ elif is_new_location and not settings.enable_email_verification:
+ # 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
+ await LoginSessionService.mark_session_verified(db, user_id)
+ logger.debug(f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}")
+ else:
+ # 不是新位置登录,正常登录
+ await LoginLogService.record_login(
+ db=db,
+ user_id=user_id,
+ request=request,
+ login_success=True,
+ login_method="password",
+ notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}",
+ )
+
+ # 无论是否新位置登录,都返回正常的token
+ # session_verified状态通过/me接口的session_verified字段来体现
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
diff --git a/app/router/v2/__init__.py b/app/router/v2/__init__.py
index 1cd8d16..8e59b4e 100644
--- a/app/router/v2/__init__.py
+++ b/app/router/v2/__init__.py
@@ -9,6 +9,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
relationship,
room,
score,
+ session_verify,
user,
)
from .router import router as api_v2_router
diff --git a/app/router/v2/me.py b/app/router/v2/me.py
index 1ca9097..e7b80ae 100644
--- a/app/router/v2/me.py
+++ b/app/router/v2/me.py
@@ -1,10 +1,11 @@
from __future__ import annotations
-from app.database import User, UserResp
+from app.database import User
from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user
from app.dependencies.database import Database
from app.models.score import GameMode
+from app.models.api_me import APIMe
from .router import router
@@ -13,7 +14,7 @@ from fastapi import Path, Security
@router.get(
"/me/{ruleset}",
- response_model=UserResp,
+ response_model=APIMe,
name="获取当前用户信息 (指定 ruleset)",
description="获取当前登录用户信息 (含指定 ruleset 统计)。",
tags=["用户"],
@@ -23,17 +24,18 @@ async def get_user_info_with_ruleset(
ruleset: GameMode = Path(description="指定 ruleset"),
current_user: User = Security(get_current_user, scopes=["identify"]),
):
- return await UserResp.from_db(
+ user_resp = await APIMe.from_db(
current_user,
session,
ALL_INCLUDED,
ruleset,
)
+ return user_resp
@router.get(
"/me/",
- response_model=UserResp,
+ response_model=APIMe,
name="获取当前用户信息",
description="获取当前登录用户信息。",
tags=["用户"],
@@ -42,9 +44,10 @@ async def get_user_info_default(
session: Database,
current_user: User = Security(get_current_user, scopes=["identify"]),
):
- return await UserResp.from_db(
+ user_resp = await APIMe.from_db(
current_user,
session,
ALL_INCLUDED,
None,
)
+ return user_resp
diff --git a/app/router/v2/session_verify.py b/app/router/v2/session_verify.py
new file mode 100644
index 0000000..5203785
--- /dev/null
+++ b/app/router/v2/session_verify.py
@@ -0,0 +1,204 @@
+"""
+会话验证路由 - 实现类似 osu! 的邮件验证流程 (API v2)
+"""
+
+from __future__ import annotations
+
+from datetime import datetime, UTC
+from typing import Annotated
+
+from app.auth import authenticate_user
+from app.config import settings
+from app.database import User
+from app.dependencies import get_current_user
+from app.dependencies.database import Database, get_redis
+from app.dependencies.geoip import GeoIPHelper, get_geoip_helper
+from app.database.email_verification import EmailVerification, LoginSession
+from app.service.email_verification_service import (
+ EmailVerificationService,
+ LoginSessionService
+)
+from app.service.login_log_service import LoginLogService
+from app.models.extended_auth import ExtendedTokenResponse
+
+from fastapi import Form, Depends, Request, HTTPException, status, Security
+from fastapi.responses import JSONResponse, Response
+from pydantic import BaseModel
+from redis.asyncio import Redis
+from sqlmodel import select
+
+from .router import router
+
+
+class SessionReissueResponse(BaseModel):
+ """重新发送验证码响应"""
+ success: bool
+ message: str
+
+
+@router.post(
+ "/session/verify",
+ name="验证会话",
+ description="验证邮件验证码并完成会话认证",
+ status_code=204
+)
+async def verify_session(
+ request: Request,
+ db: Database,
+ redis: Annotated[Redis, Depends(get_redis)],
+ verification_key: str = Form(..., description="8位邮件验证码"),
+ current_user: User = Security(get_current_user)
+) -> Response:
+ """
+ 验证邮件验证码并完成会话认证
+
+ 对应 osu! 的 session/verify 接口
+ 成功时返回 204 No Content,失败时返回 401 Unauthorized
+ """
+ try:
+ from app.dependencies.geoip import get_client_ip
+ ip_address = get_client_ip(request)
+ user_agent = request.headers.get("User-Agent", "Unknown")
+
+ # 从当前认证用户获取信息
+ user_id = current_user.id
+ if not user_id:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="用户未认证"
+ )
+
+ # 验证邮件验证码
+ success, message = await EmailVerificationService.verify_code(
+ db, redis, user_id, verification_key
+ )
+
+ if success:
+ # 记录成功的邮件验证
+ await LoginLogService.record_login(
+ db=db,
+ user_id=user_id,
+ request=request,
+ login_method="email_verification",
+ login_success=True,
+ notes=f"邮件验证成功"
+ )
+
+ # 返回 204 No Content 表示验证成功
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
+ else:
+ # 记录失败的邮件验证尝试
+ await LoginLogService.record_failed_login(
+ db=db,
+ request=request,
+ attempted_username=current_user.username,
+ login_method="email_verification",
+ notes=f"邮件验证失败: {message}"
+ )
+
+ # 返回 401 Unauthorized 表示验证失败
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=message
+ )
+
+ except ValueError:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="无效的用户会话"
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="验证过程中发生错误"
+ )
+
+
+@router.post(
+ "/session/verify/reissue",
+ name="重新发送验证码",
+ description="重新发送邮件验证码",
+ response_model=SessionReissueResponse
+)
+async def reissue_verification_code(
+ request: Request,
+ db: Database,
+ redis: Annotated[Redis, Depends(get_redis)],
+ current_user: User = Security(get_current_user)
+) -> SessionReissueResponse:
+ """
+ 重新发送邮件验证码
+
+ 对应 osu! 的 session/verify/reissue 接口
+ """
+ try:
+ from app.dependencies.geoip import get_client_ip
+ ip_address = get_client_ip(request)
+ user_agent = request.headers.get("User-Agent", "Unknown")
+
+ # 从当前认证用户获取信息
+ user_id = current_user.id
+ if not user_id:
+ return SessionReissueResponse(
+ success=False,
+ message="用户未认证"
+ )
+
+ # 重新发送验证码
+ success, message = await EmailVerificationService.resend_verification_code(
+ db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
+ )
+
+ return SessionReissueResponse(
+ success=success,
+ message=message
+ )
+
+ except ValueError:
+ return SessionReissueResponse(
+ success=False,
+ message="无效的用户会话"
+ )
+ except Exception as e:
+ return SessionReissueResponse(
+ success=False,
+ message="重新发送过程中发生错误"
+ )
+
+
+@router.post(
+ "/session/check-new-location",
+ name="检查新位置登录",
+ description="检查登录是否来自新位置(内部接口)"
+)
+async def check_new_location(
+ request: Request,
+ db: Database,
+ user_id: int,
+ geoip: GeoIPHelper = Depends(get_geoip_helper),
+):
+ """
+ 检查是否为新位置登录
+ 这是一个内部接口,用于登录流程中判断是否需要邮件验证
+ """
+ try:
+ from app.dependencies.geoip import get_client_ip
+ ip_address = get_client_ip(request)
+ geo_info = geoip.lookup(ip_address)
+ country_code = geo_info.get("country_iso", "XX")
+
+ is_new_location = await LoginSessionService.check_new_location(
+ db, user_id, ip_address, country_code
+ )
+
+ return {
+ "is_new_location": is_new_location,
+ "ip_address": ip_address,
+ "country_code": country_code
+ }
+
+ except Exception as e:
+ return {
+ "is_new_location": True, # 出错时默认为新位置
+ "error": str(e)
+ }
diff --git a/app/scheduler/database_cleanup_scheduler.py b/app/scheduler/database_cleanup_scheduler.py
new file mode 100644
index 0000000..4dfb21a
--- /dev/null
+++ b/app/scheduler/database_cleanup_scheduler.py
@@ -0,0 +1,118 @@
+"""
+数据库清理调度器 - 定时清理过期数据
+"""
+
+from __future__ import annotations
+
+import asyncio
+from datetime import datetime
+
+from app.config import settings
+from app.dependencies.database import engine
+from app.log import logger
+from app.service.database_cleanup_service import DatabaseCleanupService
+
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+
+class DatabaseCleanupScheduler:
+ """数据库清理调度器"""
+
+ def __init__(self):
+ self.running = False
+ self.task = None
+
+ async def start(self):
+ """启动调度器"""
+ if self.running:
+ return
+
+ self.running = True
+ self.task = asyncio.create_task(self._run_scheduler())
+ logger.debug("Database cleanup scheduler started")
+
+ async def stop(self):
+ """停止调度器"""
+ if not self.running:
+ return
+
+ self.running = False
+ if self.task:
+ self.task.cancel()
+ try:
+ await self.task
+ except asyncio.CancelledError:
+ pass
+ logger.debug("Database cleanup scheduler stopped")
+
+ async def _run_scheduler(self):
+ """运行调度器"""
+ while self.running:
+ try:
+ # 每小时运行一次清理
+ await asyncio.sleep(3600) # 3600秒 = 1小时
+
+ if not self.running:
+ break
+
+ await self._run_cleanup()
+
+ except asyncio.CancelledError:
+ break
+ except Exception as e:
+ logger.error(f"Database cleanup scheduler error: {str(e)}")
+ # 发生错误后等待5分钟再继续
+ await asyncio.sleep(300)
+
+ async def _run_cleanup(self):
+ """执行清理任务"""
+ try:
+ async with AsyncSession(engine) as db:
+ logger.debug("Starting scheduled database cleanup...")
+
+ # 清理过期的验证码
+ expired_codes = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
+
+ # 清理过期的登录会话
+ expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
+
+ # 只在有清理记录时输出总结
+ total_cleaned = expired_codes + expired_sessions
+ if total_cleaned > 0:
+ logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}")
+
+ except Exception as e:
+ logger.error(f"Error during scheduled database cleanup: {str(e)}")
+
+ async def run_manual_cleanup(self):
+ """手动运行完整清理"""
+ try:
+ async with AsyncSession(engine) as db:
+ logger.debug("Starting manual database cleanup...")
+ results = await DatabaseCleanupService.run_full_cleanup(db)
+ total = sum(results.values())
+ if total > 0:
+ logger.debug(f"Manual cleanup completed, total records cleaned: {total}")
+ return results
+ except Exception as e:
+ logger.error(f"Error during manual database cleanup: {str(e)}")
+ return {}
+
+
+# 全局实例
+database_cleanup_scheduler = DatabaseCleanupScheduler()
+
+
+async def start_database_cleanup_scheduler():
+ """启动数据库清理调度器"""
+ await database_cleanup_scheduler.start()
+
+
+async def stop_database_cleanup_scheduler():
+ """停止数据库清理调度器"""
+ await database_cleanup_scheduler.stop()
+
+
+async def run_manual_database_cleanup():
+ """手动运行数据库清理"""
+ return await database_cleanup_scheduler.run_manual_cleanup()
diff --git a/app/service/database_cleanup_service.py b/app/service/database_cleanup_service.py
new file mode 100644
index 0000000..a4558d8
--- /dev/null
+++ b/app/service/database_cleanup_service.py
@@ -0,0 +1,289 @@
+"""
+数据库清理服务 - 清理过期的验证码和会话
+"""
+
+from __future__ import annotations
+
+from datetime import datetime, UTC, timedelta
+
+from app.database.email_verification import EmailVerification, LoginSession
+from app.log import logger
+
+from sqlmodel import select
+from sqlmodel.ext.asyncio.session import AsyncSession
+from sqlalchemy import and_
+
+
+class DatabaseCleanupService:
+ """数据库清理服务"""
+
+ @staticmethod
+ async def cleanup_expired_verification_codes(db: AsyncSession) -> int:
+ """
+ 清理过期的邮件验证码
+
+ Args:
+ db: 数据库会话
+
+ Returns:
+ int: 清理的记录数
+ """
+ try:
+ # 查找过期的验证码记录
+ current_time = datetime.now(UTC)
+
+ stmt = select(EmailVerification).where(
+ EmailVerification.expires_at < current_time
+ )
+ result = await db.exec(stmt)
+ expired_codes = result.all()
+
+ # 删除过期的记录
+ deleted_count = 0
+ for code in expired_codes:
+ await db.delete(code)
+ deleted_count += 1
+
+ await db.commit()
+
+ if deleted_count > 0:
+ logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes")
+
+ return deleted_count
+
+ except Exception as e:
+ await db.rollback()
+ logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {str(e)}")
+ return 0
+
+ @staticmethod
+ async def cleanup_expired_login_sessions(db: AsyncSession) -> int:
+ """
+ 清理过期的登录会话
+
+ Args:
+ db: 数据库会话
+
+ Returns:
+ int: 清理的记录数
+ """
+ try:
+ # 查找过期的登录会话记录
+ current_time = datetime.now(UTC)
+
+ stmt = select(LoginSession).where(
+ LoginSession.expires_at < current_time
+ )
+ result = await db.exec(stmt)
+ expired_sessions = result.all()
+
+ # 删除过期的记录
+ deleted_count = 0
+ for session in expired_sessions:
+ await db.delete(session)
+ deleted_count += 1
+
+ await db.commit()
+
+ if deleted_count > 0:
+ logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions")
+
+ return deleted_count
+
+ except Exception as e:
+ await db.rollback()
+ logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {str(e)}")
+ return 0
+
+ @staticmethod
+ async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int:
+ """
+ 清理旧的已使用验证码记录
+
+ Args:
+ db: 数据库会话
+ days_old: 清理多少天前的已使用记录,默认7天
+
+ Returns:
+ int: 清理的记录数
+ """
+ try:
+ # 查找指定天数前的已使用验证码记录
+ cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
+
+ stmt = select(EmailVerification).where(
+ EmailVerification.is_used == True
+ )
+ result = await db.exec(stmt)
+ all_used_codes = result.all()
+
+ # 筛选出过期的记录
+ old_used_codes = [
+ code for code in all_used_codes
+ if code.used_at and code.used_at < cutoff_time
+ ]
+
+ # 删除旧的已使用记录
+ deleted_count = 0
+ for code in old_used_codes:
+ await db.delete(code)
+ deleted_count += 1
+
+ await db.commit()
+
+ if deleted_count > 0:
+ logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days")
+
+ return deleted_count
+
+ except Exception as e:
+ await db.rollback()
+ logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {str(e)}")
+ return 0
+
+ @staticmethod
+ async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
+ """
+ 清理旧的已验证会话记录
+
+ Args:
+ db: 数据库会话
+ days_old: 清理多少天前的已验证记录,默认30天
+
+ Returns:
+ int: 清理的记录数
+ """
+ try:
+ # 查找指定天数前的已验证会话记录
+ cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
+
+ stmt = select(LoginSession).where(
+ LoginSession.is_verified == True
+ )
+ result = await db.exec(stmt)
+ all_verified_sessions = result.all()
+
+ # 筛选出过期的记录
+ old_verified_sessions = [
+ session for session in all_verified_sessions
+ if session.verified_at and session.verified_at < cutoff_time
+ ]
+
+ # 删除旧的已验证记录
+ deleted_count = 0
+ for session in old_verified_sessions:
+ await db.delete(session)
+ deleted_count += 1
+
+ await db.commit()
+
+ if deleted_count > 0:
+ logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days")
+
+ return deleted_count
+
+ except Exception as e:
+ await db.rollback()
+ logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {str(e)}")
+ return 0
+
+ @staticmethod
+ async def run_full_cleanup(db: AsyncSession) -> dict[str, int]:
+ """
+ 运行完整的清理流程
+
+ Args:
+ db: 数据库会话
+
+ Returns:
+ dict: 各项清理的结果统计
+ """
+ results = {}
+
+ # 清理过期的验证码
+ results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
+
+ # 清理过期的登录会话
+ results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
+
+ # 清理7天前的已使用验证码
+ results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
+
+ # 清理30天前的已验证会话
+ results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
+
+ total_cleaned = sum(results.values())
+ if total_cleaned > 0:
+ logger.debug(f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}")
+
+ return results
+
+ @staticmethod
+ async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]:
+ """
+ 获取清理统计信息
+
+ Args:
+ db: 数据库会话
+
+ Returns:
+ dict: 统计信息
+ """
+ try:
+ current_time = datetime.now(UTC)
+ cutoff_7_days = current_time - timedelta(days=7)
+ cutoff_30_days = current_time - timedelta(days=30)
+
+ # 统计过期的验证码数量
+ expired_codes_stmt = select(EmailVerification).where(
+ EmailVerification.expires_at < current_time
+ )
+ expired_codes_result = await db.exec(expired_codes_stmt)
+ expired_codes_count = len(expired_codes_result.all())
+
+ # 统计过期的登录会话数量
+ expired_sessions_stmt = select(LoginSession).where(
+ LoginSession.expires_at < current_time
+ )
+ expired_sessions_result = await db.exec(expired_sessions_stmt)
+ expired_sessions_count = len(expired_sessions_result.all())
+
+ # 统计7天前的已使用验证码数量
+ old_used_codes_stmt = select(EmailVerification).where(
+ EmailVerification.is_used == True
+ )
+ old_used_codes_result = await db.exec(old_used_codes_stmt)
+ all_used_codes = old_used_codes_result.all()
+ old_used_codes_count = len([
+ code for code in all_used_codes
+ if code.used_at and code.used_at < cutoff_7_days
+ ])
+
+ # 统计30天前的已验证会话数量
+ old_verified_sessions_stmt = select(LoginSession).where(
+ LoginSession.is_verified == True
+ )
+ old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
+ all_verified_sessions = old_verified_sessions_result.all()
+ old_verified_sessions_count = len([
+ session for session in all_verified_sessions
+ if session.verified_at and session.verified_at < cutoff_30_days
+ ])
+
+ return {
+ "expired_verification_codes": expired_codes_count,
+ "expired_login_sessions": expired_sessions_count,
+ "old_used_verification_codes": old_used_codes_count,
+ "old_verified_sessions": old_verified_sessions_count,
+ "total_cleanable": expired_codes_count + expired_sessions_count + old_used_codes_count + old_verified_sessions_count
+ }
+
+ except Exception as e:
+ logger.error(f"[Cleanup Service] Error getting cleanup statistics: {str(e)}")
+ return {
+ "expired_verification_codes": 0,
+ "expired_login_sessions": 0,
+ "old_used_verification_codes": 0,
+ "old_verified_sessions": 0,
+ "total_cleanable": 0
+ }
diff --git a/app/service/email_service.py b/app/service/email_service.py
new file mode 100644
index 0000000..3685b15
--- /dev/null
+++ b/app/service/email_service.py
@@ -0,0 +1,167 @@
+"""
+邮件验证服务
+"""
+
+from __future__ import annotations
+
+import smtplib
+from email.mime.text import MIMEText
+from email.mime.multipart import MIMEMultipart
+import secrets
+import string
+from datetime import datetime, UTC, timedelta
+from typing import Optional
+
+from app.config import settings
+from app.log import logger
+
+
+class EmailService:
+ """邮件发送服务"""
+
+ def __init__(self):
+ self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
+ self.smtp_port = getattr(settings, 'smtp_port', 587)
+ self.smtp_username = getattr(settings, 'smtp_username', '')
+ self.smtp_password = getattr(settings, 'smtp_password', '')
+ self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
+ self.from_name = getattr(settings, 'from_name', 'osu! server')
+
+ def generate_verification_code(self) -> str:
+ """生成8位验证码"""
+ # 只使用数字,避免混淆
+ return ''.join(secrets.choice(string.digits) for _ in range(8))
+
+ async def send_verification_email(self, email: str, code: str, username: str) -> bool:
+ """发送验证邮件"""
+ try:
+ msg = MIMEMultipart()
+ msg['From'] = f"{self.from_name} <{self.from_email}>"
+ msg['To'] = email
+ msg['Subject'] = "邮箱验证 - Email Verification"
+
+ # HTML 邮件内容
+ html_content = f"""
+
+
+
+
+
+
+
+
+
+
+
+
你好 {username}!
+
感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:
+
+
{code}
+
+
这个验证码将在 10 分钟后过期。
+
+
+
注意:
+
+ - 请不要与任何人分享这个验证码
+ - 如果你没有请求此验证码,请忽略这封邮件
+ - 验证码只能使用一次
+
+
+
+
如果你有任何问题,请联系我们的支持团队。
+
+
+
+
Hello {username}!
+
Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:
+
+
This verification code will expire in 10 minutes.
+
+
Important: Do not share this verification code with anyone. If you did not request this code, please ignore this email.
+
+
+
+
+
+
+ """
+
+ msg.attach(MIMEText(html_content, 'html', 'utf-8'))
+
+ # 发送邮件
+ if not settings.enable_email_sending:
+ # 邮件发送功能禁用时只记录日志,不实际发送
+ logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}")
+ return True
+
+ with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
+ if self.smtp_username and self.smtp_password:
+ server.starttls()
+ server.login(self.smtp_username, self.smtp_password)
+
+ server.send_message(msg)
+
+ logger.info(f"[Email Verification] Successfully sent verification code to {email}")
+ return True
+
+ except Exception as e:
+ logger.error(f"[Email Verification] Failed to send email: {e}")
+ return False
+
+
+# 全局邮件服务实例
+email_service = EmailService()
diff --git a/app/service/email_verification_service.py b/app/service/email_verification_service.py
new file mode 100644
index 0000000..d513fbc
--- /dev/null
+++ b/app/service/email_verification_service.py
@@ -0,0 +1,367 @@
+"""
+邮件验证管理服务
+"""
+
+from __future__ import annotations
+
+import secrets
+import string
+from datetime import datetime, UTC, timedelta
+from typing import Optional
+
+from app.database.email_verification import EmailVerification, LoginSession
+from app.service.email_service import email_service
+from app.log import logger
+from app.config import settings
+
+from sqlmodel.ext.asyncio.session import AsyncSession
+from sqlmodel import select
+from redis.asyncio import Redis
+
+
+class EmailVerificationService:
+ """邮件验证服务"""
+
+ @staticmethod
+ def generate_verification_code() -> str:
+ """生成8位验证码"""
+ return ''.join(secrets.choice(string.digits) for _ in range(8))
+
+ @staticmethod
+ def generate_session_token() -> str:
+ """生成会话令牌"""
+ return secrets.token_urlsafe(32)
+
+ @staticmethod
+ async def create_verification_record(
+ db: AsyncSession,
+ redis: Redis,
+ user_id: int,
+ email: str,
+ ip_address: str | None = None,
+ user_agent: str | None = None
+ ) -> tuple[EmailVerification, str]:
+ """创建邮件验证记录"""
+
+ # 检查是否有未过期的验证码
+ existing_result = await db.exec(
+ select(EmailVerification).where(
+ EmailVerification.user_id == user_id,
+ EmailVerification.is_used == False,
+ EmailVerification.expires_at > datetime.now(UTC)
+ )
+ )
+ existing = existing_result.first()
+
+ if existing:
+ # 如果有未过期的验证码,直接返回
+ return existing, existing.verification_code
+
+ # 生成新的验证码
+ code = EmailVerificationService.generate_verification_code()
+
+ # 创建验证记录
+ verification = EmailVerification(
+ user_id=user_id,
+ email=email,
+ verification_code=code,
+ expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期
+ ip_address=ip_address,
+ user_agent=user_agent
+ )
+
+ db.add(verification)
+ await db.commit()
+ await db.refresh(verification)
+
+ # 存储到 Redis(用于快速验证)
+ await redis.setex(
+ f"email_verification:{user_id}:{code}",
+ 600, # 10分钟过期
+ str(verification.id) if verification.id else "0"
+ )
+
+ logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}")
+ return verification, code
+
+ @staticmethod
+ async def send_verification_email(
+ db: AsyncSession,
+ redis: Redis,
+ user_id: int,
+ username: str,
+ email: str,
+ ip_address: str | None = None,
+ user_agent: str | None = None
+ ) -> bool:
+ """发送验证邮件"""
+ try:
+ # 检查是否启用邮件验证功能
+ if not settings.enable_email_verification:
+ logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
+ return True # 返回成功,但不执行验证流程
+
+ # 创建验证记录
+ verification, code = await EmailVerificationService.create_verification_record(
+ db, redis, user_id, email, ip_address, user_agent
+ )
+
+ # 发送邮件
+ success = await email_service.send_verification_email(email, code, username)
+
+ if success:
+ logger.info(f"[Email Verification] Successfully sent verification email to {email} (user: {username})")
+ return True
+ else:
+ logger.error(f"[Email Verification] Failed to send verification email: {email} (user: {username})")
+ return False
+
+ except Exception as e:
+ logger.error(f"[Email Verification] Exception during sending verification email: {e}")
+ return False
+
+ @staticmethod
+ async def verify_code(
+ db: AsyncSession,
+ redis: Redis,
+ user_id: int,
+ code: str,
+ ip_address: str | None = None
+ ) -> tuple[bool, str]:
+ """验证验证码"""
+ try:
+ # 检查是否启用邮件验证功能
+ if not settings.enable_email_verification:
+ logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}")
+ # 仍然标记登录会话为已验证
+ await LoginSessionService.mark_session_verified(db, user_id)
+ return True, "验证成功(邮件验证功能已禁用)"
+
+ # 先从 Redis 检查
+ verification_id = await redis.get(f"email_verification:{user_id}:{code}")
+ if not verification_id:
+ return False, "验证码无效或已过期"
+
+ # 从数据库获取验证记录
+ result = await db.exec(
+ select(EmailVerification).where(
+ EmailVerification.id == int(verification_id),
+ EmailVerification.user_id == user_id,
+ EmailVerification.verification_code == code,
+ EmailVerification.is_used == False,
+ EmailVerification.expires_at > datetime.now(UTC)
+ )
+ )
+
+ verification = result.first()
+ if not verification:
+ return False, "验证码无效或已过期"
+
+ # 标记为已使用
+ verification.is_used = True
+ verification.used_at = datetime.now(UTC)
+
+ # 同时更新对应的登录会话状态
+ await LoginSessionService.mark_session_verified(db, user_id)
+
+ await db.commit()
+
+ # 删除 Redis 记录
+ await redis.delete(f"email_verification:{user_id}:{code}")
+
+ logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
+ return True, "验证成功"
+
+ except Exception as e:
+ logger.error(f"[Email Verification] Exception during verification code validation: {e}")
+ return False, "验证过程中发生错误"
+
+ @staticmethod
+ async def resend_verification_code(
+ db: AsyncSession,
+ redis: Redis,
+ user_id: int,
+ username: str,
+ email: str,
+ ip_address: str | None = None,
+ user_agent: str | None = None
+ ) -> tuple[bool, str]:
+ """重新发送验证码"""
+ try:
+ # 检查是否启用邮件验证功能
+ if not settings.enable_email_verification:
+ logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
+ return True, "验证码已发送(邮件验证功能已禁用)"
+
+ # 检查重发频率限制(60秒内只能发送一次)
+ rate_limit_key = f"email_verification_rate_limit:{user_id}"
+ if await redis.get(rate_limit_key):
+ return False, "请等待60秒后再重新发送"
+
+ # 设置频率限制
+ await redis.setex(rate_limit_key, 60, "1")
+
+ # 生成新的验证码
+ success = await EmailVerificationService.send_verification_email(
+ db, redis, user_id, username, email, ip_address, user_agent
+ )
+
+ if success:
+ return True, "验证码已重新发送"
+ else:
+ return False, "重新发送失败,请稍后再试"
+
+ except Exception as e:
+ logger.error(f"[Email Verification] Exception during resending verification code: {e}")
+ return False, "重新发送过程中发生错误"
+
+
+class LoginSessionService:
+ """登录会话服务"""
+
+ @staticmethod
+ async def create_session(
+ db: AsyncSession,
+ redis: Redis,
+ user_id: int,
+ ip_address: str,
+ user_agent: str | None = None,
+ country_code: str | None = None,
+ is_new_location: bool = False
+ ) -> LoginSession:
+ """创建登录会话"""
+ session_token = EmailVerificationService.generate_session_token()
+
+ session = LoginSession(
+ user_id=user_id,
+ session_token=session_token,
+ ip_address=ip_address,
+ user_agent=user_agent,
+ country_code=country_code,
+ is_new_location=is_new_location,
+ expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期
+ is_verified=not is_new_location # 新位置需要验证
+ )
+
+ db.add(session)
+ await db.commit()
+ await db.refresh(session)
+
+ # 存储到 Redis
+ await redis.setex(
+ f"login_session:{session_token}",
+ 86400, # 24小时
+ user_id
+ )
+
+ logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
+ return session
+
+ @staticmethod
+ async def verify_session(
+ db: AsyncSession,
+ redis: Redis,
+ session_token: str,
+ verification_code: str
+ ) -> tuple[bool, str]:
+ """验证会话(通过邮件验证码)"""
+ try:
+ # 从 Redis 获取用户ID
+ user_id = await redis.get(f"login_session:{session_token}")
+ if not user_id:
+ return False, "会话无效或已过期"
+
+ user_id = int(user_id)
+
+ # 验证邮件验证码
+ success, message = await EmailVerificationService.verify_code(
+ db, redis, user_id, verification_code
+ )
+
+ if not success:
+ return False, message
+
+ # 更新会话状态
+ result = await db.exec(
+ select(LoginSession).where(
+ LoginSession.session_token == session_token,
+ LoginSession.user_id == user_id,
+ LoginSession.is_verified == False
+ )
+ )
+
+ session = result.first()
+ if session:
+ session.is_verified = True
+ session.verified_at = datetime.now(UTC)
+ await db.commit()
+
+ logger.info(f"[Login Session] User {user_id} session verification successful")
+ return True, "会话验证成功"
+
+ except Exception as e:
+ logger.error(f"[Login Session] Exception during session verification: {e}")
+ return False, "验证过程中发生错误"
+
+ @staticmethod
+ async def check_new_location(
+ db: AsyncSession,
+ user_id: int,
+ ip_address: str,
+ country_code: str | None = None
+ ) -> bool:
+ """检查是否为新位置登录"""
+ try:
+ # 查看过去30天内是否有相同IP或相同国家的登录记录
+ thirty_days_ago = datetime.now(UTC) - timedelta(days=30)
+
+ result = await db.exec(
+ select(LoginSession).where(
+ LoginSession.user_id == user_id,
+ LoginSession.created_at > thirty_days_ago,
+ (LoginSession.ip_address == ip_address) |
+ (LoginSession.country_code == country_code)
+ )
+ )
+
+ existing_sessions = result.all()
+
+ # 如果有历史记录,则不是新位置
+ return len(existing_sessions) == 0
+
+ except Exception as e:
+ logger.error(f"[Login Session] Exception during new location check: {e}")
+ # 出错时默认为新位置(更安全)
+ return True
+
+ @staticmethod
+ async def mark_session_verified(
+ db: AsyncSession,
+ user_id: int
+ ) -> bool:
+ """标记用户的未验证会话为已验证"""
+ try:
+ # 查找用户所有未验证且未过期的会话
+ result = await db.exec(
+ select(LoginSession).where(
+ LoginSession.user_id == user_id,
+ LoginSession.is_verified == False,
+ LoginSession.expires_at > datetime.now(UTC)
+ )
+ )
+
+ sessions = result.all()
+
+ # 标记所有会话为已验证
+ for session in sessions:
+ session.is_verified = True
+ session.verified_at = datetime.now(UTC)
+
+ if sessions:
+ logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
+
+ return len(sessions) > 0
+
+ except Exception as e:
+ logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
+ return False
diff --git a/app/service/redis_message_system.py b/app/service/redis_message_system.py
index c6322f2..51c7df6 100644
--- a/app/service/redis_message_system.py
+++ b/app/service/redis_message_system.py
@@ -65,6 +65,17 @@ class RedisMessageSystem:
if not user.id:
raise ValueError("User ID is required")
+ # 获取频道类型以判断是否需要存储到数据库
+ async with with_db() as session:
+ from app.database.chat import ChatChannel, ChannelType
+ from sqlmodel import select
+
+ channel_result = await session.exec(
+ select(ChatChannel.type).where(ChatChannel.channel_id == channel_id)
+ )
+ channel_type = channel_result.first()
+ is_multiplayer = channel_type == ChannelType.MULTIPLAYER
+
# 准备消息数据
message_data = {
"message_id": message_id,
@@ -76,6 +87,7 @@ class RedisMessageSystem:
"uuid": user_uuid or "",
"status": "cached", # Redis 缓存状态
"created_at": time.time(),
+ "is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
}
# 立即存储到 Redis
@@ -118,9 +130,14 @@ class RedisMessageSystem:
uuid=user_uuid,
)
- logger.info(
- f"Message {message_id} sent to Redis cache for channel {channel_id}"
- )
+ if is_multiplayer:
+ logger.info(
+ f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}, will not be persisted to database"
+ )
+ else:
+ logger.info(
+ f"Message {message_id} sent to Redis cache for channel {channel_id}"
+ )
return response
async def get_messages(
@@ -222,6 +239,9 @@ class RedisMessageSystem:
):
"""存储消息到 Redis"""
try:
+ # 检查是否是多人房间消息
+ is_multiplayer = message_data.get("is_multiplayer", False)
+
# 存储消息数据
await self._redis_exec(
self.redis.hset,
@@ -267,10 +287,14 @@ class RedisMessageSystem:
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
)
- # 添加到待持久化队列
- await self._redis_exec(
- self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
- )
+ # 只有非多人房间消息才添加到待持久化队列
+ if not is_multiplayer:
+ await self._redis_exec(
+ self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
+ )
+ logger.debug(f"Message {message_id} added to persistence queue")
+ else:
+ logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
except Exception as e:
logger.error(f"Failed to store message to Redis: {e}")
@@ -475,6 +499,19 @@ class RedisMessageSystem:
v = v.decode("utf-8")
message_data[k] = v
+ # 检查是否是多人房间消息,如果是则跳过数据库存储
+ is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
+ if is_multiplayer:
+ # 多人房间消息不存储到数据库,直接标记为已跳过
+ await self._redis_exec(
+ self.redis.hset,
+ f"msg:{channel_id}:{message_id}",
+ "status",
+ "skipped_multiplayer",
+ )
+ logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
+ continue
+
# 检查消息是否已存在于数据库
existing = await session.get(ChatMessage, int(message_id))
if existing:
diff --git a/app/service/session_manager.py b/app/service/session_manager.py
new file mode 100644
index 0000000..11d78c4
--- /dev/null
+++ b/app/service/session_manager.py
@@ -0,0 +1,121 @@
+"""
+API 状态管理 - 模拟 osu! 的 APIState 和会话管理
+"""
+
+from __future__ import annotations
+
+from enum import Enum
+from typing import Optional
+from datetime import datetime
+
+from pydantic import BaseModel
+
+
+class APIState(str, Enum):
+ """API 连接状态,对应 osu! 的 APIState"""
+ OFFLINE = "offline"
+ CONNECTING = "connecting"
+ REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证
+ ONLINE = "online"
+ FAILING = "failing"
+
+
+class UserSession(BaseModel):
+ """用户会话信息"""
+ user_id: int
+ username: str
+ email: str
+ session_token: str | None = None
+ state: APIState = APIState.OFFLINE
+ requires_verification: bool = False
+ verification_sent: bool = False
+ last_verification_attempt: datetime | None = None
+ failed_attempts: int = 0
+ ip_address: str | None = None
+ country_code: str | None = None
+ is_new_location: bool = False
+
+
+class SessionManager:
+ """会话管理器"""
+
+ def __init__(self):
+ self._sessions: dict[str, UserSession] = {}
+
+ def create_session(
+ self,
+ user_id: int,
+ username: str,
+ email: str,
+ ip_address: str,
+ country_code: str | None = None,
+ is_new_location: bool = False
+ ) -> UserSession:
+ """创建新的用户会话"""
+ import secrets
+
+ session_token = secrets.token_urlsafe(32)
+
+ # 根据是否为新位置决定初始状态
+ if is_new_location:
+ state = APIState.REQUIRES_SECOND_FACTOR_AUTH
+ else:
+ state = APIState.ONLINE
+
+ session = UserSession(
+ user_id=user_id,
+ username=username,
+ email=email,
+ session_token=session_token,
+ state=state,
+ requires_verification=is_new_location,
+ ip_address=ip_address,
+ country_code=country_code,
+ is_new_location=is_new_location
+ )
+
+ self._sessions[session_token] = session
+ return session
+
+ def get_session(self, session_token: str) -> UserSession | None:
+ """获取会话"""
+ return self._sessions.get(session_token)
+
+ def update_session_state(self, session_token: str, state: APIState):
+ """更新会话状态"""
+ if session_token in self._sessions:
+ self._sessions[session_token].state = state
+
+ def mark_verification_sent(self, session_token: str):
+ """标记验证邮件已发送"""
+ if session_token in self._sessions:
+ session = self._sessions[session_token]
+ session.verification_sent = True
+ session.last_verification_attempt = datetime.now()
+
+ def increment_failed_attempts(self, session_token: str):
+ """增加失败尝试次数"""
+ if session_token in self._sessions:
+ self._sessions[session_token].failed_attempts += 1
+
+ def verify_session(self, session_token: str) -> bool:
+ """验证会话成功"""
+ if session_token in self._sessions:
+ session = self._sessions[session_token]
+ session.state = APIState.ONLINE
+ session.requires_verification = False
+ return True
+ return False
+
+ def remove_session(self, session_token: str):
+ """移除会话"""
+ self._sessions.pop(session_token, None)
+
+ def cleanup_expired_sessions(self):
+ """清理过期会话"""
+ # 这里可以实现清理逻辑
+ pass
+
+
+# 全局会话管理器
+session_manager = SessionManager()
diff --git a/main.py b/main.py
index 4a4a417..716301d 100644
--- a/main.py
+++ b/main.py
@@ -23,6 +23,7 @@ from app.router import (
)
from app.router.redirect import redirect_router
from app.scheduler.cache_scheduler import start_cache_scheduler, stop_cache_scheduler
+from app.scheduler.database_cleanup_scheduler import start_database_cleanup_scheduler, stop_database_cleanup_scheduler
from app.service.beatmap_download_service import download_service
from app.service.calculate_all_user_rank import calculate_user_rank
from app.service.create_banchobot import create_banchobot
@@ -79,6 +80,7 @@ async def lifespan(app: FastAPI):
await create_banchobot()
await download_service.start_health_check() # 启动下载服务健康检查
await start_cache_scheduler() # 启动缓存调度器
+ await start_database_cleanup_scheduler() # 启动数据库清理调度器
redis_message_system.start() # 启动 Redis 消息系统
start_stats_scheduler() # 启动统计调度器
load_achievements()
@@ -88,6 +90,7 @@ async def lifespan(app: FastAPI):
redis_message_system.stop() # 停止 Redis 消息系统
stop_stats_scheduler() # 停止统计调度器
await stop_cache_scheduler() # 停止缓存调度器
+ await stop_database_cleanup_scheduler() # 停止数据库清理调度器
await download_service.stop_health_check() # 停止下载服务健康检查
await engine.dispose()
await redis_client.aclose()
diff --git a/migrations/versions/0f96348cdfd2_add_email_verification_tables.py b/migrations/versions/0f96348cdfd2_add_email_verification_tables.py
new file mode 100644
index 0000000..0fd86c6
--- /dev/null
+++ b/migrations/versions/0f96348cdfd2_add_email_verification_tables.py
@@ -0,0 +1,65 @@
+"""add email verification tables
+
+Revision ID: 0f96348cdfd2
+Revises: e96a649e18ca
+Create Date: 2025-08-22 07:26:59.129564
+
+"""
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision: str = "0f96348cdfd2"
+down_revision: str | Sequence[str] | None = "e96a649e18ca"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ """Upgrade schema."""
+ # 创建邮件验证表
+ op.create_table(
+ "email_verifications",
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True, autoincrement=True),
+ sa.Column("user_id", sa.BigInteger(), nullable=False),
+ sa.Column("email", sa.String(255), nullable=False),
+ sa.Column("verification_code", sa.String(8), nullable=False),
+ sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("is_used", sa.Boolean(), nullable=False, default=False),
+ sa.Column("used_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column("ip_address", sa.String(255), nullable=True),
+ sa.Column("user_agent", sa.String(255), nullable=True),
+ sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"]),
+ sa.Index("ix_email_verifications_user_id", "user_id"),
+ sa.Index("ix_email_verifications_email", "email"),
+ )
+
+ # 创建登录会话表
+ op.create_table(
+ "login_sessions",
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True, autoincrement=True),
+ sa.Column("user_id", sa.BigInteger(), nullable=False),
+ sa.Column("session_token", sa.String(255), nullable=False),
+ sa.Column("ip_address", sa.String(255), nullable=False),
+ sa.Column("user_agent", sa.String(255), nullable=True),
+ sa.Column("country_code", sa.String(255), nullable=True),
+ sa.Column("is_verified", sa.Boolean(), nullable=False, default=False),
+ sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("verified_at", sa.DateTime(timezone=True), nullable=True),
+ sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("is_new_location", sa.Boolean(), nullable=False, default=False),
+ sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"]),
+ sa.Index("ix_login_sessions_user_id", "user_id"),
+ sa.Index("ix_login_sessions_session_token", "session_token", unique=True),
+ )
+
+
+def downgrade() -> None:
+ """Downgrade schema."""
+ op.drop_table("login_sessions")
+ op.drop_table("email_verifications")