57
app/auth.py
57
app/auth.py
@@ -317,13 +317,51 @@ def totp_redis_key(user: User) -> str:
|
||||
return f"totp:setup:{user.email}"
|
||||
|
||||
|
||||
def _generate_totp_account_label(user: User) -> str:
|
||||
"""生成TOTP账户标签
|
||||
|
||||
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
|
||||
"""
|
||||
if settings.totp_use_username_in_label:
|
||||
# 使用用户名作为主要标识
|
||||
primary_identifier = user.username
|
||||
else:
|
||||
# 使用邮箱作为标识
|
||||
primary_identifier = user.email
|
||||
|
||||
# 如果配置了服务名称,添加到标签中以便在认证器中区分
|
||||
if settings.totp_service_name:
|
||||
return f"{primary_identifier} ({settings.totp_service_name})"
|
||||
else:
|
||||
return primary_identifier
|
||||
|
||||
|
||||
def _generate_totp_issuer_name() -> str:
|
||||
"""生成TOTP发行者名称
|
||||
|
||||
优先使用自定义的totp_issuer,否则使用服务名称
|
||||
"""
|
||||
if settings.totp_issuer:
|
||||
return settings.totp_issuer
|
||||
elif settings.totp_service_name:
|
||||
return settings.totp_service_name
|
||||
else:
|
||||
# 回退到默认值
|
||||
return "osu! Private Server"
|
||||
|
||||
|
||||
async def start_create_totp_key(user: User, redis: Redis) -> StartCreateTotpKeyResp:
|
||||
secret = pyotp.random_base32()
|
||||
await redis.hset(totp_redis_key(user), mapping={"secret": secret, "fails": 0}) # pyright: ignore[reportGeneralTypeIssues]
|
||||
await redis.expire(totp_redis_key(user), 300)
|
||||
|
||||
# 生成更完整的账户标签和issuer信息
|
||||
account_label = _generate_totp_account_label(user)
|
||||
issuer_name = _generate_totp_issuer_name()
|
||||
|
||||
return StartCreateTotpKeyResp(
|
||||
secret=secret,
|
||||
uri=pyotp.totp.TOTP(secret).provisioning_uri(name=user.email, issuer_name=settings.totp_issuer),
|
||||
uri=pyotp.totp.TOTP(secret).provisioning_uri(name=account_label, issuer_name=issuer_name),
|
||||
)
|
||||
|
||||
|
||||
@@ -331,6 +369,23 @@ def verify_totp_key(secret: str, code: str) -> bool:
|
||||
return pyotp.TOTP(secret).verify(code, valid_window=1)
|
||||
|
||||
|
||||
async def verify_totp_key_with_replay_protection(
|
||||
user_id: int, secret: str, code: str, redis: Redis
|
||||
) -> bool:
|
||||
"""验证TOTP密钥,并防止密钥重放攻击"""
|
||||
if not pyotp.TOTP(secret).verify(code, valid_window=1):
|
||||
return False
|
||||
|
||||
# 防止120秒内重复使用同一密钥(参考osu-web实现)
|
||||
cache_key = f"totp:{user_id}:{code}"
|
||||
if await redis.exists(cache_key):
|
||||
return False
|
||||
|
||||
# 设置120秒过期时间
|
||||
await redis.setex(cache_key, 120, "1")
|
||||
return True
|
||||
|
||||
|
||||
def _generate_backup_codes(count=10, length=BACKUP_CODE_LENGTH) -> list[str]:
|
||||
alphabet = string.ascii_uppercase + string.digits
|
||||
return ["".join(secrets.choice(alphabet) for _ in range(length)) for _ in range(count)]
|
||||
|
||||
@@ -304,6 +304,16 @@ STORAGE_SETTINGS='{
|
||||
Field(default=None, description="TOTP 认证器中的发行者名称"),
|
||||
"验证服务设置",
|
||||
]
|
||||
totp_service_name: Annotated[
|
||||
str,
|
||||
Field(default="g0v0! Lazer Server", description="TOTP 认证器中显示的服务名称"),
|
||||
"验证服务设置",
|
||||
]
|
||||
totp_use_username_in_label: Annotated[
|
||||
bool,
|
||||
Field(default=True, description="在TOTP标签中使用用户名而不是邮箱"),
|
||||
"验证服务设置",
|
||||
]
|
||||
enable_email_verification: Annotated[
|
||||
bool,
|
||||
Field(default=False, description="是否启用邮件验证功能"),
|
||||
@@ -314,6 +324,11 @@ STORAGE_SETTINGS='{
|
||||
Field(default=True, description="是否启用智能验证(基于客户端类型和设备信任)"),
|
||||
"验证服务设置",
|
||||
]
|
||||
enable_session_verification: Annotated[
|
||||
bool,
|
||||
Field(default=True, description="是否启用会话验证中间件"),
|
||||
"验证服务设置",
|
||||
]
|
||||
enable_multi_device_login: Annotated[
|
||||
bool,
|
||||
Field(default=True, description="是否允许多设备同时登录"),
|
||||
|
||||
@@ -49,5 +49,7 @@ class LoginSession(SQLModel, table=True):
|
||||
verified_at: datetime | None = Field(default=None)
|
||||
expires_at: datetime = Field() # 会话过期时间
|
||||
is_new_location: bool = Field(default=False) # 是否新位置登录
|
||||
session_token: str | None = Field(default=None, max_length=64, index=True) # 会话令牌
|
||||
verification_method: str | None = Field(default=None, max_length=20) # 验证方法 (totp/mail)
|
||||
|
||||
token: Optional["OAuthToken"] = Relationship(back_populates="login_session")
|
||||
|
||||
@@ -8,9 +8,11 @@ from app.database import User
|
||||
from app.database.auth import OAuthToken, V1APIKeys
|
||||
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||
|
||||
from .database import Database
|
||||
from .api_version import APIVersion
|
||||
from .database import Database, get_redis
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from redis.asyncio import Redis
|
||||
from fastapi.security import (
|
||||
APIKeyQuery,
|
||||
HTTPBearer,
|
||||
@@ -97,13 +99,40 @@ async def get_client_user_no_verified(user_and_token: UserAndToken = Depends(get
|
||||
return user_and_token[0]
|
||||
|
||||
|
||||
async def get_client_user(db: Database, user_and_token: UserAndToken = Depends(get_client_user_and_token)):
|
||||
async def get_client_user(
|
||||
db: Database,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
api_version: APIVersion,
|
||||
user_and_token: UserAndToken = Depends(get_client_user_and_token)
|
||||
):
|
||||
from app.service.verification_service import LoginSessionService
|
||||
|
||||
user, token = user_and_token
|
||||
|
||||
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
|
||||
raise HTTPException(status_code=403, detail="User not verified")
|
||||
# 获取当前验证方式
|
||||
verify_method = None
|
||||
if api_version >= 20250913:
|
||||
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
|
||||
|
||||
if verify_method is None:
|
||||
# 智能选择验证方式(有TOTP优先TOTP)
|
||||
totp_key = await user.awaitable_attrs.totp_key
|
||||
if totp_key is not None and api_version >= 20240101:
|
||||
verify_method = "totp"
|
||||
else:
|
||||
verify_method = "mail"
|
||||
|
||||
# 设置选择的验证方法到Redis中,避免重复选择
|
||||
if api_version >= 20250913:
|
||||
await LoginSessionService.set_login_method(user.id, token.id, verify_method, redis)
|
||||
|
||||
# 返回符合 osu! API 标准的错误响应
|
||||
error_response = {
|
||||
"error": "User not verified",
|
||||
"method": verify_method
|
||||
}
|
||||
raise HTTPException(status_code=401, detail=error_response)
|
||||
return user
|
||||
|
||||
|
||||
|
||||
74
app/interfaces/session_verification.py
Normal file
74
app/interfaces/session_verification.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
会话验证接口
|
||||
|
||||
基于osu-web的SessionVerificationInterface实现
|
||||
用于标准化会话验证行为
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class SessionVerificationInterface(ABC):
|
||||
"""会话验证接口
|
||||
|
||||
定义了会话验证所需的基本操作,参考osu-web的实现
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def find_for_verification(cls, session_id: str) -> Optional[SessionVerificationInterface]:
|
||||
"""根据会话ID查找会话用于验证
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
会话实例或None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_key(self) -> str:
|
||||
"""获取会话密钥/ID"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_key_for_event(self) -> str:
|
||||
"""获取用于事件广播的会话密钥"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_verification_method(self) -> Optional[str]:
|
||||
"""获取当前验证方法
|
||||
|
||||
Returns:
|
||||
验证方法 ('totp', 'mail') 或 None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_verified(self) -> bool:
|
||||
"""检查会话是否已验证"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def mark_verified(self) -> None:
|
||||
"""标记会话为已验证"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_verification_method(self, method: str) -> None:
|
||||
"""设置验证方法
|
||||
|
||||
Args:
|
||||
method: 验证方法 ('totp', 'mail')
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def user_id(self) -> Optional[int]:
|
||||
"""获取关联的用户ID"""
|
||||
pass
|
||||
9
app/middleware/__init__.py
Normal file
9
app/middleware/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
中间件模块
|
||||
|
||||
提供会话验证和其他中间件功能
|
||||
"""
|
||||
|
||||
from .verify_session import VerifySessionMiddleware, SessionState
|
||||
|
||||
__all__ = ["VerifySessionMiddleware", "SessionState"]
|
||||
318
app/middleware/session_verification.py
Normal file
318
app/middleware/session_verification.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
会话验证中间件和状态管理
|
||||
|
||||
基于osu-web的会话验证系统实现
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.database.lazer_user import User
|
||||
from app.database.verification import LoginSession
|
||||
from app.dependencies.database import with_db, get_redis
|
||||
from app.log import logger
|
||||
from app.service.verification_service import LoginSessionService
|
||||
|
||||
|
||||
class SessionVerificationState:
|
||||
"""会话验证状态管理类
|
||||
|
||||
参考osu-web的State类实现
|
||||
"""
|
||||
|
||||
def __init__(self, session: LoginSession, user: User, redis: Redis):
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.redis = redis
|
||||
|
||||
@classmethod
|
||||
async def get_current(
|
||||
cls,
|
||||
request: Request,
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user: User,
|
||||
) -> Optional[SessionVerificationState]:
|
||||
"""获取当前会话验证状态"""
|
||||
try:
|
||||
# 从请求头或token中获取会话信息
|
||||
session_token = cls._extract_session_token(request)
|
||||
if not session_token:
|
||||
return None
|
||||
|
||||
# 查找会话
|
||||
session = await LoginSessionService.find_for_verification(db, session_token)
|
||||
if not session or session.user_id != user.id:
|
||||
return None
|
||||
|
||||
return cls(session, user, redis)
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification] Error getting current state: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_token(request: Request) -> Optional[str]:
|
||||
"""从请求中提取会话token"""
|
||||
# 尝试从Authorization header提取
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header[7:] # 移除"Bearer "前缀
|
||||
|
||||
# 可以扩展其他提取方式
|
||||
return None
|
||||
|
||||
def get_method(self) -> str:
|
||||
"""获取验证方法
|
||||
|
||||
参考osu-web的逻辑,智能选择验证方法
|
||||
"""
|
||||
current_method = self.session.verification_method
|
||||
|
||||
if current_method is None:
|
||||
# 智能选择验证方法
|
||||
# 参考osu-web: API版本 < 20250913 或用户没有TOTP时使用邮件验证
|
||||
# 这里简化为检查用户是否有TOTP
|
||||
totp_key = getattr(self.user, 'totp_key', None)
|
||||
current_method = 'totp' if totp_key else 'mail'
|
||||
|
||||
# 设置验证方法
|
||||
asyncio.create_task(self._set_verification_method(current_method))
|
||||
|
||||
return current_method
|
||||
|
||||
async def _set_verification_method(self, method: str) -> None:
|
||||
"""内部方法:设置验证方法"""
|
||||
try:
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None and method in ['totp', 'mail']:
|
||||
# 类型检查确保method是正确的字面量类型
|
||||
verification_method = method if method in ['totp', 'mail'] else 'totp'
|
||||
await LoginSessionService.set_login_method(
|
||||
self.user.id, token_id, verification_method, self.redis # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification] Error setting verification method: {e}")
|
||||
|
||||
def is_verified(self) -> bool:
|
||||
"""检查会话是否已验证"""
|
||||
return self.session.is_verified
|
||||
|
||||
async def mark_verified(self) -> None:
|
||||
"""标记会话为已验证"""
|
||||
try:
|
||||
# 创建专用数据库会话
|
||||
db = with_db()
|
||||
try:
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
await LoginSessionService.mark_session_verified(
|
||||
db, self.redis, self.user.id, token_id
|
||||
)
|
||||
finally:
|
||||
await db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification] Error marking session verified: {e}")
|
||||
|
||||
def get_key(self) -> str:
|
||||
"""获取会话密钥"""
|
||||
return str(self.session.id) if self.session.id else ""
|
||||
|
||||
def get_key_for_event(self) -> str:
|
||||
"""获取用于事件广播的会话密钥"""
|
||||
return LoginSessionService.get_key_for_event(self.get_key())
|
||||
|
||||
def user_id(self) -> int:
|
||||
"""获取用户ID"""
|
||||
return self.user.id
|
||||
|
||||
async def issue_mail_if_needed(self) -> None:
|
||||
"""如果需要,发送验证邮件"""
|
||||
try:
|
||||
if self.get_method() == "mail":
|
||||
from app.service.verification_service import EmailVerificationService
|
||||
|
||||
# 创建专用数据库会话发送邮件
|
||||
db = with_db()
|
||||
try:
|
||||
await EmailVerificationService.send_verification_email(
|
||||
db, self.redis, self.user.id, self.user.username,
|
||||
self.user.email, None, None
|
||||
)
|
||||
finally:
|
||||
await db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification] Error issuing mail: {e}")
|
||||
|
||||
|
||||
class SessionVerificationController:
|
||||
"""会话验证控制器
|
||||
|
||||
参考osu-web的Controller类实现
|
||||
"""
|
||||
|
||||
# 需要跳过验证的路由(参考osu-web的SKIP_VERIFICATION_ROUTES)
|
||||
SKIP_VERIFICATION_ROUTES = {
|
||||
"/api/v2/session/verify",
|
||||
"/api/v2/session/verify/reissue",
|
||||
"/api/v2/me",
|
||||
"/api/v2/logout",
|
||||
"/oauth/token",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def should_skip_verification(request: Request) -> bool:
|
||||
"""检查是否应该跳过验证"""
|
||||
path = request.url.path
|
||||
return path in SessionVerificationController.SKIP_VERIFICATION_ROUTES
|
||||
|
||||
@staticmethod
|
||||
async def initiate_verification(
|
||||
state: SessionVerificationState,
|
||||
request: Request,
|
||||
) -> Response:
|
||||
"""启动会话验证流程
|
||||
|
||||
参考osu-web的initiate方法
|
||||
"""
|
||||
try:
|
||||
method = state.get_method()
|
||||
|
||||
# 如果是邮件验证,发送验证邮件
|
||||
if method == "mail":
|
||||
await state.issue_mail_if_needed()
|
||||
|
||||
# API请求返回JSON响应
|
||||
if request.url.path.startswith("/api/"):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"method": method}
|
||||
)
|
||||
|
||||
# 其他情况可以扩展支持HTML响应
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"authentication": "verify",
|
||||
"method": method,
|
||||
"message": "Session verification required"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification] Error initiating verification: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Verification initiation failed"
|
||||
)
|
||||
|
||||
|
||||
class SessionVerificationMiddleware:
|
||||
"""会话验证中间件
|
||||
|
||||
参考osu-web的VerifyUser中间件实现
|
||||
"""
|
||||
|
||||
def __init__(self, app: Callable[[Request], Awaitable[Response]]):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
"""中间件主要逻辑"""
|
||||
try:
|
||||
# 检查是否需要跳过验证
|
||||
if SessionVerificationController.should_skip_verification(request):
|
||||
return await call_next(request)
|
||||
|
||||
# 获取依赖项
|
||||
user = await self._get_user(request)
|
||||
if not user:
|
||||
# 未认证用户跳过验证
|
||||
return await call_next(request)
|
||||
|
||||
# 获取数据库和Redis连接
|
||||
db = await self._get_db()
|
||||
redis = await self._get_redis()
|
||||
|
||||
# 获取会话验证状态
|
||||
state = await SessionVerificationState.get_current(request, db, redis, user)
|
||||
if not state:
|
||||
# 无法获取会话状态,继续请求
|
||||
return await call_next(request)
|
||||
|
||||
# 检查是否已验证
|
||||
if state.is_verified():
|
||||
# 已验证,继续请求
|
||||
return await call_next(request)
|
||||
|
||||
# 检查是否需要验证
|
||||
if not self._requires_verification(request):
|
||||
return await call_next(request)
|
||||
|
||||
# 启动验证流程
|
||||
return await SessionVerificationController.initiate_verification(state, request)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification Middleware] Unexpected error: {e}")
|
||||
# 出错时允许请求继续,避免阻塞正常流程
|
||||
return await call_next(request)
|
||||
|
||||
async def _get_user(self, request: Request) -> Optional[User]:
|
||||
"""获取当前用户"""
|
||||
try:
|
||||
# 这里需要手动获取用户,因为在中间件中无法直接使用依赖注入
|
||||
# 简化实现,实际应该从token中解析用户
|
||||
return None # 暂时返回None,需要实际实现
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _get_db(self) -> AsyncSession:
|
||||
"""获取数据库连接"""
|
||||
return with_db()
|
||||
|
||||
async def _get_redis(self) -> Redis:
|
||||
"""获取Redis连接"""
|
||||
return get_redis()
|
||||
|
||||
def _requires_verification(self, request: Request) -> bool:
|
||||
"""检查是否需要验证
|
||||
|
||||
参考osu-web的requiresVerification方法
|
||||
"""
|
||||
method = request.method
|
||||
|
||||
# GET/HEAD/OPTIONS请求一般不需要验证
|
||||
safe_methods = {"GET", "HEAD", "OPTIONS"}
|
||||
if method in safe_methods:
|
||||
return False
|
||||
|
||||
# POST/PUT/DELETE等修改操作需要验证
|
||||
return True
|
||||
|
||||
|
||||
# FastAPI中间件包装器
|
||||
class FastAPISessionVerificationMiddleware:
|
||||
"""FastAPI会话验证中间件包装器"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
request = Request(scope, receive)
|
||||
|
||||
async def call_next(req: Request) -> Response:
|
||||
# 这里需要调用FastAPI应用
|
||||
return Response("OK") # 占位符实现
|
||||
|
||||
middleware = SessionVerificationMiddleware(call_next)
|
||||
response = await middleware(request, call_next)
|
||||
|
||||
await response(scope, receive, send)
|
||||
44
app/middleware/setup.py
Normal file
44
app/middleware/setup.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
中间件设置和配置
|
||||
|
||||
展示如何将会话验证中间件集成到FastAPI应用中
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.config import settings
|
||||
from app.middleware.verify_session import VerifySessionMiddleware
|
||||
|
||||
|
||||
def setup_session_verification_middleware(app: FastAPI) -> None:
|
||||
"""设置会话验证中间件
|
||||
|
||||
Args:
|
||||
app: FastAPI应用实例
|
||||
"""
|
||||
# 只在启用会话验证时添加中间件
|
||||
if settings.enable_session_verification:
|
||||
app.add_middleware(VerifySessionMiddleware)
|
||||
|
||||
# 可以在这里添加中间件配置日志
|
||||
from app.log import logger
|
||||
logger.info("[Middleware] Session verification middleware enabled")
|
||||
else:
|
||||
from app.log import logger
|
||||
logger.info("[Middleware] Session verification middleware disabled")
|
||||
|
||||
|
||||
def setup_all_middlewares(app: FastAPI) -> None:
|
||||
"""设置所有中间件
|
||||
|
||||
Args:
|
||||
app: FastAPI应用实例
|
||||
"""
|
||||
# 设置会话验证中间件
|
||||
setup_session_verification_middleware(app)
|
||||
|
||||
# 可以在这里添加其他中间件
|
||||
# app.add_middleware(OtherMiddleware)
|
||||
|
||||
from app.log import logger
|
||||
logger.info("[Middleware] All middlewares configured")
|
||||
287
app/middleware/verify_session.py
Normal file
287
app/middleware/verify_session.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
FastAPI会话验证中间件
|
||||
|
||||
基于osu-web的会话验证系统,适配FastAPI框架
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.database.lazer_user import User
|
||||
from app.database.verification import LoginSession
|
||||
from app.dependencies.database import with_db, get_redis
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.log import logger
|
||||
from app.service.verification_service import LoginSessionService
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
"""会话验证中间件
|
||||
|
||||
参考osu-web的VerifyUser中间件,适配FastAPI
|
||||
"""
|
||||
|
||||
# 需要跳过验证的路由
|
||||
SKIP_VERIFICATION_ROUTES = {
|
||||
"/api/v2/session/verify",
|
||||
"/api/v2/session/verify/reissue",
|
||||
"/api/v2/me",
|
||||
"/api/v2/logout",
|
||||
"/oauth/token",
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/redoc",
|
||||
}
|
||||
|
||||
# 需要强制验证的路由模式(敏感操作)
|
||||
ALWAYS_VERIFY_PATTERNS = {
|
||||
"/api/v2/account/",
|
||||
"/api/v2/settings/",
|
||||
"/api/private/admin/",
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""中间件主处理逻辑"""
|
||||
try:
|
||||
# 检查是否跳过验证
|
||||
if self._should_skip_verification(request):
|
||||
return await call_next(request)
|
||||
|
||||
# 获取当前用户
|
||||
user = await self._get_current_user(request)
|
||||
if not user:
|
||||
# 未登录用户跳过验证
|
||||
return await call_next(request)
|
||||
|
||||
# 获取会话状态
|
||||
session_state = await self._get_session_state(request, user)
|
||||
if not session_state:
|
||||
# 无会话状态,继续请求
|
||||
return await call_next(request)
|
||||
|
||||
# 检查是否已验证
|
||||
if session_state.is_verified():
|
||||
return await call_next(request)
|
||||
|
||||
# 检查是否需要验证
|
||||
if not self._requires_verification(request, user):
|
||||
return await call_next(request)
|
||||
|
||||
# 启动验证流程
|
||||
return await self._initiate_verification(request, session_state)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Verify Session Middleware] Error: {e}")
|
||||
# 出错时允许请求继续,避免阻塞
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_verification(self, request: Request) -> bool:
|
||||
"""检查是否应该跳过验证"""
|
||||
path = request.url.path
|
||||
|
||||
# 完全匹配的跳过路由
|
||||
if path in self.SKIP_VERIFICATION_ROUTES:
|
||||
return True
|
||||
|
||||
# 非API请求跳过
|
||||
if not path.startswith("/api/"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _requires_verification(self, request: Request, user: User) -> bool:
|
||||
"""检查是否需要验证"""
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
# 检查是否为强制验证的路由
|
||||
for pattern in self.ALWAYS_VERIFY_PATTERNS:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
# 特权用户或非活跃用户需要验证
|
||||
if hasattr(user, 'is_privileged') and user.is_privileged():
|
||||
return True
|
||||
if hasattr(user, 'is_inactive') and user.is_inactive():
|
||||
return True
|
||||
|
||||
# 安全方法(GET/HEAD/OPTIONS)一般不需要验证
|
||||
safe_methods = {"GET", "HEAD", "OPTIONS"}
|
||||
if method in safe_methods:
|
||||
return False
|
||||
|
||||
# 修改操作(POST/PUT/DELETE/PATCH)需要验证
|
||||
return method in {"POST", "PUT", "DELETE", "PATCH"}
|
||||
|
||||
async def _get_current_user(self, request: Request) -> Optional[User]:
|
||||
"""获取当前用户"""
|
||||
try:
|
||||
# 从Authorization header提取token
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
token = auth_header[7:] # 移除"Bearer "前缀
|
||||
|
||||
# 创建专用数据库会话
|
||||
db = with_db()
|
||||
try:
|
||||
# 获取token记录
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
return None
|
||||
|
||||
# 获取用户
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
return user
|
||||
finally:
|
||||
await db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[Verify Session Middleware] Error getting user: {e}")
|
||||
return None
|
||||
|
||||
async def _get_session_state(self, request: Request, user: User) -> Optional[SessionState]:
|
||||
"""获取会话状态"""
|
||||
try:
|
||||
# 提取会话token(这里简化为使用相同的auth token)
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
session_token = auth_header[7:]
|
||||
|
||||
# 获取数据库和Redis连接
|
||||
db = with_db()
|
||||
try:
|
||||
redis = get_redis()
|
||||
|
||||
# 查找会话
|
||||
session = await LoginSessionService.find_for_verification(db, session_token)
|
||||
if not session or session.user_id != user.id:
|
||||
return None
|
||||
|
||||
return SessionState(session, user, redis, db)
|
||||
finally:
|
||||
await db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Verify Session Middleware] Error getting session state: {e}")
|
||||
return None
|
||||
|
||||
async def _initiate_verification(self, request: Request, state: SessionState) -> Response:
|
||||
"""启动验证流程"""
|
||||
try:
|
||||
method = await state.get_method()
|
||||
|
||||
# 如果是邮件验证,可以在这里触发发送邮件
|
||||
if method == "mail":
|
||||
await state.issue_mail_if_needed()
|
||||
|
||||
# 返回验证要求响应
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"method": method,
|
||||
"message": "Session verification required"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Verify Session Middleware] Error initiating verification: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"error": "Verification initiation failed"}
|
||||
)
|
||||
|
||||
|
||||
class SessionState:
|
||||
"""会话状态类
|
||||
|
||||
简化版本的会话状态管理
|
||||
"""
|
||||
|
||||
def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession):
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.redis = redis
|
||||
self.db = db
|
||||
self._verification_method: Optional[str] = None
|
||||
|
||||
def is_verified(self) -> bool:
|
||||
"""检查会话是否已验证"""
|
||||
return self.session.is_verified
|
||||
|
||||
async def get_method(self) -> str:
|
||||
"""获取验证方法"""
|
||||
if self._verification_method is None:
|
||||
# 从Redis获取已设置的方法
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
self._verification_method = await LoginSessionService.get_login_method(
|
||||
self.user.id, token_id, self.redis
|
||||
)
|
||||
|
||||
# 如果没有设置,智能选择
|
||||
if self._verification_method is None:
|
||||
# 检查用户是否有TOTP密钥
|
||||
await self.user.awaitable_attrs.totp_key # 预加载
|
||||
totp_key = getattr(self.user, 'totp_key', None)
|
||||
self._verification_method = 'totp' if totp_key else 'mail'
|
||||
|
||||
# 保存选择的方法
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
await LoginSessionService.set_login_method(
|
||||
self.user.id, token_id, self._verification_method, self.redis
|
||||
)
|
||||
|
||||
return self._verification_method
|
||||
|
||||
async def mark_verified(self) -> None:
|
||||
"""标记会话为已验证"""
|
||||
try:
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
await LoginSessionService.mark_session_verified(
|
||||
self.db, self.redis, self.user.id, token_id
|
||||
)
|
||||
self.session.is_verified = True # 更新本地状态
|
||||
except Exception as e:
|
||||
logger.error(f"[Session State] Error marking verified: {e}")
|
||||
|
||||
async def issue_mail_if_needed(self) -> None:
|
||||
"""如果需要,发送验证邮件"""
|
||||
try:
|
||||
if await self.get_method() == "mail":
|
||||
from app.service.verification_service import EmailVerificationService
|
||||
|
||||
# 这里可以触发邮件发送
|
||||
await EmailVerificationService.send_verification_email(
|
||||
self.db, self.redis, self.user.id, self.user.username,
|
||||
self.user.email, None, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Session State] Error issuing mail: {e}")
|
||||
|
||||
def get_key(self) -> str:
|
||||
"""获取会话密钥"""
|
||||
return str(self.session.id) if self.session.id else ""
|
||||
|
||||
def get_key_for_event(self) -> str:
|
||||
"""获取用于事件广播的会话密钥"""
|
||||
return LoginSessionService.get_key_for_event(self.get_key())
|
||||
|
||||
def user_id(self) -> int:
|
||||
"""获取用户ID"""
|
||||
return self.user.id
|
||||
@@ -48,9 +48,18 @@ class ChatServer:
|
||||
user_id = user.id
|
||||
if user_id in self.connect_client:
|
||||
del self.connect_client[user_id]
|
||||
|
||||
# 创建频道ID列表的副本以避免在迭代过程中修改字典
|
||||
channel_ids_to_process = []
|
||||
for channel_id, channel in self.channels.items():
|
||||
if user_id in channel:
|
||||
channel.remove(user_id)
|
||||
channel_ids_to_process.append(channel_id)
|
||||
|
||||
# 现在安全地处理每个频道
|
||||
for channel_id in channel_ids_to_process:
|
||||
# 再次检查用户是否仍在频道中(防止并发修改)
|
||||
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||||
self.channels[channel_id].remove(user_id)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
||||
|
||||
@@ -5,9 +5,8 @@ from app.auth import (
|
||||
finish_create_totp_key,
|
||||
start_create_totp_key,
|
||||
totp_redis_key,
|
||||
verify_totp_key,
|
||||
verify_totp_key_with_replay_protection,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.const import BACKUP_CODE_LENGTH
|
||||
from app.database.auth import TotpKeys
|
||||
from app.database.lazer_user import User
|
||||
@@ -19,9 +18,38 @@ from .router import router
|
||||
|
||||
from fastapi import Body, Depends, HTTPException, Security
|
||||
import pyotp
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class TotpStatusResp(BaseModel):
|
||||
"""TOTP状态响应"""
|
||||
enabled: bool
|
||||
created_at: str | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/totp/status",
|
||||
name="检查 TOTP 状态",
|
||||
description="检查当前用户是否已启用 TOTP 双因素验证",
|
||||
tags=["验证", "g0v0 API"],
|
||||
response_model=TotpStatusResp,
|
||||
)
|
||||
async def get_totp_status(
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
"""检查用户是否已创建TOTP"""
|
||||
totp_key = await current_user.awaitable_attrs.totp_key
|
||||
|
||||
if totp_key:
|
||||
return TotpStatusResp(
|
||||
enabled=True,
|
||||
created_at=totp_key.created_at.isoformat()
|
||||
)
|
||||
else:
|
||||
return TotpStatusResp(enabled=False)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/totp/create",
|
||||
name="开始 TOTP 创建流程",
|
||||
@@ -44,11 +72,16 @@ async def start_create_totp(
|
||||
|
||||
previous = await redis.hgetall(totp_redis_key(current_user)) # pyright: ignore[reportGeneralTypeIssues]
|
||||
if previous: # pyright: ignore[reportGeneralTypeIssues]
|
||||
from app.auth import _generate_totp_account_label, _generate_totp_issuer_name
|
||||
|
||||
account_label = _generate_totp_account_label(current_user)
|
||||
issuer_name = _generate_totp_issuer_name()
|
||||
|
||||
return StartCreateTotpKeyResp(
|
||||
secret=previous["secret"],
|
||||
uri=pyotp.totp.TOTP(previous["secret"]).provisioning_uri(
|
||||
name=current_user.email,
|
||||
issuer_name=settings.totp_issuer,
|
||||
name=account_label,
|
||||
issuer_name=issuer_name,
|
||||
),
|
||||
)
|
||||
return await start_create_totp_key(current_user, redis)
|
||||
@@ -92,12 +125,21 @@ async def finish_create_totp(
|
||||
async def disable_totp(
|
||||
session: Database,
|
||||
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
totp = await session.get(TotpKeys, current_user.id)
|
||||
if not totp:
|
||||
raise HTTPException(status_code=400, detail="TOTP is not enabled for this user")
|
||||
if verify_totp_key(totp.secret, code) or (len(code) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp, code)):
|
||||
|
||||
# 使用防重放保护的TOTP验证或备份码验证
|
||||
is_totp_valid = False
|
||||
if len(code) == 6 and code.isdigit():
|
||||
is_totp_valid = await verify_totp_key_with_replay_protection(current_user.id, totp.secret, code, redis)
|
||||
elif len(code) == BACKUP_CODE_LENGTH:
|
||||
is_totp_valid = check_totp_backup_code(totp, code)
|
||||
|
||||
if is_totp_valid:
|
||||
await session.delete(totp)
|
||||
await session.commit()
|
||||
else:
|
||||
|
||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.auth import check_totp_backup_code, verify_totp_key
|
||||
from app.auth import check_totp_backup_code, verify_totp_key_with_replay_protection
|
||||
from app.config import settings
|
||||
from app.const import BACKUP_CODE_LENGTH
|
||||
from app.database.auth import TotpKeys
|
||||
@@ -40,7 +40,11 @@ class SessionReissueResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class VerifyFailed(Exception): ...
|
||||
class VerifyFailed(Exception):
|
||||
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
|
||||
super().__init__(message)
|
||||
self.reason = reason
|
||||
self.should_reissue = should_reissue
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -80,28 +84,42 @@ async def verify_session(
|
||||
try:
|
||||
totp_key: TotpKeys | None = await current_user.awaitable_attrs.totp_key
|
||||
if verify_method is None:
|
||||
verify_method = "totp" if totp_key else "mail"
|
||||
# 智能选择验证方法(参考osu-web实现)
|
||||
# API版本较老或用户未设置TOTP时强制使用邮件验证
|
||||
#print(api_version, totp_key)
|
||||
if api_version < 20240101 or totp_key is None:
|
||||
verify_method = "mail"
|
||||
else:
|
||||
verify_method = "totp"
|
||||
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
|
||||
login_method = verify_method
|
||||
|
||||
if verify_method == "totp":
|
||||
if not totp_key:
|
||||
# TOTP密钥在验证开始和现在之间被删除(参考osu-web的fallback机制)
|
||||
if settings.enable_email_verification:
|
||||
await LoginSessionService.set_login_method(user_id, token_id, "mail", redis)
|
||||
await EmailVerificationService.send_verification_email(
|
||||
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||
)
|
||||
verify_method = "mail"
|
||||
raise VerifyFailed("用户未设置 TOTP,已发送邮件验证码")
|
||||
raise VerifyFailed("用户TOTP已被删除,已切换到邮件验证")
|
||||
# 如果未开启邮箱验证,则直接认为认证通过
|
||||
# 正常不会进入到这里
|
||||
|
||||
elif verify_totp_key(totp_key.secret, verification_key):
|
||||
elif await verify_totp_key_with_replay_protection(user_id, totp_key.secret, verification_key, redis):
|
||||
pass
|
||||
elif len(verification_key) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp_key, verification_key):
|
||||
login_method = "totp_backup_code"
|
||||
else:
|
||||
raise VerifyFailed("TOTP 验证失败")
|
||||
# 记录详细的验证失败原因(参考osu-web的错误处理)
|
||||
if len(verification_key) != 6:
|
||||
raise VerifyFailed("TOTP验证码长度错误,应为6位数字", reason="incorrect_length")
|
||||
elif not verification_key.isdigit():
|
||||
raise VerifyFailed("TOTP验证码格式错误,应为纯数字", reason="incorrect_format")
|
||||
else:
|
||||
# 可能是密钥错误或者重放攻击
|
||||
raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
|
||||
else:
|
||||
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
|
||||
if not success:
|
||||
@@ -127,7 +145,28 @@ async def verify_session(
|
||||
login_method=login_method,
|
||||
notes=str(e),
|
||||
)
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": verify_method})
|
||||
|
||||
# 构建更详细的错误响应(参考osu-web的错误处理)
|
||||
error_response = {
|
||||
"error": str(e),
|
||||
"method": verify_method,
|
||||
}
|
||||
|
||||
# 如果有具体的错误原因,添加到响应中
|
||||
if hasattr(e, 'reason') and e.reason:
|
||||
error_response["reason"] = e.reason
|
||||
|
||||
# 如果需要重新发送邮件验证码
|
||||
if hasattr(e, 'should_reissue') and e.should_reissue and verify_method == "mail":
|
||||
try:
|
||||
await EmailVerificationService.send_verification_email(
|
||||
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||
)
|
||||
error_response["reissued"] = True
|
||||
except Exception:
|
||||
pass # 忽略重发邮件失败的错误
|
||||
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -74,10 +74,13 @@ class DatabaseCleanupScheduler:
|
||||
# 清理过期的登录会话
|
||||
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||
|
||||
# 清理1小时前未验证的登录会话
|
||||
unverified_sessions = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1)
|
||||
|
||||
# 只在有清理记录时输出总结
|
||||
total_cleaned = expired_codes + expired_sessions
|
||||
total_cleaned = expired_codes + expired_sessions + unverified_sessions
|
||||
if total_cleaned > 0:
|
||||
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}")
|
||||
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}, unverified: {unverified_sessions}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during scheduled database cleanup: {e!s}")
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.database.verification import EmailVerification, LoginSession
|
||||
from app.log import logger
|
||||
from app.utils import utcnow
|
||||
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ class DatabaseCleanupService:
|
||||
# 查找指定天数前的已使用验证码记录
|
||||
cutoff_time = utcnow() - timedelta(days=days_old)
|
||||
|
||||
stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
|
||||
stmt = select(EmailVerification).where(EmailVerification.is_used == True)
|
||||
result = await db.exec(stmt)
|
||||
all_used_codes = result.all()
|
||||
|
||||
@@ -134,6 +134,50 @@ class DatabaseCleanupService:
|
||||
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_unverified_login_sessions(db: AsyncSession, hours_old: int = 1) -> int:
|
||||
"""
|
||||
清理指定小时前创建但仍未验证的登录会话
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
hours_old: 清理多少小时前创建但仍未验证的会话,默认1小时
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 计算截止时间
|
||||
cutoff_time = utcnow() - timedelta(hours=hours_old)
|
||||
|
||||
# 查找指定时间前创建且仍未验证的会话记录
|
||||
stmt = select(LoginSession).where(
|
||||
LoginSession.is_verified == False,
|
||||
LoginSession.created_at < cutoff_time
|
||||
)
|
||||
result = await db.exec(stmt)
|
||||
unverified_sessions = result.all()
|
||||
|
||||
# 删除未验证的会话记录
|
||||
deleted_count = 0
|
||||
for session in unverified_sessions:
|
||||
await db.delete(session)
|
||||
deleted_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(
|
||||
f"[Cleanup Service] Cleaned up {deleted_count} unverified login sessions older than {hours_old} hour(s)"
|
||||
)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning unverified login sessions: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
|
||||
"""
|
||||
@@ -150,7 +194,7 @@ class DatabaseCleanupService:
|
||||
# 查找指定天数前的已验证会话记录
|
||||
cutoff_time = utcnow() - timedelta(days=days_old)
|
||||
|
||||
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
||||
stmt = select(LoginSession).where(LoginSession.is_verified == True)
|
||||
result = await db.exec(stmt)
|
||||
all_verified_sessions = result.all()
|
||||
|
||||
@@ -200,6 +244,9 @@ class DatabaseCleanupService:
|
||||
# 清理过期的登录会话
|
||||
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||
|
||||
# 清理1小时前未验证的登录会话
|
||||
results["unverified_login_sessions"] = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1)
|
||||
|
||||
# 清理7天前的已使用验证码
|
||||
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
|
||||
|
||||
@@ -227,6 +274,7 @@ class DatabaseCleanupService:
|
||||
"""
|
||||
try:
|
||||
current_time = utcnow()
|
||||
cutoff_1_hour = current_time - timedelta(hours=1)
|
||||
cutoff_7_days = current_time - timedelta(days=7)
|
||||
cutoff_30_days = current_time - timedelta(days=30)
|
||||
|
||||
@@ -240,8 +288,16 @@ class DatabaseCleanupService:
|
||||
expired_sessions_result = await db.exec(expired_sessions_stmt)
|
||||
expired_sessions_count = len(expired_sessions_result.all())
|
||||
|
||||
# 统计1小时前未验证的登录会话数量
|
||||
unverified_sessions_stmt = select(LoginSession).where(
|
||||
LoginSession.is_verified == False,
|
||||
LoginSession.created_at < cutoff_1_hour
|
||||
)
|
||||
unverified_sessions_result = await db.exec(unverified_sessions_stmt)
|
||||
unverified_sessions_count = len(unverified_sessions_result.all())
|
||||
|
||||
# 统计7天前的已使用验证码数量
|
||||
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
|
||||
old_used_codes_stmt = select(EmailVerification).where(EmailVerification.is_used == True)
|
||||
old_used_codes_result = await db.exec(old_used_codes_stmt)
|
||||
all_used_codes = old_used_codes_result.all()
|
||||
old_used_codes_count = len(
|
||||
@@ -249,7 +305,7 @@ class DatabaseCleanupService:
|
||||
)
|
||||
|
||||
# 统计30天前的已验证会话数量
|
||||
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
||||
old_verified_sessions_stmt = select(LoginSession).where(LoginSession.is_verified == True)
|
||||
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
|
||||
all_verified_sessions = old_verified_sessions_result.all()
|
||||
old_verified_sessions_count = len(
|
||||
@@ -263,10 +319,12 @@ class DatabaseCleanupService:
|
||||
return {
|
||||
"expired_verification_codes": expired_codes_count,
|
||||
"expired_login_sessions": expired_sessions_count,
|
||||
"unverified_login_sessions": unverified_sessions_count,
|
||||
"old_used_verification_codes": old_used_codes_count,
|
||||
"old_verified_sessions": old_verified_sessions_count,
|
||||
"total_cleanable": expired_codes_count
|
||||
+ expired_sessions_count
|
||||
+ unverified_sessions_count
|
||||
+ old_used_codes_count
|
||||
+ old_verified_sessions_count,
|
||||
}
|
||||
@@ -276,6 +334,7 @@ class DatabaseCleanupService:
|
||||
return {
|
||||
"expired_verification_codes": 0,
|
||||
"expired_login_sessions": 0,
|
||||
"unverified_login_sessions": 0,
|
||||
"old_used_verification_codes": 0,
|
||||
"old_verified_sessions": 0,
|
||||
"total_cleanable": 0,
|
||||
|
||||
@@ -330,6 +330,7 @@ class RankingCacheService:
|
||||
|
||||
# 计算统计信息
|
||||
stats = {
|
||||
"total": total_users,
|
||||
"total_users": total_users,
|
||||
"last_updated": utcnow().isoformat(),
|
||||
"type": type,
|
||||
|
||||
@@ -7,10 +7,11 @@ from __future__ import annotations
|
||||
from datetime import timedelta
|
||||
import secrets
|
||||
import string
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.database.verification import EmailVerification, LoginSession
|
||||
from app.interfaces.session_verification import SessionVerificationInterface
|
||||
from app.log import logger
|
||||
from app.service.client_detection_service import ClientDetectionService, ClientInfo
|
||||
from app.service.device_trust_service import DeviceTrustService
|
||||
@@ -514,6 +515,26 @@ This email was sent automatically, please do not reply.
|
||||
class LoginSessionService:
|
||||
"""登录会话服务"""
|
||||
|
||||
# Session verification interface methods
|
||||
@staticmethod
|
||||
async def find_for_verification(db: AsyncSession, session_id: str) -> Optional[LoginSession]:
|
||||
"""根据会话ID查找会话用于验证"""
|
||||
try:
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.session_token == session_id,
|
||||
LoginSession.expires_at > utcnow(),
|
||||
)
|
||||
)
|
||||
return result.first()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_key_for_event(session_id: str) -> str:
|
||||
"""获取用于事件广播的会话密钥"""
|
||||
return f"g0v0:{session_id}"
|
||||
|
||||
@staticmethod
|
||||
async def create_session(
|
||||
db: AsyncSession,
|
||||
|
||||
6
main.py
6
main.py
@@ -46,6 +46,8 @@ from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
import sentry_sdk
|
||||
|
||||
from app.middleware.verify_session import VerifySessionMiddleware
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -171,6 +173,10 @@ app.include_router(lio_router)
|
||||
# from app.signalr import signalr_router
|
||||
# app.include_router(signalr_router)
|
||||
|
||||
# 会话验证中间件
|
||||
if settings.enable_session_verification:
|
||||
app.add_middleware(VerifySessionMiddleware)
|
||||
|
||||
# CORS 配置
|
||||
origins = []
|
||||
for url in [*settings.cors_urls, settings.server_url]:
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"""feat(db): add session verification fields to login_session
|
||||
|
||||
Revision ID: 9419272e4c85
|
||||
Revises: fe8e9f3da298
|
||||
Create Date: 2025-09-24 00:46:57.367742
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9419272e4c85"
|
||||
down_revision: str | Sequence[str] | None = "fe8e9f3da298"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("login_sessions", sa.Column("session_token", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=True))
|
||||
op.add_column("login_sessions", sa.Column("verification_method", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True))
|
||||
op.create_index(op.f("ix_login_sessions_session_token"), "login_sessions", ["session_token"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_login_sessions_session_token"), table_name="login_sessions")
|
||||
op.drop_column("login_sessions", "verification_method")
|
||||
op.drop_column("login_sessions", "session_token")
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user