添加邮件验证
This commit is contained in:
12
.env.example
12
.env.example
@@ -45,6 +45,18 @@ FETCHER_SCOPES=public
|
|||||||
# 日志设置
|
# 日志设置
|
||||||
LOG_LEVEL="INFO"
|
LOG_LEVEL="INFO"
|
||||||
|
|
||||||
|
# 邮件服务设置
|
||||||
|
SMTP_SERVER="smtp.gmail.com" # SMTP 服务器地址
|
||||||
|
SMTP_PORT=587 # SMTP 端口
|
||||||
|
SMTP_USERNAME="your-email@gmail.com" # 邮箱用户名
|
||||||
|
SMTP_PASSWORD="your-app-password" # 邮箱密码或应用专用密码
|
||||||
|
FROM_EMAIL="noreply@your-server.com" # 发送方邮箱
|
||||||
|
FROM_NAME="osu! Private Server" # 发送方名称
|
||||||
|
|
||||||
|
# 邮件验证功能开关
|
||||||
|
ENABLE_EMAIL_VERIFICATION=true # 是否启用邮件验证功能(新位置登录时需要邮件验证)
|
||||||
|
ENABLE_EMAIL_SENDING=false # 是否真实发送邮件(false时仅模拟发送,输出到日志)
|
||||||
|
|
||||||
# Sentry 设置,为空表示不启用
|
# Sentry 设置,为空表示不启用
|
||||||
SENTRY_DSN
|
SENTRY_DSN
|
||||||
|
|
||||||
|
|||||||
@@ -117,6 +117,22 @@ class Settings(BaseSettings):
|
|||||||
# 日志设置
|
# 日志设置
|
||||||
log_level: str = "INFO"
|
log_level: str = "INFO"
|
||||||
|
|
||||||
|
# 邮件服务设置
|
||||||
|
smtp_server: str = "localhost"
|
||||||
|
smtp_port: int = 587
|
||||||
|
smtp_username: str = ""
|
||||||
|
smtp_password: str = ""
|
||||||
|
from_email: str = "noreply@example.com"
|
||||||
|
from_name: str = "osu! server"
|
||||||
|
|
||||||
|
# 邮件验证功能开关
|
||||||
|
enable_email_verification: bool = Field(
|
||||||
|
default=True, description="是否启用邮件验证功能"
|
||||||
|
)
|
||||||
|
enable_email_sending: bool = Field(
|
||||||
|
default=False, description="是否真实发送邮件(False时仅模拟发送)"
|
||||||
|
)
|
||||||
|
|
||||||
# Sentry 配置
|
# Sentry 配置
|
||||||
sentry_dsn: HttpUrl | None = None
|
sentry_dsn: HttpUrl | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from .counts import (
|
|||||||
ReplayWatchedCount,
|
ReplayWatchedCount,
|
||||||
)
|
)
|
||||||
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
|
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
|
||||||
|
from .email_verification import EmailVerification, LoginSession
|
||||||
from .favourite_beatmapset import FavouriteBeatmapset
|
from .favourite_beatmapset import FavouriteBeatmapset
|
||||||
from .lazer_user import (
|
from .lazer_user import (
|
||||||
User,
|
User,
|
||||||
|
|||||||
44
app/database/email_verification.py
Normal file
44
app/database/email_verification.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""
|
||||||
|
邮件验证相关数据库模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
from sqlmodel import SQLModel, Field
|
||||||
|
from sqlalchemy import Column, BigInteger, ForeignKey
|
||||||
|
|
||||||
|
|
||||||
|
class EmailVerification(SQLModel, table=True):
|
||||||
|
"""邮件验证记录"""
|
||||||
|
|
||||||
|
__tablename__: str = "email_verifications"
|
||||||
|
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||||
|
email: str = Field(index=True)
|
||||||
|
verification_code: str = Field(max_length=8) # 8位验证码
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
expires_at: datetime = Field() # 验证码过期时间
|
||||||
|
is_used: bool = Field(default=False) # 是否已使用
|
||||||
|
used_at: datetime | None = Field(default=None)
|
||||||
|
ip_address: str | None = Field(default=None) # 请求IP
|
||||||
|
user_agent: str | None = Field(default=None) # 用户代理
|
||||||
|
|
||||||
|
|
||||||
|
class LoginSession(SQLModel, table=True):
|
||||||
|
"""登录会话记录"""
|
||||||
|
|
||||||
|
__tablename__: str = "login_sessions"
|
||||||
|
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||||
|
session_token: str = Field(unique=True, index=True) # 会话令牌
|
||||||
|
ip_address: str = Field() # 登录IP
|
||||||
|
user_agent: str | None = Field(default=None)
|
||||||
|
country_code: str | None = Field(default=None)
|
||||||
|
is_verified: bool = Field(default=False) # 是否已验证
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
verified_at: datetime | None = Field(default=None)
|
||||||
|
expires_at: datetime = Field() # 会话过期时间
|
||||||
|
is_new_location: bool = Field(default=False) # 是否新位置登录
|
||||||
@@ -475,6 +475,25 @@ class UserResp(UserBase):
|
|||||||
)
|
)
|
||||||
).one()
|
).one()
|
||||||
|
|
||||||
|
# 检查会话验证状态
|
||||||
|
# 如果邮件验证功能被禁用,则始终设置 session_verified 为 true
|
||||||
|
from app.config import settings
|
||||||
|
if not settings.enable_email_verification:
|
||||||
|
u.session_verified = True
|
||||||
|
else:
|
||||||
|
# 如果用户有未验证的登录会话,则设置 session_verified 为 false
|
||||||
|
from .email_verification import LoginSession
|
||||||
|
unverified_session = (
|
||||||
|
await session.exec(
|
||||||
|
select(LoginSession).where(
|
||||||
|
LoginSession.user_id == obj.id,
|
||||||
|
LoginSession.is_verified == False,
|
||||||
|
LoginSession.expires_at > datetime.now(UTC)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
u.session_verified = unverified_session is None
|
||||||
|
|
||||||
return u
|
return u
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
17
app/models/api_me.py
Normal file
17
app/models/api_me.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
APIMe 响应模型 - 对应 osu! 的 APIMe 类型
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.database.lazer_user import UserResp
|
||||||
|
|
||||||
|
|
||||||
|
class APIMe(UserResp):
|
||||||
|
"""
|
||||||
|
/me 端点的响应模型
|
||||||
|
对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段
|
||||||
|
|
||||||
|
session_verified 字段已经在 UserResp 中定义,这里不需要重复定义
|
||||||
|
"""
|
||||||
|
pass
|
||||||
31
app/models/extended_auth.py
Normal file
31
app/models/extended_auth.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""
|
||||||
|
扩展的 OAuth 响应模型,支持二次验证
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ExtendedTokenResponse(BaseModel):
|
||||||
|
"""扩展的令牌响应,支持二次验证状态"""
|
||||||
|
access_token: str | None = None
|
||||||
|
token_type: str = "Bearer"
|
||||||
|
expires_in: int | None = None
|
||||||
|
refresh_token: str | None = None
|
||||||
|
scope: str | None = None
|
||||||
|
|
||||||
|
# 二次验证相关字段
|
||||||
|
requires_second_factor: bool = False
|
||||||
|
verification_message: str | None = None
|
||||||
|
user_id: int | None = None # 用于二次验证的用户ID
|
||||||
|
|
||||||
|
|
||||||
|
class SessionState(BaseModel):
|
||||||
|
"""会话状态"""
|
||||||
|
user_id: int
|
||||||
|
username: str
|
||||||
|
email: str
|
||||||
|
requires_verification: bool
|
||||||
|
session_token: str | None = None
|
||||||
|
verification_sent: bool = False
|
||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
import re
|
import re
|
||||||
from typing import Literal
|
from typing import Literal, Union
|
||||||
|
|
||||||
from app.auth import (
|
from app.auth import (
|
||||||
authenticate_user,
|
authenticate_user,
|
||||||
@@ -28,8 +28,13 @@ from app.models.oauth import (
|
|||||||
TokenResponse,
|
TokenResponse,
|
||||||
UserRegistrationErrors,
|
UserRegistrationErrors,
|
||||||
)
|
)
|
||||||
|
from app.models.extended_auth import ExtendedTokenResponse
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
from app.service.login_log_service import LoginLogService
|
from app.service.login_log_service import LoginLogService
|
||||||
|
from app.service.email_verification_service import (
|
||||||
|
EmailVerificationService,
|
||||||
|
LoginSessionService
|
||||||
|
)
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Form, Request
|
from fastapi import APIRouter, Depends, Form, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
@@ -198,7 +203,7 @@ async def register_user(
|
|||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/oauth/token",
|
"/oauth/token",
|
||||||
response_model=TokenResponse,
|
response_model=Union[TokenResponse, ExtendedTokenResponse],
|
||||||
name="获取访问令牌",
|
name="获取访问令牌",
|
||||||
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
|
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
|
||||||
)
|
)
|
||||||
@@ -218,6 +223,7 @@ async def oauth_token(
|
|||||||
None, description="刷新令牌(仅刷新令牌模式需要)"
|
None, description="刷新令牌(仅刷新令牌模式需要)"
|
||||||
),
|
),
|
||||||
redis: Redis = Depends(get_redis),
|
redis: Redis = Depends(get_redis),
|
||||||
|
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||||
):
|
):
|
||||||
scopes = scope.split(" ")
|
scopes = scope.split(" ")
|
||||||
|
|
||||||
@@ -295,17 +301,68 @@ async def oauth_token(
|
|||||||
# 确保用户对象与当前会话关联
|
# 确保用户对象与当前会话关联
|
||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
|
|
||||||
# 记录成功的登录
|
# 获取用户信息和客户端信息
|
||||||
user_id = getattr(user, "id")
|
user_id = getattr(user, "id")
|
||||||
assert user_id is not None, "User ID should not be None after authentication"
|
assert user_id is not None, "User ID should not be None after authentication"
|
||||||
await LoginLogService.record_login(
|
|
||||||
db=db,
|
from app.dependencies.geoip import get_client_ip
|
||||||
user_id=user_id,
|
ip_address = get_client_ip(request)
|
||||||
request=request,
|
user_agent = request.headers.get("User-Agent", "")
|
||||||
login_success=True,
|
|
||||||
login_method="password",
|
# 获取国家代码
|
||||||
notes=f"OAuth password grant for client {client_id}",
|
geo_info = geoip.lookup(ip_address)
|
||||||
|
country_code = geo_info.get("country_iso", "XX")
|
||||||
|
|
||||||
|
# 检查是否为新位置登录
|
||||||
|
is_new_location = await LoginSessionService.check_new_location(
|
||||||
|
db, user_id, ip_address, country_code
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 创建登录会话记录
|
||||||
|
login_session = await LoginSessionService.create_session(
|
||||||
|
db, redis, user_id, ip_address, user_agent, country_code, is_new_location
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果是新位置登录,需要邮件验证
|
||||||
|
if is_new_location and settings.enable_email_verification:
|
||||||
|
# 刷新用户对象以确保属性已加载
|
||||||
|
await db.refresh(user)
|
||||||
|
|
||||||
|
# 发送邮件验证码
|
||||||
|
verification_sent = await EmailVerificationService.send_verification_email(
|
||||||
|
db, redis, user_id, user.username, user.email, ip_address, user_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录需要二次验证的登录尝试
|
||||||
|
await LoginLogService.record_login(
|
||||||
|
db=db,
|
||||||
|
user_id=user_id,
|
||||||
|
request=request,
|
||||||
|
login_success=True,
|
||||||
|
login_method="password_pending_verification",
|
||||||
|
notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not verification_sent:
|
||||||
|
# 邮件发送失败,记录错误
|
||||||
|
logger.error(f"[Auth] Failed to send email verification code for user {user_id}")
|
||||||
|
elif is_new_location and not settings.enable_email_verification:
|
||||||
|
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
|
||||||
|
await LoginSessionService.mark_session_verified(db, user_id)
|
||||||
|
logger.debug(f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}")
|
||||||
|
else:
|
||||||
|
# 不是新位置登录,正常登录
|
||||||
|
await LoginLogService.record_login(
|
||||||
|
db=db,
|
||||||
|
user_id=user_id,
|
||||||
|
request=request,
|
||||||
|
login_success=True,
|
||||||
|
login_method="password",
|
||||||
|
notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 无论是否新位置登录,都返回正常的token
|
||||||
|
# session_verified状态通过/me接口的session_verified字段来体现
|
||||||
|
|
||||||
# 生成令牌
|
# 生成令牌
|
||||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
|||||||
relationship,
|
relationship,
|
||||||
room,
|
room,
|
||||||
score,
|
score,
|
||||||
|
session_verify,
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
from .router import router as api_v2_router
|
from .router import router as api_v2_router
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from app.database import User, UserResp
|
from app.database import User
|
||||||
from app.database.lazer_user import ALL_INCLUDED
|
from app.database.lazer_user import ALL_INCLUDED
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.dependencies.database import Database
|
from app.dependencies.database import Database
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
from app.models.api_me import APIMe
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
@@ -13,7 +14,7 @@ from fastapi import Path, Security
|
|||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/me/{ruleset}",
|
"/me/{ruleset}",
|
||||||
response_model=UserResp,
|
response_model=APIMe,
|
||||||
name="获取当前用户信息 (指定 ruleset)",
|
name="获取当前用户信息 (指定 ruleset)",
|
||||||
description="获取当前登录用户信息 (含指定 ruleset 统计)。",
|
description="获取当前登录用户信息 (含指定 ruleset 统计)。",
|
||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
@@ -23,17 +24,18 @@ async def get_user_info_with_ruleset(
|
|||||||
ruleset: GameMode = Path(description="指定 ruleset"),
|
ruleset: GameMode = Path(description="指定 ruleset"),
|
||||||
current_user: User = Security(get_current_user, scopes=["identify"]),
|
current_user: User = Security(get_current_user, scopes=["identify"]),
|
||||||
):
|
):
|
||||||
return await UserResp.from_db(
|
user_resp = await APIMe.from_db(
|
||||||
current_user,
|
current_user,
|
||||||
session,
|
session,
|
||||||
ALL_INCLUDED,
|
ALL_INCLUDED,
|
||||||
ruleset,
|
ruleset,
|
||||||
)
|
)
|
||||||
|
return user_resp
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/me/",
|
"/me/",
|
||||||
response_model=UserResp,
|
response_model=APIMe,
|
||||||
name="获取当前用户信息",
|
name="获取当前用户信息",
|
||||||
description="获取当前登录用户信息。",
|
description="获取当前登录用户信息。",
|
||||||
tags=["用户"],
|
tags=["用户"],
|
||||||
@@ -42,9 +44,10 @@ async def get_user_info_default(
|
|||||||
session: Database,
|
session: Database,
|
||||||
current_user: User = Security(get_current_user, scopes=["identify"]),
|
current_user: User = Security(get_current_user, scopes=["identify"]),
|
||||||
):
|
):
|
||||||
return await UserResp.from_db(
|
user_resp = await APIMe.from_db(
|
||||||
current_user,
|
current_user,
|
||||||
session,
|
session,
|
||||||
ALL_INCLUDED,
|
ALL_INCLUDED,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
return user_resp
|
||||||
|
|||||||
204
app/router/v2/session_verify.py
Normal file
204
app/router/v2/session_verify.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""
|
||||||
|
会话验证路由 - 实现类似 osu! 的邮件验证流程 (API v2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from app.auth import authenticate_user
|
||||||
|
from app.config import settings
|
||||||
|
from app.database import User
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.dependencies.database import Database, get_redis
|
||||||
|
from app.dependencies.geoip import GeoIPHelper, get_geoip_helper
|
||||||
|
from app.database.email_verification import EmailVerification, LoginSession
|
||||||
|
from app.service.email_verification_service import (
|
||||||
|
EmailVerificationService,
|
||||||
|
LoginSessionService
|
||||||
|
)
|
||||||
|
from app.service.login_log_service import LoginLogService
|
||||||
|
from app.models.extended_auth import ExtendedTokenResponse
|
||||||
|
|
||||||
|
from fastapi import Form, Depends, Request, HTTPException, status, Security
|
||||||
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from .router import router
|
||||||
|
|
||||||
|
|
||||||
|
class SessionReissueResponse(BaseModel):
|
||||||
|
"""重新发送验证码响应"""
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/session/verify",
|
||||||
|
name="验证会话",
|
||||||
|
description="验证邮件验证码并完成会话认证",
|
||||||
|
status_code=204
|
||||||
|
)
|
||||||
|
async def verify_session(
|
||||||
|
request: Request,
|
||||||
|
db: Database,
|
||||||
|
redis: Annotated[Redis, Depends(get_redis)],
|
||||||
|
verification_key: str = Form(..., description="8位邮件验证码"),
|
||||||
|
current_user: User = Security(get_current_user)
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
验证邮件验证码并完成会话认证
|
||||||
|
|
||||||
|
对应 osu! 的 session/verify 接口
|
||||||
|
成功时返回 204 No Content,失败时返回 401 Unauthorized
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.dependencies.geoip import get_client_ip
|
||||||
|
ip_address = get_client_ip(request)
|
||||||
|
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||||
|
|
||||||
|
# 从当前认证用户获取信息
|
||||||
|
user_id = current_user.id
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="用户未认证"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证邮件验证码
|
||||||
|
success, message = await EmailVerificationService.verify_code(
|
||||||
|
db, redis, user_id, verification_key
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
# 记录成功的邮件验证
|
||||||
|
await LoginLogService.record_login(
|
||||||
|
db=db,
|
||||||
|
user_id=user_id,
|
||||||
|
request=request,
|
||||||
|
login_method="email_verification",
|
||||||
|
login_success=True,
|
||||||
|
notes=f"邮件验证成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回 204 No Content 表示验证成功
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
else:
|
||||||
|
# 记录失败的邮件验证尝试
|
||||||
|
await LoginLogService.record_failed_login(
|
||||||
|
db=db,
|
||||||
|
request=request,
|
||||||
|
attempted_username=current_user.username,
|
||||||
|
login_method="email_verification",
|
||||||
|
notes=f"邮件验证失败: {message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回 401 Unauthorized 表示验证失败
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=message
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="无效的用户会话"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="验证过程中发生错误"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/session/verify/reissue",
|
||||||
|
name="重新发送验证码",
|
||||||
|
description="重新发送邮件验证码",
|
||||||
|
response_model=SessionReissueResponse
|
||||||
|
)
|
||||||
|
async def reissue_verification_code(
|
||||||
|
request: Request,
|
||||||
|
db: Database,
|
||||||
|
redis: Annotated[Redis, Depends(get_redis)],
|
||||||
|
current_user: User = Security(get_current_user)
|
||||||
|
) -> SessionReissueResponse:
|
||||||
|
"""
|
||||||
|
重新发送邮件验证码
|
||||||
|
|
||||||
|
对应 osu! 的 session/verify/reissue 接口
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.dependencies.geoip import get_client_ip
|
||||||
|
ip_address = get_client_ip(request)
|
||||||
|
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||||
|
|
||||||
|
# 从当前认证用户获取信息
|
||||||
|
user_id = current_user.id
|
||||||
|
if not user_id:
|
||||||
|
return SessionReissueResponse(
|
||||||
|
success=False,
|
||||||
|
message="用户未认证"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 重新发送验证码
|
||||||
|
success, message = await EmailVerificationService.resend_verification_code(
|
||||||
|
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
return SessionReissueResponse(
|
||||||
|
success=success,
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
return SessionReissueResponse(
|
||||||
|
success=False,
|
||||||
|
message="无效的用户会话"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return SessionReissueResponse(
|
||||||
|
success=False,
|
||||||
|
message="重新发送过程中发生错误"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/session/check-new-location",
|
||||||
|
name="检查新位置登录",
|
||||||
|
description="检查登录是否来自新位置(内部接口)"
|
||||||
|
)
|
||||||
|
async def check_new_location(
|
||||||
|
request: Request,
|
||||||
|
db: Database,
|
||||||
|
user_id: int,
|
||||||
|
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
检查是否为新位置登录
|
||||||
|
这是一个内部接口,用于登录流程中判断是否需要邮件验证
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.dependencies.geoip import get_client_ip
|
||||||
|
ip_address = get_client_ip(request)
|
||||||
|
geo_info = geoip.lookup(ip_address)
|
||||||
|
country_code = geo_info.get("country_iso", "XX")
|
||||||
|
|
||||||
|
is_new_location = await LoginSessionService.check_new_location(
|
||||||
|
db, user_id, ip_address, country_code
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"is_new_location": is_new_location,
|
||||||
|
"ip_address": ip_address,
|
||||||
|
"country_code": country_code
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"is_new_location": True, # 出错时默认为新位置
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
118
app/scheduler/database_cleanup_scheduler.py
Normal file
118
app/scheduler/database_cleanup_scheduler.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""
|
||||||
|
数据库清理调度器 - 定时清理过期数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.dependencies.database import engine
|
||||||
|
from app.log import logger
|
||||||
|
from app.service.database_cleanup_service import DatabaseCleanupService
|
||||||
|
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseCleanupScheduler:
|
||||||
|
"""数据库清理调度器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.running = False
|
||||||
|
self.task = None
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""启动调度器"""
|
||||||
|
if self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
self.task = asyncio.create_task(self._run_scheduler())
|
||||||
|
logger.debug("Database cleanup scheduler started")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止调度器"""
|
||||||
|
if not self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = False
|
||||||
|
if self.task:
|
||||||
|
self.task.cancel()
|
||||||
|
try:
|
||||||
|
await self.task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.debug("Database cleanup scheduler stopped")
|
||||||
|
|
||||||
|
async def _run_scheduler(self):
|
||||||
|
"""运行调度器"""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# 每小时运行一次清理
|
||||||
|
await asyncio.sleep(3600) # 3600秒 = 1小时
|
||||||
|
|
||||||
|
if not self.running:
|
||||||
|
break
|
||||||
|
|
||||||
|
await self._run_cleanup()
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database cleanup scheduler error: {str(e)}")
|
||||||
|
# 发生错误后等待5分钟再继续
|
||||||
|
await asyncio.sleep(300)
|
||||||
|
|
||||||
|
async def _run_cleanup(self):
|
||||||
|
"""执行清理任务"""
|
||||||
|
try:
|
||||||
|
async with AsyncSession(engine) as db:
|
||||||
|
logger.debug("Starting scheduled database cleanup...")
|
||||||
|
|
||||||
|
# 清理过期的验证码
|
||||||
|
expired_codes = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
|
||||||
|
|
||||||
|
# 清理过期的登录会话
|
||||||
|
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||||
|
|
||||||
|
# 只在有清理记录时输出总结
|
||||||
|
total_cleaned = expired_codes + expired_sessions
|
||||||
|
if total_cleaned > 0:
|
||||||
|
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during scheduled database cleanup: {str(e)}")
|
||||||
|
|
||||||
|
async def run_manual_cleanup(self):
|
||||||
|
"""手动运行完整清理"""
|
||||||
|
try:
|
||||||
|
async with AsyncSession(engine) as db:
|
||||||
|
logger.debug("Starting manual database cleanup...")
|
||||||
|
results = await DatabaseCleanupService.run_full_cleanup(db)
|
||||||
|
total = sum(results.values())
|
||||||
|
if total > 0:
|
||||||
|
logger.debug(f"Manual cleanup completed, total records cleaned: {total}")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during manual database cleanup: {str(e)}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局实例
|
||||||
|
database_cleanup_scheduler = DatabaseCleanupScheduler()
|
||||||
|
|
||||||
|
|
||||||
|
async def start_database_cleanup_scheduler():
|
||||||
|
"""启动数据库清理调度器"""
|
||||||
|
await database_cleanup_scheduler.start()
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_database_cleanup_scheduler():
|
||||||
|
"""停止数据库清理调度器"""
|
||||||
|
await database_cleanup_scheduler.stop()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_manual_database_cleanup():
|
||||||
|
"""手动运行数据库清理"""
|
||||||
|
return await database_cleanup_scheduler.run_manual_cleanup()
|
||||||
289
app/service/database_cleanup_service.py
Normal file
289
app/service/database_cleanup_service.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
"""
|
||||||
|
数据库清理服务 - 清理过期的验证码和会话
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, UTC, timedelta
|
||||||
|
|
||||||
|
from app.database.email_verification import EmailVerification, LoginSession
|
||||||
|
from app.log import logger
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
from sqlalchemy import and_
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseCleanupService:
|
||||||
|
"""数据库清理服务"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def cleanup_expired_verification_codes(db: AsyncSession) -> int:
|
||||||
|
"""
|
||||||
|
清理过期的邮件验证码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 清理的记录数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 查找过期的验证码记录
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
stmt = select(EmailVerification).where(
|
||||||
|
EmailVerification.expires_at < current_time
|
||||||
|
)
|
||||||
|
result = await db.exec(stmt)
|
||||||
|
expired_codes = result.all()
|
||||||
|
|
||||||
|
# 删除过期的记录
|
||||||
|
deleted_count = 0
|
||||||
|
for code in expired_codes:
|
||||||
|
await db.delete(code)
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes")
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {str(e)}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def cleanup_expired_login_sessions(db: AsyncSession) -> int:
|
||||||
|
"""
|
||||||
|
清理过期的登录会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 清理的记录数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 查找过期的登录会话记录
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
stmt = select(LoginSession).where(
|
||||||
|
LoginSession.expires_at < current_time
|
||||||
|
)
|
||||||
|
result = await db.exec(stmt)
|
||||||
|
expired_sessions = result.all()
|
||||||
|
|
||||||
|
# 删除过期的记录
|
||||||
|
deleted_count = 0
|
||||||
|
for session in expired_sessions:
|
||||||
|
await db.delete(session)
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions")
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {str(e)}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int:
|
||||||
|
"""
|
||||||
|
清理旧的已使用验证码记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
days_old: 清理多少天前的已使用记录,默认7天
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 清理的记录数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 查找指定天数前的已使用验证码记录
|
||||||
|
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||||
|
|
||||||
|
stmt = select(EmailVerification).where(
|
||||||
|
EmailVerification.is_used == True
|
||||||
|
)
|
||||||
|
result = await db.exec(stmt)
|
||||||
|
all_used_codes = result.all()
|
||||||
|
|
||||||
|
# 筛选出过期的记录
|
||||||
|
old_used_codes = [
|
||||||
|
code for code in all_used_codes
|
||||||
|
if code.used_at and code.used_at < cutoff_time
|
||||||
|
]
|
||||||
|
|
||||||
|
# 删除旧的已使用记录
|
||||||
|
deleted_count = 0
|
||||||
|
for code in old_used_codes:
|
||||||
|
await db.delete(code)
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days")
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {str(e)}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
|
||||||
|
"""
|
||||||
|
清理旧的已验证会话记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
days_old: 清理多少天前的已验证记录,默认30天
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 清理的记录数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 查找指定天数前的已验证会话记录
|
||||||
|
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||||
|
|
||||||
|
stmt = select(LoginSession).where(
|
||||||
|
LoginSession.is_verified == True
|
||||||
|
)
|
||||||
|
result = await db.exec(stmt)
|
||||||
|
all_verified_sessions = result.all()
|
||||||
|
|
||||||
|
# 筛选出过期的记录
|
||||||
|
old_verified_sessions = [
|
||||||
|
session for session in all_verified_sessions
|
||||||
|
if session.verified_at and session.verified_at < cutoff_time
|
||||||
|
]
|
||||||
|
|
||||||
|
# 删除旧的已验证记录
|
||||||
|
deleted_count = 0
|
||||||
|
for session in old_verified_sessions:
|
||||||
|
await db.delete(session)
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days")
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {str(e)}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def run_full_cleanup(db: AsyncSession) -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
运行完整的清理流程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 各项清理的结果统计
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# 清理过期的验证码
|
||||||
|
results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
|
||||||
|
|
||||||
|
# 清理过期的登录会话
|
||||||
|
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||||
|
|
||||||
|
# 清理7天前的已使用验证码
|
||||||
|
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
|
||||||
|
|
||||||
|
# 清理30天前的已验证会话
|
||||||
|
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
|
||||||
|
|
||||||
|
total_cleaned = sum(results.values())
|
||||||
|
if total_cleaned > 0:
|
||||||
|
logger.debug(f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
获取清理统计信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 统计信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
cutoff_7_days = current_time - timedelta(days=7)
|
||||||
|
cutoff_30_days = current_time - timedelta(days=30)
|
||||||
|
|
||||||
|
# 统计过期的验证码数量
|
||||||
|
expired_codes_stmt = select(EmailVerification).where(
|
||||||
|
EmailVerification.expires_at < current_time
|
||||||
|
)
|
||||||
|
expired_codes_result = await db.exec(expired_codes_stmt)
|
||||||
|
expired_codes_count = len(expired_codes_result.all())
|
||||||
|
|
||||||
|
# 统计过期的登录会话数量
|
||||||
|
expired_sessions_stmt = select(LoginSession).where(
|
||||||
|
LoginSession.expires_at < current_time
|
||||||
|
)
|
||||||
|
expired_sessions_result = await db.exec(expired_sessions_stmt)
|
||||||
|
expired_sessions_count = len(expired_sessions_result.all())
|
||||||
|
|
||||||
|
# 统计7天前的已使用验证码数量
|
||||||
|
old_used_codes_stmt = select(EmailVerification).where(
|
||||||
|
EmailVerification.is_used == True
|
||||||
|
)
|
||||||
|
old_used_codes_result = await db.exec(old_used_codes_stmt)
|
||||||
|
all_used_codes = old_used_codes_result.all()
|
||||||
|
old_used_codes_count = len([
|
||||||
|
code for code in all_used_codes
|
||||||
|
if code.used_at and code.used_at < cutoff_7_days
|
||||||
|
])
|
||||||
|
|
||||||
|
# 统计30天前的已验证会话数量
|
||||||
|
old_verified_sessions_stmt = select(LoginSession).where(
|
||||||
|
LoginSession.is_verified == True
|
||||||
|
)
|
||||||
|
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
|
||||||
|
all_verified_sessions = old_verified_sessions_result.all()
|
||||||
|
old_verified_sessions_count = len([
|
||||||
|
session for session in all_verified_sessions
|
||||||
|
if session.verified_at and session.verified_at < cutoff_30_days
|
||||||
|
])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"expired_verification_codes": expired_codes_count,
|
||||||
|
"expired_login_sessions": expired_sessions_count,
|
||||||
|
"old_used_verification_codes": old_used_codes_count,
|
||||||
|
"old_verified_sessions": old_verified_sessions_count,
|
||||||
|
"total_cleanable": expired_codes_count + expired_sessions_count + old_used_codes_count + old_verified_sessions_count
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {str(e)}")
|
||||||
|
return {
|
||||||
|
"expired_verification_codes": 0,
|
||||||
|
"expired_login_sessions": 0,
|
||||||
|
"old_used_verification_codes": 0,
|
||||||
|
"old_verified_sessions": 0,
|
||||||
|
"total_cleanable": 0
|
||||||
|
}
|
||||||
167
app/service/email_service.py
Normal file
167
app/service/email_service.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
邮件验证服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import smtplib
|
||||||
|
from email.mime.text import MIMEText
|
||||||
|
from email.mime.multipart import MIMEMultipart
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
from datetime import datetime, UTC, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
class EmailService:
|
||||||
|
"""邮件发送服务"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
|
||||||
|
self.smtp_port = getattr(settings, 'smtp_port', 587)
|
||||||
|
self.smtp_username = getattr(settings, 'smtp_username', '')
|
||||||
|
self.smtp_password = getattr(settings, 'smtp_password', '')
|
||||||
|
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
|
||||||
|
self.from_name = getattr(settings, 'from_name', 'osu! server')
|
||||||
|
|
||||||
|
def generate_verification_code(self) -> str:
|
||||||
|
"""生成8位验证码"""
|
||||||
|
# 只使用数字,避免混淆
|
||||||
|
return ''.join(secrets.choice(string.digits) for _ in range(8))
|
||||||
|
|
||||||
|
async def send_verification_email(self, email: str, code: str, username: str) -> bool:
|
||||||
|
"""发送验证邮件"""
|
||||||
|
try:
|
||||||
|
msg = MIMEMultipart()
|
||||||
|
msg['From'] = f"{self.from_name} <{self.from_email}>"
|
||||||
|
msg['To'] = email
|
||||||
|
msg['Subject'] = "邮箱验证 - Email Verification"
|
||||||
|
|
||||||
|
# HTML 邮件内容
|
||||||
|
html_content = f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<style>
|
||||||
|
.container {{
|
||||||
|
max-width: 600px;
|
||||||
|
margin: 0 auto;
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
line-height: 1.6;
|
||||||
|
}}
|
||||||
|
.header {{
|
||||||
|
background: linear-gradient(135deg, #ff66aa, #ff9966);
|
||||||
|
color: white;
|
||||||
|
padding: 20px;
|
||||||
|
text-align: center;
|
||||||
|
border-radius: 10px 10px 0 0;
|
||||||
|
}}
|
||||||
|
.content {{
|
||||||
|
background: #f9f9f9;
|
||||||
|
padding: 30px;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
}}
|
||||||
|
.code {{
|
||||||
|
background: #fff;
|
||||||
|
border: 2px solid #ff66aa;
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 15px;
|
||||||
|
text-align: center;
|
||||||
|
font-size: 24px;
|
||||||
|
font-weight: bold;
|
||||||
|
letter-spacing: 3px;
|
||||||
|
margin: 20px 0;
|
||||||
|
color: #333;
|
||||||
|
}}
|
||||||
|
.footer {{
|
||||||
|
background: #333;
|
||||||
|
color: #fff;
|
||||||
|
padding: 15px;
|
||||||
|
text-align: center;
|
||||||
|
border-radius: 0 0 10px 10px;
|
||||||
|
font-size: 12px;
|
||||||
|
}}
|
||||||
|
.warning {{
|
||||||
|
background: #fff3cd;
|
||||||
|
border: 1px solid #ffeaa7;
|
||||||
|
border-radius: 5px;
|
||||||
|
padding: 10px;
|
||||||
|
margin: 15px 0;
|
||||||
|
color: #856404;
|
||||||
|
}}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="header">
|
||||||
|
<h1> osu! 邮箱验证</h1>
|
||||||
|
<p>Email Verification</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="content">
|
||||||
|
<h2>你好 {username}!</h2>
|
||||||
|
<p>感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:</p>
|
||||||
|
|
||||||
|
<div class="code">{code}</div>
|
||||||
|
|
||||||
|
<p>这个验证码将在 <strong>10 分钟后过期</strong>。</p>
|
||||||
|
|
||||||
|
<div class="warning">
|
||||||
|
<strong>注意:</strong>
|
||||||
|
<ul>
|
||||||
|
<li>请不要与任何人分享这个验证码</li>
|
||||||
|
<li>如果你没有请求此验证码,请忽略这封邮件</li>
|
||||||
|
<li>验证码只能使用一次</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p>如果你有任何问题,请联系我们的支持团队。</p>
|
||||||
|
|
||||||
|
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
|
||||||
|
|
||||||
|
<h3>Hello {username}!</h3>
|
||||||
|
<p>Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:</p>
|
||||||
|
|
||||||
|
<p>This verification code will expire in <strong>10 minutes</strong>.</p>
|
||||||
|
|
||||||
|
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="footer">
|
||||||
|
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
|
||||||
|
<p>This email was sent automatically, please do not reply.</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
|
||||||
|
|
||||||
|
# 发送邮件
|
||||||
|
if not settings.enable_email_sending:
|
||||||
|
# 邮件发送功能禁用时只记录日志,不实际发送
|
||||||
|
logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
|
||||||
|
if self.smtp_username and self.smtp_password:
|
||||||
|
server.starttls()
|
||||||
|
server.login(self.smtp_username, self.smtp_password)
|
||||||
|
|
||||||
|
server.send_message(msg)
|
||||||
|
|
||||||
|
logger.info(f"[Email Verification] Successfully sent verification code to {email}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Email Verification] Failed to send email: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# 全局邮件服务实例
|
||||||
|
email_service = EmailService()
|
||||||
367
app/service/email_verification_service.py
Normal file
367
app/service/email_verification_service.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
"""
|
||||||
|
邮件验证管理服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
from datetime import datetime, UTC, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.database.email_verification import EmailVerification, LoginSession
|
||||||
|
from app.service.email_service import email_service
|
||||||
|
from app.log import logger
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
from sqlmodel import select
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
|
||||||
|
class EmailVerificationService:
|
||||||
|
"""邮件验证服务"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_verification_code() -> str:
|
||||||
|
"""生成8位验证码"""
|
||||||
|
return ''.join(secrets.choice(string.digits) for _ in range(8))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_session_token() -> str:
|
||||||
|
"""生成会话令牌"""
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_verification_record(
|
||||||
|
db: AsyncSession,
|
||||||
|
redis: Redis,
|
||||||
|
user_id: int,
|
||||||
|
email: str,
|
||||||
|
ip_address: str | None = None,
|
||||||
|
user_agent: str | None = None
|
||||||
|
) -> tuple[EmailVerification, str]:
|
||||||
|
"""创建邮件验证记录"""
|
||||||
|
|
||||||
|
# 检查是否有未过期的验证码
|
||||||
|
existing_result = await db.exec(
|
||||||
|
select(EmailVerification).where(
|
||||||
|
EmailVerification.user_id == user_id,
|
||||||
|
EmailVerification.is_used == False,
|
||||||
|
EmailVerification.expires_at > datetime.now(UTC)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing = existing_result.first()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# 如果有未过期的验证码,直接返回
|
||||||
|
return existing, existing.verification_code
|
||||||
|
|
||||||
|
# 生成新的验证码
|
||||||
|
code = EmailVerificationService.generate_verification_code()
|
||||||
|
|
||||||
|
# 创建验证记录
|
||||||
|
verification = EmailVerification(
|
||||||
|
user_id=user_id,
|
||||||
|
email=email,
|
||||||
|
verification_code=code,
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期
|
||||||
|
ip_address=ip_address,
|
||||||
|
user_agent=user_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(verification)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(verification)
|
||||||
|
|
||||||
|
# 存储到 Redis(用于快速验证)
|
||||||
|
await redis.setex(
|
||||||
|
f"email_verification:{user_id}:{code}",
|
||||||
|
600, # 10分钟过期
|
||||||
|
str(verification.id) if verification.id else "0"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}")
|
||||||
|
return verification, code
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def send_verification_email(
|
||||||
|
db: AsyncSession,
|
||||||
|
redis: Redis,
|
||||||
|
user_id: int,
|
||||||
|
username: str,
|
||||||
|
email: str,
|
||||||
|
ip_address: str | None = None,
|
||||||
|
user_agent: str | None = None
|
||||||
|
) -> bool:
|
||||||
|
"""发送验证邮件"""
|
||||||
|
try:
|
||||||
|
# 检查是否启用邮件验证功能
|
||||||
|
if not settings.enable_email_verification:
|
||||||
|
logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
|
||||||
|
return True # 返回成功,但不执行验证流程
|
||||||
|
|
||||||
|
# 创建验证记录
|
||||||
|
verification, code = await EmailVerificationService.create_verification_record(
|
||||||
|
db, redis, user_id, email, ip_address, user_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发送邮件
|
||||||
|
success = await email_service.send_verification_email(email, code, username)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(f"[Email Verification] Successfully sent verification email to {email} (user: {username})")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"[Email Verification] Failed to send verification email: {email} (user: {username})")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def verify_code(
|
||||||
|
db: AsyncSession,
|
||||||
|
redis: Redis,
|
||||||
|
user_id: int,
|
||||||
|
code: str,
|
||||||
|
ip_address: str | None = None
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""验证验证码"""
|
||||||
|
try:
|
||||||
|
# 检查是否启用邮件验证功能
|
||||||
|
if not settings.enable_email_verification:
|
||||||
|
logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}")
|
||||||
|
# 仍然标记登录会话为已验证
|
||||||
|
await LoginSessionService.mark_session_verified(db, user_id)
|
||||||
|
return True, "验证成功(邮件验证功能已禁用)"
|
||||||
|
|
||||||
|
# 先从 Redis 检查
|
||||||
|
verification_id = await redis.get(f"email_verification:{user_id}:{code}")
|
||||||
|
if not verification_id:
|
||||||
|
return False, "验证码无效或已过期"
|
||||||
|
|
||||||
|
# 从数据库获取验证记录
|
||||||
|
result = await db.exec(
|
||||||
|
select(EmailVerification).where(
|
||||||
|
EmailVerification.id == int(verification_id),
|
||||||
|
EmailVerification.user_id == user_id,
|
||||||
|
EmailVerification.verification_code == code,
|
||||||
|
EmailVerification.is_used == False,
|
||||||
|
EmailVerification.expires_at > datetime.now(UTC)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
verification = result.first()
|
||||||
|
if not verification:
|
||||||
|
return False, "验证码无效或已过期"
|
||||||
|
|
||||||
|
# 标记为已使用
|
||||||
|
verification.is_used = True
|
||||||
|
verification.used_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
# 同时更新对应的登录会话状态
|
||||||
|
await LoginSessionService.mark_session_verified(db, user_id)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 删除 Redis 记录
|
||||||
|
await redis.delete(f"email_verification:{user_id}:{code}")
|
||||||
|
|
||||||
|
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
|
||||||
|
return True, "验证成功"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Email Verification] Exception during verification code validation: {e}")
|
||||||
|
return False, "验证过程中发生错误"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def resend_verification_code(
|
||||||
|
db: AsyncSession,
|
||||||
|
redis: Redis,
|
||||||
|
user_id: int,
|
||||||
|
username: str,
|
||||||
|
email: str,
|
||||||
|
ip_address: str | None = None,
|
||||||
|
user_agent: str | None = None
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""重新发送验证码"""
|
||||||
|
try:
|
||||||
|
# 检查是否启用邮件验证功能
|
||||||
|
if not settings.enable_email_verification:
|
||||||
|
logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
|
||||||
|
return True, "验证码已发送(邮件验证功能已禁用)"
|
||||||
|
|
||||||
|
# 检查重发频率限制(60秒内只能发送一次)
|
||||||
|
rate_limit_key = f"email_verification_rate_limit:{user_id}"
|
||||||
|
if await redis.get(rate_limit_key):
|
||||||
|
return False, "请等待60秒后再重新发送"
|
||||||
|
|
||||||
|
# 设置频率限制
|
||||||
|
await redis.setex(rate_limit_key, 60, "1")
|
||||||
|
|
||||||
|
# 生成新的验证码
|
||||||
|
success = await EmailVerificationService.send_verification_email(
|
||||||
|
db, redis, user_id, username, email, ip_address, user_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return True, "验证码已重新发送"
|
||||||
|
else:
|
||||||
|
return False, "重新发送失败,请稍后再试"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Email Verification] Exception during resending verification code: {e}")
|
||||||
|
return False, "重新发送过程中发生错误"
|
||||||
|
|
||||||
|
|
||||||
|
class LoginSessionService:
|
||||||
|
"""登录会话服务"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_session(
|
||||||
|
db: AsyncSession,
|
||||||
|
redis: Redis,
|
||||||
|
user_id: int,
|
||||||
|
ip_address: str,
|
||||||
|
user_agent: str | None = None,
|
||||||
|
country_code: str | None = None,
|
||||||
|
is_new_location: bool = False
|
||||||
|
) -> LoginSession:
|
||||||
|
"""创建登录会话"""
|
||||||
|
session_token = EmailVerificationService.generate_session_token()
|
||||||
|
|
||||||
|
session = LoginSession(
|
||||||
|
user_id=user_id,
|
||||||
|
session_token=session_token,
|
||||||
|
ip_address=ip_address,
|
||||||
|
user_agent=user_agent,
|
||||||
|
country_code=country_code,
|
||||||
|
is_new_location=is_new_location,
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期
|
||||||
|
is_verified=not is_new_location # 新位置需要验证
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(session)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(session)
|
||||||
|
|
||||||
|
# 存储到 Redis
|
||||||
|
await redis.setex(
|
||||||
|
f"login_session:{session_token}",
|
||||||
|
86400, # 24小时
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
|
||||||
|
return session
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def verify_session(
|
||||||
|
db: AsyncSession,
|
||||||
|
redis: Redis,
|
||||||
|
session_token: str,
|
||||||
|
verification_code: str
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""验证会话(通过邮件验证码)"""
|
||||||
|
try:
|
||||||
|
# 从 Redis 获取用户ID
|
||||||
|
user_id = await redis.get(f"login_session:{session_token}")
|
||||||
|
if not user_id:
|
||||||
|
return False, "会话无效或已过期"
|
||||||
|
|
||||||
|
user_id = int(user_id)
|
||||||
|
|
||||||
|
# 验证邮件验证码
|
||||||
|
success, message = await EmailVerificationService.verify_code(
|
||||||
|
db, redis, user_id, verification_code
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return False, message
|
||||||
|
|
||||||
|
# 更新会话状态
|
||||||
|
result = await db.exec(
|
||||||
|
select(LoginSession).where(
|
||||||
|
LoginSession.session_token == session_token,
|
||||||
|
LoginSession.user_id == user_id,
|
||||||
|
LoginSession.is_verified == False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
session = result.first()
|
||||||
|
if session:
|
||||||
|
session.is_verified = True
|
||||||
|
session.verified_at = datetime.now(UTC)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
logger.info(f"[Login Session] User {user_id} session verification successful")
|
||||||
|
return True, "会话验证成功"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Login Session] Exception during session verification: {e}")
|
||||||
|
return False, "验证过程中发生错误"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_new_location(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: int,
|
||||||
|
ip_address: str,
|
||||||
|
country_code: str | None = None
|
||||||
|
) -> bool:
|
||||||
|
"""检查是否为新位置登录"""
|
||||||
|
try:
|
||||||
|
# 查看过去30天内是否有相同IP或相同国家的登录记录
|
||||||
|
thirty_days_ago = datetime.now(UTC) - timedelta(days=30)
|
||||||
|
|
||||||
|
result = await db.exec(
|
||||||
|
select(LoginSession).where(
|
||||||
|
LoginSession.user_id == user_id,
|
||||||
|
LoginSession.created_at > thirty_days_ago,
|
||||||
|
(LoginSession.ip_address == ip_address) |
|
||||||
|
(LoginSession.country_code == country_code)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_sessions = result.all()
|
||||||
|
|
||||||
|
# 如果有历史记录,则不是新位置
|
||||||
|
return len(existing_sessions) == 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Login Session] Exception during new location check: {e}")
|
||||||
|
# 出错时默认为新位置(更安全)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def mark_session_verified(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: int
|
||||||
|
) -> bool:
|
||||||
|
"""标记用户的未验证会话为已验证"""
|
||||||
|
try:
|
||||||
|
# 查找用户所有未验证且未过期的会话
|
||||||
|
result = await db.exec(
|
||||||
|
select(LoginSession).where(
|
||||||
|
LoginSession.user_id == user_id,
|
||||||
|
LoginSession.is_verified == False,
|
||||||
|
LoginSession.expires_at > datetime.now(UTC)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sessions = result.all()
|
||||||
|
|
||||||
|
# 标记所有会话为已验证
|
||||||
|
for session in sessions:
|
||||||
|
session.is_verified = True
|
||||||
|
session.verified_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
if sessions:
|
||||||
|
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
|
||||||
|
|
||||||
|
return len(sessions) > 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
|
||||||
|
return False
|
||||||
@@ -65,6 +65,17 @@ class RedisMessageSystem:
|
|||||||
if not user.id:
|
if not user.id:
|
||||||
raise ValueError("User ID is required")
|
raise ValueError("User ID is required")
|
||||||
|
|
||||||
|
# 获取频道类型以判断是否需要存储到数据库
|
||||||
|
async with with_db() as session:
|
||||||
|
from app.database.chat import ChatChannel, ChannelType
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
channel_result = await session.exec(
|
||||||
|
select(ChatChannel.type).where(ChatChannel.channel_id == channel_id)
|
||||||
|
)
|
||||||
|
channel_type = channel_result.first()
|
||||||
|
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
|
||||||
|
|
||||||
# 准备消息数据
|
# 准备消息数据
|
||||||
message_data = {
|
message_data = {
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
@@ -76,6 +87,7 @@ class RedisMessageSystem:
|
|||||||
"uuid": user_uuid or "",
|
"uuid": user_uuid or "",
|
||||||
"status": "cached", # Redis 缓存状态
|
"status": "cached", # Redis 缓存状态
|
||||||
"created_at": time.time(),
|
"created_at": time.time(),
|
||||||
|
"is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
|
||||||
}
|
}
|
||||||
|
|
||||||
# 立即存储到 Redis
|
# 立即存储到 Redis
|
||||||
@@ -118,9 +130,14 @@ class RedisMessageSystem:
|
|||||||
uuid=user_uuid,
|
uuid=user_uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
if is_multiplayer:
|
||||||
f"Message {message_id} sent to Redis cache for channel {channel_id}"
|
logger.info(
|
||||||
)
|
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}, will not be persisted to database"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Message {message_id} sent to Redis cache for channel {channel_id}"
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def get_messages(
|
async def get_messages(
|
||||||
@@ -222,6 +239,9 @@ class RedisMessageSystem:
|
|||||||
):
|
):
|
||||||
"""存储消息到 Redis"""
|
"""存储消息到 Redis"""
|
||||||
try:
|
try:
|
||||||
|
# 检查是否是多人房间消息
|
||||||
|
is_multiplayer = message_data.get("is_multiplayer", False)
|
||||||
|
|
||||||
# 存储消息数据
|
# 存储消息数据
|
||||||
await self._redis_exec(
|
await self._redis_exec(
|
||||||
self.redis.hset,
|
self.redis.hset,
|
||||||
@@ -267,10 +287,14 @@ class RedisMessageSystem:
|
|||||||
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
|
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
|
||||||
)
|
)
|
||||||
|
|
||||||
# 添加到待持久化队列
|
# 只有非多人房间消息才添加到待持久化队列
|
||||||
await self._redis_exec(
|
if not is_multiplayer:
|
||||||
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
|
await self._redis_exec(
|
||||||
)
|
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
|
||||||
|
)
|
||||||
|
logger.debug(f"Message {message_id} added to persistence queue")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to store message to Redis: {e}")
|
logger.error(f"Failed to store message to Redis: {e}")
|
||||||
@@ -475,6 +499,19 @@ class RedisMessageSystem:
|
|||||||
v = v.decode("utf-8")
|
v = v.decode("utf-8")
|
||||||
message_data[k] = v
|
message_data[k] = v
|
||||||
|
|
||||||
|
# 检查是否是多人房间消息,如果是则跳过数据库存储
|
||||||
|
is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
|
||||||
|
if is_multiplayer:
|
||||||
|
# 多人房间消息不存储到数据库,直接标记为已跳过
|
||||||
|
await self._redis_exec(
|
||||||
|
self.redis.hset,
|
||||||
|
f"msg:{channel_id}:{message_id}",
|
||||||
|
"status",
|
||||||
|
"skipped_multiplayer",
|
||||||
|
)
|
||||||
|
logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
|
||||||
|
continue
|
||||||
|
|
||||||
# 检查消息是否已存在于数据库
|
# 检查消息是否已存在于数据库
|
||||||
existing = await session.get(ChatMessage, int(message_id))
|
existing = await session.get(ChatMessage, int(message_id))
|
||||||
if existing:
|
if existing:
|
||||||
|
|||||||
121
app/service/session_manager.py
Normal file
121
app/service/session_manager.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
API 状态管理 - 模拟 osu! 的 APIState 和会话管理
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class APIState(str, Enum):
|
||||||
|
"""API 连接状态,对应 osu! 的 APIState"""
|
||||||
|
OFFLINE = "offline"
|
||||||
|
CONNECTING = "connecting"
|
||||||
|
REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证
|
||||||
|
ONLINE = "online"
|
||||||
|
FAILING = "failing"
|
||||||
|
|
||||||
|
|
||||||
|
class UserSession(BaseModel):
|
||||||
|
"""用户会话信息"""
|
||||||
|
user_id: int
|
||||||
|
username: str
|
||||||
|
email: str
|
||||||
|
session_token: str | None = None
|
||||||
|
state: APIState = APIState.OFFLINE
|
||||||
|
requires_verification: bool = False
|
||||||
|
verification_sent: bool = False
|
||||||
|
last_verification_attempt: datetime | None = None
|
||||||
|
failed_attempts: int = 0
|
||||||
|
ip_address: str | None = None
|
||||||
|
country_code: str | None = None
|
||||||
|
is_new_location: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
"""会话管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._sessions: dict[str, UserSession] = {}
|
||||||
|
|
||||||
|
def create_session(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
username: str,
|
||||||
|
email: str,
|
||||||
|
ip_address: str,
|
||||||
|
country_code: str | None = None,
|
||||||
|
is_new_location: bool = False
|
||||||
|
) -> UserSession:
|
||||||
|
"""创建新的用户会话"""
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
session_token = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
# 根据是否为新位置决定初始状态
|
||||||
|
if is_new_location:
|
||||||
|
state = APIState.REQUIRES_SECOND_FACTOR_AUTH
|
||||||
|
else:
|
||||||
|
state = APIState.ONLINE
|
||||||
|
|
||||||
|
session = UserSession(
|
||||||
|
user_id=user_id,
|
||||||
|
username=username,
|
||||||
|
email=email,
|
||||||
|
session_token=session_token,
|
||||||
|
state=state,
|
||||||
|
requires_verification=is_new_location,
|
||||||
|
ip_address=ip_address,
|
||||||
|
country_code=country_code,
|
||||||
|
is_new_location=is_new_location
|
||||||
|
)
|
||||||
|
|
||||||
|
self._sessions[session_token] = session
|
||||||
|
return session
|
||||||
|
|
||||||
|
def get_session(self, session_token: str) -> UserSession | None:
|
||||||
|
"""获取会话"""
|
||||||
|
return self._sessions.get(session_token)
|
||||||
|
|
||||||
|
def update_session_state(self, session_token: str, state: APIState):
|
||||||
|
"""更新会话状态"""
|
||||||
|
if session_token in self._sessions:
|
||||||
|
self._sessions[session_token].state = state
|
||||||
|
|
||||||
|
def mark_verification_sent(self, session_token: str):
|
||||||
|
"""标记验证邮件已发送"""
|
||||||
|
if session_token in self._sessions:
|
||||||
|
session = self._sessions[session_token]
|
||||||
|
session.verification_sent = True
|
||||||
|
session.last_verification_attempt = datetime.now()
|
||||||
|
|
||||||
|
def increment_failed_attempts(self, session_token: str):
|
||||||
|
"""增加失败尝试次数"""
|
||||||
|
if session_token in self._sessions:
|
||||||
|
self._sessions[session_token].failed_attempts += 1
|
||||||
|
|
||||||
|
def verify_session(self, session_token: str) -> bool:
|
||||||
|
"""验证会话成功"""
|
||||||
|
if session_token in self._sessions:
|
||||||
|
session = self._sessions[session_token]
|
||||||
|
session.state = APIState.ONLINE
|
||||||
|
session.requires_verification = False
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def remove_session(self, session_token: str):
|
||||||
|
"""移除会话"""
|
||||||
|
self._sessions.pop(session_token, None)
|
||||||
|
|
||||||
|
def cleanup_expired_sessions(self):
|
||||||
|
"""清理过期会话"""
|
||||||
|
# 这里可以实现清理逻辑
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# 全局会话管理器
|
||||||
|
session_manager = SessionManager()
|
||||||
3
main.py
3
main.py
@@ -23,6 +23,7 @@ from app.router import (
|
|||||||
)
|
)
|
||||||
from app.router.redirect import redirect_router
|
from app.router.redirect import redirect_router
|
||||||
from app.scheduler.cache_scheduler import start_cache_scheduler, stop_cache_scheduler
|
from app.scheduler.cache_scheduler import start_cache_scheduler, stop_cache_scheduler
|
||||||
|
from app.scheduler.database_cleanup_scheduler import start_database_cleanup_scheduler, stop_database_cleanup_scheduler
|
||||||
from app.service.beatmap_download_service import download_service
|
from app.service.beatmap_download_service import download_service
|
||||||
from app.service.calculate_all_user_rank import calculate_user_rank
|
from app.service.calculate_all_user_rank import calculate_user_rank
|
||||||
from app.service.create_banchobot import create_banchobot
|
from app.service.create_banchobot import create_banchobot
|
||||||
@@ -79,6 +80,7 @@ async def lifespan(app: FastAPI):
|
|||||||
await create_banchobot()
|
await create_banchobot()
|
||||||
await download_service.start_health_check() # 启动下载服务健康检查
|
await download_service.start_health_check() # 启动下载服务健康检查
|
||||||
await start_cache_scheduler() # 启动缓存调度器
|
await start_cache_scheduler() # 启动缓存调度器
|
||||||
|
await start_database_cleanup_scheduler() # 启动数据库清理调度器
|
||||||
redis_message_system.start() # 启动 Redis 消息系统
|
redis_message_system.start() # 启动 Redis 消息系统
|
||||||
start_stats_scheduler() # 启动统计调度器
|
start_stats_scheduler() # 启动统计调度器
|
||||||
load_achievements()
|
load_achievements()
|
||||||
@@ -88,6 +90,7 @@ async def lifespan(app: FastAPI):
|
|||||||
redis_message_system.stop() # 停止 Redis 消息系统
|
redis_message_system.stop() # 停止 Redis 消息系统
|
||||||
stop_stats_scheduler() # 停止统计调度器
|
stop_stats_scheduler() # 停止统计调度器
|
||||||
await stop_cache_scheduler() # 停止缓存调度器
|
await stop_cache_scheduler() # 停止缓存调度器
|
||||||
|
await stop_database_cleanup_scheduler() # 停止数据库清理调度器
|
||||||
await download_service.stop_health_check() # 停止下载服务健康检查
|
await download_service.stop_health_check() # 停止下载服务健康检查
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
await redis_client.aclose()
|
await redis_client.aclose()
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
"""add email verification tables
|
||||||
|
|
||||||
|
Revision ID: 0f96348cdfd2
|
||||||
|
Revises: e96a649e18ca
|
||||||
|
Create Date: 2025-08-22 07:26:59.129564
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0f96348cdfd2"
|
||||||
|
down_revision: str | Sequence[str] | None = "e96a649e18ca"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# 创建邮件验证表
|
||||||
|
op.create_table(
|
||||||
|
"email_verifications",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False, primary_key=True, autoincrement=True),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("email", sa.String(255), nullable=False),
|
||||||
|
sa.Column("verification_code", sa.String(8), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("is_used", sa.Boolean(), nullable=False, default=False),
|
||||||
|
sa.Column("used_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("ip_address", sa.String(255), nullable=True),
|
||||||
|
sa.Column("user_agent", sa.String(255), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"]),
|
||||||
|
sa.Index("ix_email_verifications_user_id", "user_id"),
|
||||||
|
sa.Index("ix_email_verifications_email", "email"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建登录会话表
|
||||||
|
op.create_table(
|
||||||
|
"login_sessions",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False, primary_key=True, autoincrement=True),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("session_token", sa.String(255), nullable=False),
|
||||||
|
sa.Column("ip_address", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_agent", sa.String(255), nullable=True),
|
||||||
|
sa.Column("country_code", sa.String(255), nullable=True),
|
||||||
|
sa.Column("is_verified", sa.Boolean(), nullable=False, default=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("verified_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("is_new_location", sa.Boolean(), nullable=False, default=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"]),
|
||||||
|
sa.Index("ix_login_sessions_user_id", "user_id"),
|
||||||
|
sa.Index("ix_login_sessions_session_token", "session_token", unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
op.drop_table("login_sessions")
|
||||||
|
op.drop_table("email_verifications")
|
||||||
Reference in New Issue
Block a user