Merge pull request #46 from GooGuTeam/totp-fix

Totp fix
This commit is contained in:
咕谷酱
2025-09-24 03:19:23 +08:00
committed by GitHub
18 changed files with 1076 additions and 25 deletions

View File

@@ -317,13 +317,51 @@ def totp_redis_key(user: User) -> str:
return f"totp:setup:{user.email}"
def _generate_totp_account_label(user: User) -> str:
"""生成TOTP账户标签
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
"""
if settings.totp_use_username_in_label:
# 使用用户名作为主要标识
primary_identifier = user.username
else:
# 使用邮箱作为标识
primary_identifier = user.email
# 如果配置了服务名称,添加到标签中以便在认证器中区分
if settings.totp_service_name:
return f"{primary_identifier} ({settings.totp_service_name})"
else:
return primary_identifier
def _generate_totp_issuer_name() -> str:
"""生成TOTP发行者名称
优先使用自定义的totp_issuer否则使用服务名称
"""
if settings.totp_issuer:
return settings.totp_issuer
elif settings.totp_service_name:
return settings.totp_service_name
else:
# 回退到默认值
return "osu! Private Server"
async def start_create_totp_key(user: User, redis: Redis) -> StartCreateTotpKeyResp:
secret = pyotp.random_base32()
await redis.hset(totp_redis_key(user), mapping={"secret": secret, "fails": 0}) # pyright: ignore[reportGeneralTypeIssues]
await redis.expire(totp_redis_key(user), 300)
# 生成更完整的账户标签和issuer信息
account_label = _generate_totp_account_label(user)
issuer_name = _generate_totp_issuer_name()
return StartCreateTotpKeyResp(
secret=secret,
uri=pyotp.totp.TOTP(secret).provisioning_uri(name=user.email, issuer_name=settings.totp_issuer),
uri=pyotp.totp.TOTP(secret).provisioning_uri(name=account_label, issuer_name=issuer_name),
)
@@ -331,6 +369,23 @@ def verify_totp_key(secret: str, code: str) -> bool:
return pyotp.TOTP(secret).verify(code, valid_window=1)
async def verify_totp_key_with_replay_protection(
user_id: int, secret: str, code: str, redis: Redis
) -> bool:
"""验证TOTP密钥并防止密钥重放攻击"""
if not pyotp.TOTP(secret).verify(code, valid_window=1):
return False
# 防止120秒内重复使用同一密钥参考osu-web实现
cache_key = f"totp:{user_id}:{code}"
if await redis.exists(cache_key):
return False
# 设置120秒过期时间
await redis.setex(cache_key, 120, "1")
return True
def _generate_backup_codes(count=10, length=BACKUP_CODE_LENGTH) -> list[str]:
alphabet = string.ascii_uppercase + string.digits
return ["".join(secrets.choice(alphabet) for _ in range(length)) for _ in range(count)]

View File

@@ -304,6 +304,16 @@ STORAGE_SETTINGS='{
Field(default=None, description="TOTP 认证器中的发行者名称"),
"验证服务设置",
]
totp_service_name: Annotated[
str,
Field(default="g0v0! Lazer Server", description="TOTP 认证器中显示的服务名称"),
"验证服务设置",
]
totp_use_username_in_label: Annotated[
bool,
Field(default=True, description="在TOTP标签中使用用户名而不是邮箱"),
"验证服务设置",
]
enable_email_verification: Annotated[
bool,
Field(default=False, description="是否启用邮件验证功能"),
@@ -314,6 +324,11 @@ STORAGE_SETTINGS='{
Field(default=True, description="是否启用智能验证(基于客户端类型和设备信任)"),
"验证服务设置",
]
enable_session_verification: Annotated[
bool,
Field(default=True, description="是否启用会话验证中间件"),
"验证服务设置",
]
enable_multi_device_login: Annotated[
bool,
Field(default=True, description="是否允许多设备同时登录"),

View File

@@ -49,5 +49,7 @@ class LoginSession(SQLModel, table=True):
verified_at: datetime | None = Field(default=None)
expires_at: datetime = Field() # 会话过期时间
is_new_location: bool = Field(default=False) # 是否新位置登录
session_token: str | None = Field(default=None, max_length=64, index=True) # 会话令牌
verification_method: str | None = Field(default=None, max_length=20) # 验证方法 (totp/mail)
token: Optional["OAuthToken"] = Relationship(back_populates="login_session")

View File

@@ -8,9 +8,11 @@ from app.database import User
from app.database.auth import OAuthToken, V1APIKeys
from app.models.oauth import OAuth2ClientCredentialsBearer
from .database import Database
from .api_version import APIVersion
from .database import Database, get_redis
from fastapi import Depends, HTTPException
from redis.asyncio import Redis
from fastapi.security import (
APIKeyQuery,
HTTPBearer,
@@ -97,13 +99,40 @@ async def get_client_user_no_verified(user_and_token: UserAndToken = Depends(get
return user_and_token[0]
async def get_client_user(db: Database, user_and_token: UserAndToken = Depends(get_client_user_and_token)):
async def get_client_user(
db: Database,
redis: Annotated[Redis, Depends(get_redis)],
api_version: APIVersion,
user_and_token: UserAndToken = Depends(get_client_user_and_token)
):
from app.service.verification_service import LoginSessionService
user, token = user_and_token
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
raise HTTPException(status_code=403, detail="User not verified")
# 获取当前验证方式
verify_method = None
if api_version >= 20250913:
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
if verify_method is None:
# 智能选择验证方式有TOTP优先TOTP
totp_key = await user.awaitable_attrs.totp_key
if totp_key is not None and api_version >= 20240101:
verify_method = "totp"
else:
verify_method = "mail"
# 设置选择的验证方法到Redis中避免重复选择
if api_version >= 20250913:
await LoginSessionService.set_login_method(user.id, token.id, verify_method, redis)
# 返回符合 osu! API 标准的错误响应
error_response = {
"error": "User not verified",
"method": verify_method
}
raise HTTPException(status_code=401, detail=error_response)
return user

View File

@@ -0,0 +1,74 @@
"""
会话验证接口
基于osu-web的SessionVerificationInterface实现
用于标准化会话验证行为
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
class SessionVerificationInterface(ABC):
"""会话验证接口
定义了会话验证所需的基本操作参考osu-web的实现
"""
@classmethod
@abstractmethod
async def find_for_verification(cls, session_id: str) -> Optional[SessionVerificationInterface]:
"""根据会话ID查找会话用于验证
Args:
session_id: 会话ID
Returns:
会话实例或None
"""
pass
@abstractmethod
def get_key(self) -> str:
"""获取会话密钥/ID"""
pass
@abstractmethod
def get_key_for_event(self) -> str:
"""获取用于事件广播的会话密钥"""
pass
@abstractmethod
def get_verification_method(self) -> Optional[str]:
"""获取当前验证方法
Returns:
验证方法 ('totp', 'mail') 或 None
"""
pass
@abstractmethod
def is_verified(self) -> bool:
"""检查会话是否已验证"""
pass
@abstractmethod
async def mark_verified(self) -> None:
"""标记会话为已验证"""
pass
@abstractmethod
async def set_verification_method(self, method: str) -> None:
"""设置验证方法
Args:
method: 验证方法 ('totp', 'mail')
"""
pass
@abstractmethod
def user_id(self) -> Optional[int]:
"""获取关联的用户ID"""
pass

View File

@@ -0,0 +1,9 @@
"""
中间件模块
提供会话验证和其他中间件功能
"""
from .verify_session import VerifySessionMiddleware, SessionState
__all__ = ["VerifySessionMiddleware", "SessionState"]

View File

@@ -0,0 +1,318 @@
"""
会话验证中间件和状态管理
基于osu-web的会话验证系统实现
"""
from __future__ import annotations
import asyncio
from typing import Awaitable, Callable, Optional
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlmodel.ext.asyncio.session import AsyncSession
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import with_db, get_redis
from app.log import logger
from app.service.verification_service import LoginSessionService
class SessionVerificationState:
"""会话验证状态管理类
参考osu-web的State类实现
"""
def __init__(self, session: LoginSession, user: User, redis: Redis):
self.session = session
self.user = user
self.redis = redis
@classmethod
async def get_current(
cls,
request: Request,
db: AsyncSession,
redis: Redis,
user: User,
) -> Optional[SessionVerificationState]:
"""获取当前会话验证状态"""
try:
# 从请求头或token中获取会话信息
session_token = cls._extract_session_token(request)
if not session_token:
return None
# 查找会话
session = await LoginSessionService.find_for_verification(db, session_token)
if not session or session.user_id != user.id:
return None
return cls(session, user, redis)
except Exception as e:
logger.error(f"[Session Verification] Error getting current state: {e}")
return None
@staticmethod
def _extract_session_token(request: Request) -> Optional[str]:
"""从请求中提取会话token"""
# 尝试从Authorization header提取
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return auth_header[7:] # 移除"Bearer "前缀
# 可以扩展其他提取方式
return None
def get_method(self) -> str:
"""获取验证方法
参考osu-web的逻辑智能选择验证方法
"""
current_method = self.session.verification_method
if current_method is None:
# 智能选择验证方法
# 参考osu-web: API版本 < 20250913 或用户没有TOTP时使用邮件验证
# 这里简化为检查用户是否有TOTP
totp_key = getattr(self.user, 'totp_key', None)
current_method = 'totp' if totp_key else 'mail'
# 设置验证方法
asyncio.create_task(self._set_verification_method(current_method))
return current_method
async def _set_verification_method(self, method: str) -> None:
"""内部方法:设置验证方法"""
try:
token_id = self.session.token_id
if token_id is not None and method in ['totp', 'mail']:
# 类型检查确保method是正确的字面量类型
verification_method = method if method in ['totp', 'mail'] else 'totp'
await LoginSessionService.set_login_method(
self.user.id, token_id, verification_method, self.redis # type: ignore
)
except Exception as e:
logger.error(f"[Session Verification] Error setting verification method: {e}")
def is_verified(self) -> bool:
"""检查会话是否已验证"""
return self.session.is_verified
async def mark_verified(self) -> None:
"""标记会话为已验证"""
try:
# 创建专用数据库会话
db = with_db()
try:
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.mark_session_verified(
db, self.redis, self.user.id, token_id
)
finally:
await db.close()
except Exception as e:
logger.error(f"[Session Verification] Error marking session verified: {e}")
def get_key(self) -> str:
"""获取会话密钥"""
return str(self.session.id) if self.session.id else ""
def get_key_for_event(self) -> str:
"""获取用于事件广播的会话密钥"""
return LoginSessionService.get_key_for_event(self.get_key())
def user_id(self) -> int:
"""获取用户ID"""
return self.user.id
async def issue_mail_if_needed(self) -> None:
"""如果需要,发送验证邮件"""
try:
if self.get_method() == "mail":
from app.service.verification_service import EmailVerificationService
# 创建专用数据库会话发送邮件
db = with_db()
try:
await EmailVerificationService.send_verification_email(
db, self.redis, self.user.id, self.user.username,
self.user.email, None, None
)
finally:
await db.close()
except Exception as e:
logger.error(f"[Session Verification] Error issuing mail: {e}")
class SessionVerificationController:
"""会话验证控制器
参考osu-web的Controller类实现
"""
# 需要跳过验证的路由参考osu-web的SKIP_VERIFICATION_ROUTES
SKIP_VERIFICATION_ROUTES = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/me",
"/api/v2/logout",
"/oauth/token",
}
@staticmethod
def should_skip_verification(request: Request) -> bool:
"""检查是否应该跳过验证"""
path = request.url.path
return path in SessionVerificationController.SKIP_VERIFICATION_ROUTES
@staticmethod
async def initiate_verification(
state: SessionVerificationState,
request: Request,
) -> Response:
"""启动会话验证流程
参考osu-web的initiate方法
"""
try:
method = state.get_method()
# 如果是邮件验证,发送验证邮件
if method == "mail":
await state.issue_mail_if_needed()
# API请求返回JSON响应
if request.url.path.startswith("/api/"):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"method": method}
)
# 其他情况可以扩展支持HTML响应
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"authentication": "verify",
"method": method,
"message": "Session verification required"
}
)
except Exception as e:
logger.error(f"[Session Verification] Error initiating verification: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Verification initiation failed"
)
class SessionVerificationMiddleware:
"""会话验证中间件
参考osu-web的VerifyUser中间件实现
"""
def __init__(self, app: Callable[[Request], Awaitable[Response]]):
self.app = app
async def __call__(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
"""中间件主要逻辑"""
try:
# 检查是否需要跳过验证
if SessionVerificationController.should_skip_verification(request):
return await call_next(request)
# 获取依赖项
user = await self._get_user(request)
if not user:
# 未认证用户跳过验证
return await call_next(request)
# 获取数据库和Redis连接
db = await self._get_db()
redis = await self._get_redis()
# 获取会话验证状态
state = await SessionVerificationState.get_current(request, db, redis, user)
if not state:
# 无法获取会话状态,继续请求
return await call_next(request)
# 检查是否已验证
if state.is_verified():
# 已验证,继续请求
return await call_next(request)
# 检查是否需要验证
if not self._requires_verification(request):
return await call_next(request)
# 启动验证流程
return await SessionVerificationController.initiate_verification(state, request)
except Exception as e:
logger.error(f"[Session Verification Middleware] Unexpected error: {e}")
# 出错时允许请求继续,避免阻塞正常流程
return await call_next(request)
async def _get_user(self, request: Request) -> Optional[User]:
"""获取当前用户"""
try:
# 这里需要手动获取用户,因为在中间件中无法直接使用依赖注入
# 简化实现实际应该从token中解析用户
return None # 暂时返回None需要实际实现
except Exception:
return None
async def _get_db(self) -> AsyncSession:
"""获取数据库连接"""
return with_db()
async def _get_redis(self) -> Redis:
"""获取Redis连接"""
return get_redis()
def _requires_verification(self, request: Request) -> bool:
"""检查是否需要验证
参考osu-web的requiresVerification方法
"""
method = request.method
# GET/HEAD/OPTIONS请求一般不需要验证
safe_methods = {"GET", "HEAD", "OPTIONS"}
if method in safe_methods:
return False
# POST/PUT/DELETE等修改操作需要验证
return True
# FastAPI中间件包装器
class FastAPISessionVerificationMiddleware:
"""FastAPI会话验证中间件包装器"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope, receive)
async def call_next(req: Request) -> Response:
# 这里需要调用FastAPI应用
return Response("OK") # 占位符实现
middleware = SessionVerificationMiddleware(call_next)
response = await middleware(request, call_next)
await response(scope, receive, send)

44
app/middleware/setup.py Normal file
View File

@@ -0,0 +1,44 @@
"""
中间件设置和配置
展示如何将会话验证中间件集成到FastAPI应用中
"""
from fastapi import FastAPI
from app.config import settings
from app.middleware.verify_session import VerifySessionMiddleware
def setup_session_verification_middleware(app: FastAPI) -> None:
"""设置会话验证中间件
Args:
app: FastAPI应用实例
"""
# 只在启用会话验证时添加中间件
if settings.enable_session_verification:
app.add_middleware(VerifySessionMiddleware)
# 可以在这里添加中间件配置日志
from app.log import logger
logger.info("[Middleware] Session verification middleware enabled")
else:
from app.log import logger
logger.info("[Middleware] Session verification middleware disabled")
def setup_all_middlewares(app: FastAPI) -> None:
"""设置所有中间件
Args:
app: FastAPI应用实例
"""
# 设置会话验证中间件
setup_session_verification_middleware(app)
# 可以在这里添加其他中间件
# app.add_middleware(OtherMiddleware)
from app.log import logger
logger.info("[Middleware] All middlewares configured")

View File

@@ -0,0 +1,287 @@
"""
FastAPI会话验证中间件
基于osu-web的会话验证系统适配FastAPI框架
"""
from __future__ import annotations
from typing import Callable, Optional
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.middleware.base import BaseHTTPMiddleware
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import with_db, get_redis
from app.auth import get_token_by_access_token
from app.log import logger
from app.service.verification_service import LoginSessionService
from sqlmodel import select
class VerifySessionMiddleware(BaseHTTPMiddleware):
"""会话验证中间件
参考osu-web的VerifyUser中间件适配FastAPI
"""
# 需要跳过验证的路由
SKIP_VERIFICATION_ROUTES = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/me",
"/api/v2/logout",
"/oauth/token",
"/health",
"/metrics",
"/docs",
"/openapi.json",
"/redoc",
}
# 需要强制验证的路由模式(敏感操作)
ALWAYS_VERIFY_PATTERNS = {
"/api/v2/account/",
"/api/v2/settings/",
"/api/private/admin/",
}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""中间件主处理逻辑"""
try:
# 检查是否跳过验证
if self._should_skip_verification(request):
return await call_next(request)
# 获取当前用户
user = await self._get_current_user(request)
if not user:
# 未登录用户跳过验证
return await call_next(request)
# 获取会话状态
session_state = await self._get_session_state(request, user)
if not session_state:
# 无会话状态,继续请求
return await call_next(request)
# 检查是否已验证
if session_state.is_verified():
return await call_next(request)
# 检查是否需要验证
if not self._requires_verification(request, user):
return await call_next(request)
# 启动验证流程
return await self._initiate_verification(request, session_state)
except Exception as e:
logger.error(f"[Verify Session Middleware] Error: {e}")
# 出错时允许请求继续,避免阻塞
return await call_next(request)
def _should_skip_verification(self, request: Request) -> bool:
"""检查是否应该跳过验证"""
path = request.url.path
# 完全匹配的跳过路由
if path in self.SKIP_VERIFICATION_ROUTES:
return True
# 非API请求跳过
if not path.startswith("/api/"):
return True
return False
def _requires_verification(self, request: Request, user: User) -> bool:
"""检查是否需要验证"""
path = request.url.path
method = request.method
# 检查是否为强制验证的路由
for pattern in self.ALWAYS_VERIFY_PATTERNS:
if path.startswith(pattern):
return True
# 特权用户或非活跃用户需要验证
if hasattr(user, 'is_privileged') and user.is_privileged():
return True
if hasattr(user, 'is_inactive') and user.is_inactive():
return True
# 安全方法GET/HEAD/OPTIONS一般不需要验证
safe_methods = {"GET", "HEAD", "OPTIONS"}
if method in safe_methods:
return False
# 修改操作POST/PUT/DELETE/PATCH需要验证
return method in {"POST", "PUT", "DELETE", "PATCH"}
async def _get_current_user(self, request: Request) -> Optional[User]:
"""获取当前用户"""
try:
# 从Authorization header提取token
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[7:] # 移除"Bearer "前缀
# 创建专用数据库会话
db = with_db()
try:
# 获取token记录
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
# 获取用户
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
return user
finally:
await db.close()
except Exception as e:
logger.debug(f"[Verify Session Middleware] Error getting user: {e}")
return None
async def _get_session_state(self, request: Request, user: User) -> Optional[SessionState]:
"""获取会话状态"""
try:
# 提取会话token这里简化为使用相同的auth token
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return None
session_token = auth_header[7:]
# 获取数据库和Redis连接
db = with_db()
try:
redis = get_redis()
# 查找会话
session = await LoginSessionService.find_for_verification(db, session_token)
if not session or session.user_id != user.id:
return None
return SessionState(session, user, redis, db)
finally:
await db.close()
except Exception as e:
logger.error(f"[Verify Session Middleware] Error getting session state: {e}")
return None
async def _initiate_verification(self, request: Request, state: SessionState) -> Response:
"""启动验证流程"""
try:
method = await state.get_method()
# 如果是邮件验证,可以在这里触发发送邮件
if method == "mail":
await state.issue_mail_if_needed()
# 返回验证要求响应
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"method": method,
"message": "Session verification required"
}
)
except Exception as e:
logger.error(f"[Verify Session Middleware] Error initiating verification: {e}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": "Verification initiation failed"}
)
class SessionState:
"""会话状态类
简化版本的会话状态管理
"""
def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession):
self.session = session
self.user = user
self.redis = redis
self.db = db
self._verification_method: Optional[str] = None
def is_verified(self) -> bool:
"""检查会话是否已验证"""
return self.session.is_verified
async def get_method(self) -> str:
"""获取验证方法"""
if self._verification_method is None:
# 从Redis获取已设置的方法
token_id = self.session.token_id
if token_id is not None:
self._verification_method = await LoginSessionService.get_login_method(
self.user.id, token_id, self.redis
)
# 如果没有设置,智能选择
if self._verification_method is None:
# 检查用户是否有TOTP密钥
await self.user.awaitable_attrs.totp_key # 预加载
totp_key = getattr(self.user, 'totp_key', None)
self._verification_method = 'totp' if totp_key else 'mail'
# 保存选择的方法
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.set_login_method(
self.user.id, token_id, self._verification_method, self.redis
)
return self._verification_method
async def mark_verified(self) -> None:
"""标记会话为已验证"""
try:
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.mark_session_verified(
self.db, self.redis, self.user.id, token_id
)
self.session.is_verified = True # 更新本地状态
except Exception as e:
logger.error(f"[Session State] Error marking verified: {e}")
async def issue_mail_if_needed(self) -> None:
"""如果需要,发送验证邮件"""
try:
if await self.get_method() == "mail":
from app.service.verification_service import EmailVerificationService
# 这里可以触发邮件发送
await EmailVerificationService.send_verification_email(
self.db, self.redis, self.user.id, self.user.username,
self.user.email, None, None
)
except Exception as e:
logger.error(f"[Session State] Error issuing mail: {e}")
def get_key(self) -> str:
"""获取会话密钥"""
return str(self.session.id) if self.session.id else ""
def get_key_for_event(self) -> str:
"""获取用于事件广播的会话密钥"""
return LoginSessionService.get_key_for_event(self.get_key())
def user_id(self) -> int:
"""获取用户ID"""
return self.user.id

View File

@@ -48,9 +48,18 @@ class ChatServer:
user_id = user.id
if user_id in self.connect_client:
del self.connect_client[user_id]
# 创建频道ID列表的副本以避免在迭代过程中修改字典
channel_ids_to_process = []
for channel_id, channel in self.channels.items():
if user_id in channel:
channel.remove(user_id)
channel_ids_to_process.append(channel_id)
# 现在安全地处理每个频道
for channel_id in channel_ids_to_process:
# 再次检查用户是否仍在频道中(防止并发修改)
if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))

View File

@@ -5,9 +5,8 @@ from app.auth import (
finish_create_totp_key,
start_create_totp_key,
totp_redis_key,
verify_totp_key,
verify_totp_key_with_replay_protection,
)
from app.config import settings
from app.const import BACKUP_CODE_LENGTH
from app.database.auth import TotpKeys
from app.database.lazer_user import User
@@ -19,9 +18,38 @@ from .router import router
from fastapi import Body, Depends, HTTPException, Security
import pyotp
from pydantic import BaseModel
from redis.asyncio import Redis
class TotpStatusResp(BaseModel):
"""TOTP状态响应"""
enabled: bool
created_at: str | None = None
@router.get(
"/totp/status",
name="检查 TOTP 状态",
description="检查当前用户是否已启用 TOTP 双因素验证",
tags=["验证", "g0v0 API"],
response_model=TotpStatusResp,
)
async def get_totp_status(
current_user: User = Security(get_client_user),
):
"""检查用户是否已创建TOTP"""
totp_key = await current_user.awaitable_attrs.totp_key
if totp_key:
return TotpStatusResp(
enabled=True,
created_at=totp_key.created_at.isoformat()
)
else:
return TotpStatusResp(enabled=False)
@router.post(
"/totp/create",
name="开始 TOTP 创建流程",
@@ -44,11 +72,16 @@ async def start_create_totp(
previous = await redis.hgetall(totp_redis_key(current_user)) # pyright: ignore[reportGeneralTypeIssues]
if previous: # pyright: ignore[reportGeneralTypeIssues]
from app.auth import _generate_totp_account_label, _generate_totp_issuer_name
account_label = _generate_totp_account_label(current_user)
issuer_name = _generate_totp_issuer_name()
return StartCreateTotpKeyResp(
secret=previous["secret"],
uri=pyotp.totp.TOTP(previous["secret"]).provisioning_uri(
name=current_user.email,
issuer_name=settings.totp_issuer,
name=account_label,
issuer_name=issuer_name,
),
)
return await start_create_totp_key(current_user, redis)
@@ -92,12 +125,21 @@ async def finish_create_totp(
async def disable_totp(
session: Database,
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
):
totp = await session.get(TotpKeys, current_user.id)
if not totp:
raise HTTPException(status_code=400, detail="TOTP is not enabled for this user")
if verify_totp_key(totp.secret, code) or (len(code) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp, code)):
# 使用防重放保护的TOTP验证或备份码验证
is_totp_valid = False
if len(code) == 6 and code.isdigit():
is_totp_valid = await verify_totp_key_with_replay_protection(current_user.id, totp.secret, code, redis)
elif len(code) == BACKUP_CODE_LENGTH:
is_totp_valid = check_totp_backup_code(totp, code)
if is_totp_valid:
await session.delete(totp)
await session.commit()
else:

View File

@@ -6,7 +6,7 @@ from __future__ import annotations
from typing import Annotated, Literal
from app.auth import check_totp_backup_code, verify_totp_key
from app.auth import check_totp_backup_code, verify_totp_key_with_replay_protection
from app.config import settings
from app.const import BACKUP_CODE_LENGTH
from app.database.auth import TotpKeys
@@ -40,7 +40,11 @@ class SessionReissueResponse(BaseModel):
message: str
class VerifyFailed(Exception): ...
class VerifyFailed(Exception):
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
super().__init__(message)
self.reason = reason
self.should_reissue = should_reissue
@router.post(
@@ -80,28 +84,42 @@ async def verify_session(
try:
totp_key: TotpKeys | None = await current_user.awaitable_attrs.totp_key
if verify_method is None:
verify_method = "totp" if totp_key else "mail"
# 智能选择验证方法参考osu-web实现
# API版本较老或用户未设置TOTP时强制使用邮件验证
#print(api_version, totp_key)
if api_version < 20240101 or totp_key is None:
verify_method = "mail"
else:
verify_method = "totp"
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
login_method = verify_method
if verify_method == "totp":
if not totp_key:
# TOTP密钥在验证开始和现在之间被删除参考osu-web的fallback机制
if settings.enable_email_verification:
await LoginSessionService.set_login_method(user_id, token_id, "mail", redis)
await EmailVerificationService.send_verification_email(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
)
verify_method = "mail"
raise VerifyFailed("用户未设置 TOTP已发送邮件验证")
raise VerifyFailed("用户TOTP已被删除已切换到邮件验证")
# 如果未开启邮箱验证,则直接认为认证通过
# 正常不会进入到这里
elif verify_totp_key(totp_key.secret, verification_key):
elif await verify_totp_key_with_replay_protection(user_id, totp_key.secret, verification_key, redis):
pass
elif len(verification_key) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp_key, verification_key):
login_method = "totp_backup_code"
else:
raise VerifyFailed("TOTP 验证失败")
# 记录详细的验证失败原因参考osu-web的错误处理
if len(verification_key) != 6:
raise VerifyFailed("TOTP验证码长度错误应为6位数字", reason="incorrect_length")
elif not verification_key.isdigit():
raise VerifyFailed("TOTP验证码格式错误应为纯数字", reason="incorrect_format")
else:
# 可能是密钥错误或者重放攻击
raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
else:
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
if not success:
@@ -127,7 +145,28 @@ async def verify_session(
login_method=login_method,
notes=str(e),
)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": verify_method})
# 构建更详细的错误响应参考osu-web的错误处理
error_response = {
"error": str(e),
"method": verify_method,
}
# 如果有具体的错误原因,添加到响应中
if hasattr(e, 'reason') and e.reason:
error_response["reason"] = e.reason
# 如果需要重新发送邮件验证码
if hasattr(e, 'should_reissue') and e.should_reissue and verify_method == "mail":
try:
await EmailVerificationService.send_verification_email(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
)
error_response["reissued"] = True
except Exception:
pass # 忽略重发邮件失败的错误
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)
@router.post(

View File

@@ -74,10 +74,13 @@ class DatabaseCleanupScheduler:
# 清理过期的登录会话
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 清理1小时前未验证的登录会话
unverified_sessions = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1)
# 只在有清理记录时输出总结
total_cleaned = expired_codes + expired_sessions
total_cleaned = expired_codes + expired_sessions + unverified_sessions
if total_cleaned > 0:
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}")
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}, unverified: {unverified_sessions}")
except Exception as e:
logger.error(f"Error during scheduled database cleanup: {e!s}")

View File

@@ -10,7 +10,7 @@ from app.database.verification import EmailVerification, LoginSession
from app.log import logger
from app.utils import utcnow
from sqlmodel import col, select
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -107,7 +107,7 @@ class DatabaseCleanupService:
# 查找指定天数前的已使用验证码记录
cutoff_time = utcnow() - timedelta(days=days_old)
stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
stmt = select(EmailVerification).where(EmailVerification.is_used == True)
result = await db.exec(stmt)
all_used_codes = result.all()
@@ -134,6 +134,50 @@ class DatabaseCleanupService:
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}")
return 0
@staticmethod
async def cleanup_unverified_login_sessions(db: AsyncSession, hours_old: int = 1) -> int:
"""
清理指定小时前创建但仍未验证的登录会话
Args:
db: 数据库会话
hours_old: 清理多少小时前创建但仍未验证的会话默认1小时
Returns:
int: 清理的记录数
"""
try:
# 计算截止时间
cutoff_time = utcnow() - timedelta(hours=hours_old)
# 查找指定时间前创建且仍未验证的会话记录
stmt = select(LoginSession).where(
LoginSession.is_verified == False,
LoginSession.created_at < cutoff_time
)
result = await db.exec(stmt)
unverified_sessions = result.all()
# 删除未验证的会话记录
deleted_count = 0
for session in unverified_sessions:
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} unverified login sessions older than {hours_old} hour(s)"
)
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning unverified login sessions: {e!s}")
return 0
@staticmethod
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
"""
@@ -150,7 +194,7 @@ class DatabaseCleanupService:
# 查找指定天数前的已验证会话记录
cutoff_time = utcnow() - timedelta(days=days_old)
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
stmt = select(LoginSession).where(LoginSession.is_verified == True)
result = await db.exec(stmt)
all_verified_sessions = result.all()
@@ -200,6 +244,9 @@ class DatabaseCleanupService:
# 清理过期的登录会话
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 清理1小时前未验证的登录会话
results["unverified_login_sessions"] = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1)
# 清理7天前的已使用验证码
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
@@ -227,6 +274,7 @@ class DatabaseCleanupService:
"""
try:
current_time = utcnow()
cutoff_1_hour = current_time - timedelta(hours=1)
cutoff_7_days = current_time - timedelta(days=7)
cutoff_30_days = current_time - timedelta(days=30)
@@ -240,8 +288,16 @@ class DatabaseCleanupService:
expired_sessions_result = await db.exec(expired_sessions_stmt)
expired_sessions_count = len(expired_sessions_result.all())
# 统计1小时前未验证的登录会话数量
unverified_sessions_stmt = select(LoginSession).where(
LoginSession.is_verified == False,
LoginSession.created_at < cutoff_1_hour
)
unverified_sessions_result = await db.exec(unverified_sessions_stmt)
unverified_sessions_count = len(unverified_sessions_result.all())
# 统计7天前的已使用验证码数量
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
old_used_codes_stmt = select(EmailVerification).where(EmailVerification.is_used == True)
old_used_codes_result = await db.exec(old_used_codes_stmt)
all_used_codes = old_used_codes_result.all()
old_used_codes_count = len(
@@ -249,7 +305,7 @@ class DatabaseCleanupService:
)
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
old_verified_sessions_stmt = select(LoginSession).where(LoginSession.is_verified == True)
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
all_verified_sessions = old_verified_sessions_result.all()
old_verified_sessions_count = len(
@@ -263,10 +319,12 @@ class DatabaseCleanupService:
return {
"expired_verification_codes": expired_codes_count,
"expired_login_sessions": expired_sessions_count,
"unverified_login_sessions": unverified_sessions_count,
"old_used_verification_codes": old_used_codes_count,
"old_verified_sessions": old_verified_sessions_count,
"total_cleanable": expired_codes_count
+ expired_sessions_count
+ unverified_sessions_count
+ old_used_codes_count
+ old_verified_sessions_count,
}
@@ -276,6 +334,7 @@ class DatabaseCleanupService:
return {
"expired_verification_codes": 0,
"expired_login_sessions": 0,
"unverified_login_sessions": 0,
"old_used_verification_codes": 0,
"old_verified_sessions": 0,
"total_cleanable": 0,

View File

@@ -330,6 +330,7 @@ class RankingCacheService:
# 计算统计信息
stats = {
"total": total_users,
"total_users": total_users,
"last_updated": utcnow().isoformat(),
"type": type,

View File

@@ -7,10 +7,11 @@ from __future__ import annotations
from datetime import timedelta
import secrets
import string
from typing import Literal
from typing import Literal, Optional
from app.config import settings
from app.database.verification import EmailVerification, LoginSession
from app.interfaces.session_verification import SessionVerificationInterface
from app.log import logger
from app.service.client_detection_service import ClientDetectionService, ClientInfo
from app.service.device_trust_service import DeviceTrustService
@@ -514,6 +515,26 @@ This email was sent automatically, please do not reply.
class LoginSessionService:
"""登录会话服务"""
# Session verification interface methods
@staticmethod
async def find_for_verification(db: AsyncSession, session_id: str) -> Optional[LoginSession]:
"""根据会话ID查找会话用于验证"""
try:
result = await db.exec(
select(LoginSession).where(
LoginSession.session_token == session_id,
LoginSession.expires_at > utcnow(),
)
)
return result.first()
except Exception:
return None
@staticmethod
def get_key_for_event(session_id: str) -> str:
"""获取用于事件广播的会话密钥"""
return f"g0v0:{session_id}"
@staticmethod
async def create_session(
db: AsyncSession,

View File

@@ -46,6 +46,8 @@ from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_limiter import FastAPILimiter
import sentry_sdk
from app.middleware.verify_session import VerifySessionMiddleware
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -171,6 +173,10 @@ app.include_router(lio_router)
# from app.signalr import signalr_router
# app.include_router(signalr_router)
# 会话验证中间件
if settings.enable_session_verification:
app.add_middleware(VerifySessionMiddleware)
# CORS 配置
origins = []
for url in [*settings.cors_urls, settings.server_url]:

View File

@@ -0,0 +1,38 @@
"""feat(db): add session verification fields to login_session
Revision ID: 9419272e4c85
Revises: fe8e9f3da298
Create Date: 2025-09-24 00:46:57.367742
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = "9419272e4c85"
down_revision: str | Sequence[str] | None = "fe8e9f3da298"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("login_sessions", sa.Column("session_token", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=True))
op.add_column("login_sessions", sa.Column("verification_method", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True))
op.create_index(op.f("ix_login_sessions_session_token"), "login_sessions", ["session_token"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_login_sessions_session_token"), table_name="login_sessions")
op.drop_column("login_sessions", "verification_method")
op.drop_column("login_sessions", "session_token")
# ### end Alembic commands ###