添加防止重放攻击
This commit is contained in:
17
app/auth.py
17
app/auth.py
@@ -331,6 +331,23 @@ def verify_totp_key(secret: str, code: str) -> bool:
|
|||||||
return pyotp.TOTP(secret).verify(code, valid_window=1)
|
return pyotp.TOTP(secret).verify(code, valid_window=1)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_totp_key_with_replay_protection(
|
||||||
|
user_id: int, secret: str, code: str, redis: Redis
|
||||||
|
) -> bool:
|
||||||
|
"""验证TOTP密钥,并防止密钥重放攻击"""
|
||||||
|
if not pyotp.TOTP(secret).verify(code, valid_window=1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 防止120秒内重复使用同一密钥(参考osu-web实现)
|
||||||
|
cache_key = f"totp:{user_id}:{code}"
|
||||||
|
if await redis.exists(cache_key):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 设置120秒过期时间
|
||||||
|
await redis.setex(cache_key, 120, "1")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _generate_backup_codes(count=10, length=BACKUP_CODE_LENGTH) -> list[str]:
|
def _generate_backup_codes(count=10, length=BACKUP_CODE_LENGTH) -> list[str]:
|
||||||
alphabet = string.ascii_uppercase + string.digits
|
alphabet = string.ascii_uppercase + string.digits
|
||||||
return ["".join(secrets.choice(alphabet) for _ in range(length)) for _ in range(count)]
|
return ["".join(secrets.choice(alphabet) for _ in range(length)) for _ in range(count)]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from app.auth import (
|
|||||||
finish_create_totp_key,
|
finish_create_totp_key,
|
||||||
start_create_totp_key,
|
start_create_totp_key,
|
||||||
totp_redis_key,
|
totp_redis_key,
|
||||||
verify_totp_key,
|
verify_totp_key_with_replay_protection,
|
||||||
)
|
)
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.const import BACKUP_CODE_LENGTH
|
from app.const import BACKUP_CODE_LENGTH
|
||||||
@@ -92,12 +92,21 @@ async def finish_create_totp(
|
|||||||
async def disable_totp(
|
async def disable_totp(
|
||||||
session: Database,
|
session: Database,
|
||||||
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"),
|
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"),
|
||||||
|
redis: Redis = Depends(get_redis),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
):
|
):
|
||||||
totp = await session.get(TotpKeys, current_user.id)
|
totp = await session.get(TotpKeys, current_user.id)
|
||||||
if not totp:
|
if not totp:
|
||||||
raise HTTPException(status_code=400, detail="TOTP is not enabled for this user")
|
raise HTTPException(status_code=400, detail="TOTP is not enabled for this user")
|
||||||
if verify_totp_key(totp.secret, code) or (len(code) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp, code)):
|
|
||||||
|
# 使用防重放保护的TOTP验证或备份码验证
|
||||||
|
is_totp_valid = False
|
||||||
|
if len(code) == 6 and code.isdigit():
|
||||||
|
is_totp_valid = await verify_totp_key_with_replay_protection(current_user.id, totp.secret, code, redis)
|
||||||
|
elif len(code) == BACKUP_CODE_LENGTH:
|
||||||
|
is_totp_valid = check_totp_backup_code(totp, code)
|
||||||
|
|
||||||
|
if is_totp_valid:
|
||||||
await session.delete(totp)
|
await session.delete(totp)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from app.auth import check_totp_backup_code, verify_totp_key
|
from app.auth import check_totp_backup_code, verify_totp_key_with_replay_protection
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.const import BACKUP_CODE_LENGTH
|
from app.const import BACKUP_CODE_LENGTH
|
||||||
from app.database.auth import TotpKeys
|
from app.database.auth import TotpKeys
|
||||||
@@ -40,7 +40,11 @@ class SessionReissueResponse(BaseModel):
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
class VerifyFailed(Exception): ...
|
class VerifyFailed(Exception):
|
||||||
|
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
|
||||||
|
super().__init__(message)
|
||||||
|
self.reason = reason
|
||||||
|
self.should_reissue = should_reissue
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -80,28 +84,41 @@ async def verify_session(
|
|||||||
try:
|
try:
|
||||||
totp_key: TotpKeys | None = await current_user.awaitable_attrs.totp_key
|
totp_key: TotpKeys | None = await current_user.awaitable_attrs.totp_key
|
||||||
if verify_method is None:
|
if verify_method is None:
|
||||||
verify_method = "totp" if totp_key else "mail"
|
# 智能选择验证方法(参考osu-web实现)
|
||||||
|
# API版本较老或用户未设置TOTP时强制使用邮件验证
|
||||||
|
if api_version < 20240101 or totp_key is None:
|
||||||
|
verify_method = "mail"
|
||||||
|
else:
|
||||||
|
verify_method = "totp"
|
||||||
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
|
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
|
||||||
login_method = verify_method
|
login_method = verify_method
|
||||||
|
|
||||||
if verify_method == "totp":
|
if verify_method == "totp":
|
||||||
if not totp_key:
|
if not totp_key:
|
||||||
|
# TOTP密钥在验证开始和现在之间被删除(参考osu-web的fallback机制)
|
||||||
if settings.enable_email_verification:
|
if settings.enable_email_verification:
|
||||||
await LoginSessionService.set_login_method(user_id, token_id, "mail", redis)
|
await LoginSessionService.set_login_method(user_id, token_id, "mail", redis)
|
||||||
await EmailVerificationService.send_verification_email(
|
await EmailVerificationService.send_verification_email(
|
||||||
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||||
)
|
)
|
||||||
verify_method = "mail"
|
verify_method = "mail"
|
||||||
raise VerifyFailed("用户未设置 TOTP,已发送邮件验证码")
|
raise VerifyFailed("用户TOTP已被删除,已切换到邮件验证")
|
||||||
# 如果未开启邮箱验证,则直接认为认证通过
|
# 如果未开启邮箱验证,则直接认为认证通过
|
||||||
# 正常不会进入到这里
|
# 正常不会进入到这里
|
||||||
|
|
||||||
elif verify_totp_key(totp_key.secret, verification_key):
|
elif await verify_totp_key_with_replay_protection(user_id, totp_key.secret, verification_key, redis):
|
||||||
pass
|
pass
|
||||||
elif len(verification_key) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp_key, verification_key):
|
elif len(verification_key) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp_key, verification_key):
|
||||||
login_method = "totp_backup_code"
|
login_method = "totp_backup_code"
|
||||||
else:
|
else:
|
||||||
raise VerifyFailed("TOTP 验证失败")
|
# 记录详细的验证失败原因(参考osu-web的错误处理)
|
||||||
|
if len(verification_key) != 6:
|
||||||
|
raise VerifyFailed("TOTP验证码长度错误,应为6位数字", reason="incorrect_length")
|
||||||
|
elif not verification_key.isdigit():
|
||||||
|
raise VerifyFailed("TOTP验证码格式错误,应为纯数字", reason="incorrect_format")
|
||||||
|
else:
|
||||||
|
# 可能是密钥错误或者重放攻击
|
||||||
|
raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
|
||||||
else:
|
else:
|
||||||
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
|
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
|
||||||
if not success:
|
if not success:
|
||||||
@@ -127,7 +144,28 @@ async def verify_session(
|
|||||||
login_method=login_method,
|
login_method=login_method,
|
||||||
notes=str(e),
|
notes=str(e),
|
||||||
)
|
)
|
||||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": verify_method})
|
|
||||||
|
# 构建更详细的错误响应(参考osu-web的错误处理)
|
||||||
|
error_response = {
|
||||||
|
"method": verify_method,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果有具体的错误原因,添加到响应中
|
||||||
|
if hasattr(e, 'reason') and e.reason:
|
||||||
|
error_response["reason"] = e.reason
|
||||||
|
|
||||||
|
# 如果需要重新发送邮件验证码
|
||||||
|
if hasattr(e, 'should_reissue') and e.should_reissue and verify_method == "mail":
|
||||||
|
try:
|
||||||
|
await EmailVerificationService.send_verification_email(
|
||||||
|
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||||
|
)
|
||||||
|
error_response["reissued"] = True
|
||||||
|
except Exception:
|
||||||
|
pass # 忽略重发邮件失败的错误
|
||||||
|
|
||||||
|
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|||||||
Reference in New Issue
Block a user