feat(session-verify): 添加 TOTP 支持 (#34)
* chore(deps): add pyotp * feat(auth): implement TOTP verification feat(auth): implement TOTP verification and email verification services - Added TOTP keys management with a new database model `TotpKeys`. - Introduced `EmailVerification` and `LoginSession` models for email verification. - Created `verification_service` to handle email verification logic and TOTP processes. - Updated user response models to include session verification methods. - Implemented routes for TOTP creation, verification, and fallback to email verification. - Enhanced login session management to support new location checks and verification methods. - Added migration script to create `totp_keys` table in the database. * feat(config): update config example * docs(totp): complete creating TOTP flow * refactor(totp): resolve review * feat(api): forbid unverified request * fix(totp): trace session by token id to avoid other sessions are forbidden * chore(linter): make pyright happy * fix(totp): only mark sessions with a specified token id
This commit is contained in:
@@ -42,7 +42,9 @@ FETCHER_SCOPES="public"
|
||||
# Logging Settings
|
||||
LOG_LEVEL="INFO"
|
||||
|
||||
# Email Service Settings
|
||||
# Verification Settings
|
||||
ENABLE_TOTP_VERIFICATION=true
|
||||
TOTP_ISSUER="osu! server"
|
||||
ENABLE_EMAIL_VERIFICATION=false
|
||||
SMTP_SERVER="localhost"
|
||||
SMTP_PORT=587
|
||||
|
||||
77
app/auth.py
77
app/auth.py
@@ -7,16 +7,20 @@ import secrets
|
||||
import string
|
||||
|
||||
from app.config import settings
|
||||
from app.const import BACKUP_CODE_LENGTH
|
||||
from app.database import (
|
||||
OAuthToken,
|
||||
User,
|
||||
)
|
||||
from app.database.auth import TotpKeys
|
||||
from app.log import logger
|
||||
from app.models.totp import FinishStatus, StartCreateTotpKeyResp
|
||||
from app.utils import utcnow
|
||||
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
import pyotp
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -277,3 +281,76 @@ async def get_user_by_authorization_code(
|
||||
await db.refresh(user)
|
||||
return (user, scopes.split(","))
|
||||
return None
|
||||
|
||||
|
||||
def totp_redis_key(user: User) -> str:
|
||||
return f"totp:setup:{user.email}"
|
||||
|
||||
|
||||
async def start_create_totp_key(user: User, redis: Redis) -> StartCreateTotpKeyResp:
|
||||
secret = pyotp.random_base32()
|
||||
await redis.hset(totp_redis_key(user), mapping={"secret": secret, "fails": 0}) # pyright: ignore[reportGeneralTypeIssues]
|
||||
await redis.expire(totp_redis_key(user), 300)
|
||||
return StartCreateTotpKeyResp(
|
||||
secret=secret,
|
||||
uri=pyotp.totp.TOTP(secret).provisioning_uri(name=user.email, issuer_name=settings.totp_issuer),
|
||||
)
|
||||
|
||||
|
||||
def verify_totp_key(secret: str, code: str) -> bool:
|
||||
return pyotp.TOTP(secret).verify(code, valid_window=1)
|
||||
|
||||
|
||||
def _generate_backup_codes(count=10, length=BACKUP_CODE_LENGTH) -> list[str]:
|
||||
alphabet = string.ascii_uppercase + string.digits
|
||||
return ["".join(secrets.choice(alphabet) for _ in range(length)) for _ in range(count)]
|
||||
|
||||
|
||||
async def _store_totp_key(user: User, secret: str, db: AsyncSession) -> list[str]:
|
||||
backup_codes = _generate_backup_codes()
|
||||
hashed_codes = [bcrypt.hashpw(code.encode(), bcrypt.gensalt()) for code in backup_codes]
|
||||
totp_secret = TotpKeys(user_id=user.id, secret=secret, backup_keys=[code.decode() for code in hashed_codes])
|
||||
db.add(totp_secret)
|
||||
await db.commit()
|
||||
return backup_codes
|
||||
|
||||
|
||||
async def finish_create_totp_key(
|
||||
user: User, code: str, redis: Redis, db: AsyncSession
|
||||
) -> tuple[FinishStatus, list[str]]:
|
||||
data = await redis.hgetall(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues]
|
||||
if not data or "secret" not in data or "fails" not in data:
|
||||
return FinishStatus.INVALID, []
|
||||
|
||||
secret = data["secret"]
|
||||
fails = int(data["fails"])
|
||||
|
||||
if fails >= 3:
|
||||
await redis.delete(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues]
|
||||
return FinishStatus.TOO_MANY_ATTEMPTS, []
|
||||
|
||||
if verify_totp_key(secret, code):
|
||||
await redis.delete(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues]
|
||||
backup_codes = await _store_totp_key(user, secret, db)
|
||||
return FinishStatus.SUCCESS, backup_codes
|
||||
else:
|
||||
fails += 1
|
||||
await redis.hset(totp_redis_key(user), "fails", str(fails)) # pyright: ignore[reportGeneralTypeIssues]
|
||||
return FinishStatus.FAILED, []
|
||||
|
||||
|
||||
async def disable_totp(user: User, db: AsyncSession) -> None:
|
||||
totp = await db.get(TotpKeys, user.id)
|
||||
if totp:
|
||||
await db.delete(totp)
|
||||
await db.commit()
|
||||
|
||||
|
||||
def check_totp_backup_code(totp: TotpKeys, code: str) -> bool:
|
||||
for hashed_code in totp.backup_keys:
|
||||
if bcrypt.checkpw(code.encode(), hashed_code.encode()):
|
||||
copy = totp.backup_keys[:]
|
||||
copy.remove(hashed_code)
|
||||
totp.backup_keys = copy
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -297,41 +297,47 @@ STORAGE_SETTINGS='{
|
||||
"日志设置",
|
||||
]
|
||||
|
||||
# 邮件服务设置
|
||||
# 验证服务设置
|
||||
enable_totp_verification: Annotated[bool, Field(default=True, description="是否启用TOTP双因素验证"), "验证服务设置"]
|
||||
totp_issuer: Annotated[
|
||||
str | None,
|
||||
Field(default=None, description="TOTP 认证器中的发行者名称"),
|
||||
"验证服务设置",
|
||||
]
|
||||
enable_email_verification: Annotated[
|
||||
bool,
|
||||
Field(default=False, description="是否启用邮件验证功能"),
|
||||
"邮件服务设置",
|
||||
"验证服务设置",
|
||||
]
|
||||
smtp_server: Annotated[
|
||||
str,
|
||||
Field(default="localhost", description="SMTP 服务器地址"),
|
||||
"邮件服务设置",
|
||||
"验证服务设置",
|
||||
]
|
||||
smtp_port: Annotated[
|
||||
int,
|
||||
Field(default=587, description="SMTP 服务器端口"),
|
||||
"邮件服务设置",
|
||||
"验证服务设置",
|
||||
]
|
||||
smtp_username: Annotated[
|
||||
str,
|
||||
Field(default="", description="SMTP 用户名"),
|
||||
"邮件服务设置",
|
||||
"验证服务设置",
|
||||
]
|
||||
smtp_password: Annotated[
|
||||
str,
|
||||
Field(default="", description="SMTP 密码"),
|
||||
"邮件服务设置",
|
||||
"验证服务设置",
|
||||
]
|
||||
from_email: Annotated[
|
||||
str,
|
||||
Field(default="noreply@example.com", description="发件人邮箱"),
|
||||
"邮件服务设置",
|
||||
"验证服务设置",
|
||||
]
|
||||
from_name: Annotated[
|
||||
str,
|
||||
Field(default="osu! server", description="发件人名称"),
|
||||
"邮件服务设置",
|
||||
"验证服务设置",
|
||||
]
|
||||
|
||||
# 监控配置
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
BANCHOBOT_ID = 2
|
||||
|
||||
BACKUP_CODE_LENGTH = 10
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .achievement import UserAchievement, UserAchievementResp
|
||||
from .auth import OAuthClient, OAuthToken, V1APIKeys
|
||||
from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys
|
||||
from .beatmap import (
|
||||
Beatmap,
|
||||
BeatmapResp,
|
||||
@@ -25,10 +25,10 @@ from .counts import (
|
||||
ReplayWatchedCount,
|
||||
)
|
||||
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
|
||||
from .email_verification import EmailVerification, LoginSession
|
||||
from .events import Event
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
from .lazer_user import (
|
||||
MeResp,
|
||||
User,
|
||||
UserResp,
|
||||
)
|
||||
@@ -67,6 +67,7 @@ from .user_account_history import (
|
||||
UserAccountHistoryType,
|
||||
)
|
||||
from .user_login_log import UserLoginLog
|
||||
from .verification import EmailVerification, LoginSession
|
||||
|
||||
__all__ = [
|
||||
"APIUploadedRoom",
|
||||
@@ -93,6 +94,7 @@ __all__ = [
|
||||
"ItemAttemptsCount",
|
||||
"ItemAttemptsResp",
|
||||
"LoginSession",
|
||||
"MeResp",
|
||||
"MonthlyPlaycounts",
|
||||
"MultiplayerEvent",
|
||||
"MultiplayerEventResp",
|
||||
@@ -126,6 +128,7 @@ __all__ = [
|
||||
"Team",
|
||||
"TeamMember",
|
||||
"TeamRequest",
|
||||
"TotpKeys",
|
||||
"User",
|
||||
"UserAccountHistory",
|
||||
"UserAccountHistoryResp",
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.utils import utcnow
|
||||
|
||||
from .verification import LoginSession
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
@@ -23,7 +25,7 @@ if TYPE_CHECKING:
|
||||
class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__: str = "oauth_tokens"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
client_id: int = Field(index=True)
|
||||
access_token: str = Field(max_length=500, unique=True)
|
||||
@@ -34,6 +36,7 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime))
|
||||
|
||||
user: "User" = Relationship()
|
||||
login_session: LoginSession | None = Relationship(back_populates="token", passive_deletes=True)
|
||||
|
||||
|
||||
class OAuthClient(SQLModel, table=True):
|
||||
@@ -52,3 +55,13 @@ class V1APIKeys(SQLModel, table=True):
|
||||
name: str = Field(max_length=100, index=True)
|
||||
key: str = Field(default_factory=secrets.token_hex, index=True)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
|
||||
|
||||
class TotpKeys(SQLModel, table=True):
|
||||
__tablename__: str = "totp_keys"
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
|
||||
secret: str = Field(max_length=100)
|
||||
backup_keys: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime))
|
||||
|
||||
user: "User" = Relationship(back_populates="totp_key")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from typing import TYPE_CHECKING, NotRequired, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict
|
||||
|
||||
from app.config import settings
|
||||
from app.database.auth import TotpKeys
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import Country, Page
|
||||
@@ -166,6 +167,7 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
back_populates="user",
|
||||
)
|
||||
events: list[Event] = Relationship(back_populates="user")
|
||||
totp_key: TotpKeys | None = Relationship(back_populates="user")
|
||||
|
||||
email: str = Field(max_length=254, unique=True, index=True, exclude=True)
|
||||
priv: int = Field(default=1, exclude=True)
|
||||
@@ -255,6 +257,8 @@ class UserResp(UserBase):
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
ruleset: GameMode | None = None,
|
||||
*,
|
||||
token_id: int | None = None,
|
||||
) -> "UserResp":
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
@@ -421,26 +425,42 @@ class UserResp(UserBase):
|
||||
)
|
||||
).one()
|
||||
|
||||
# 检查会话验证状态
|
||||
# 如果邮件验证功能被禁用,则始终设置 session_verified 为 true
|
||||
if "session_verified" in include:
|
||||
from app.service.verification_service import LoginSessionService
|
||||
|
||||
if not settings.enable_email_verification:
|
||||
u.session_verified = True
|
||||
u.session_verified = (
|
||||
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
|
||||
if token_id
|
||||
else True
|
||||
)
|
||||
|
||||
return u
|
||||
|
||||
|
||||
class MeResp(UserResp):
|
||||
session_verification_method: Literal["totp", "mail"] | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
obj: User,
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
ruleset: GameMode | None = None,
|
||||
*,
|
||||
token_id: int | None = None,
|
||||
) -> "MeResp":
|
||||
from app.dependencies.database import get_redis
|
||||
from app.service.verification_service import LoginSessionService
|
||||
|
||||
u = await super().from_db(obj, session, ["session_verified", *include], ruleset, token_id=token_id)
|
||||
u = cls.model_validate(u.model_dump())
|
||||
if (settings.enable_totp_verification or settings.enable_email_verification) and token_id:
|
||||
redis = get_redis()
|
||||
if not u.session_verified:
|
||||
u.session_verification_method = await LoginSessionService.get_login_method(obj.id, token_id, redis)
|
||||
else:
|
||||
# 如果用户有未验证的登录会话,则设置 session_verified 为 false
|
||||
from .email_verification import LoginSession
|
||||
|
||||
unverified_session = (
|
||||
await session.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == obj.id,
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > utcnow(),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
u.session_verified = unverified_session is None
|
||||
|
||||
u.session_verification_method = None
|
||||
return u
|
||||
|
||||
|
||||
@@ -455,6 +475,7 @@ ALL_INCLUDED = [
|
||||
"monthly_playcounts",
|
||||
"replays_watched_counts",
|
||||
"rank_history",
|
||||
"session_verified",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
邮件验证相关数据库模型
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from app.utils import utcnow
|
||||
|
||||
from sqlalchemy import BigInteger, Column, ForeignKey
|
||||
from sqlmodel import Field, SQLModel
|
||||
from sqlmodel import Field, Integer, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .auth import OAuthToken
|
||||
|
||||
|
||||
class EmailVerification(SQLModel, table=True):
|
||||
@@ -36,7 +38,9 @@ class LoginSession(SQLModel, table=True):
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||
session_token: str = Field(unique=True, index=True) # 会话令牌
|
||||
token_id: int | None = Field(
|
||||
sa_column=Column(Integer, ForeignKey("oauth_tokens.id", ondelete="SET NULL"), nullable=True, index=True)
|
||||
)
|
||||
ip_address: str = Field() # 登录IP
|
||||
user_agent: str | None = Field(default=None, max_length=250)
|
||||
country_code: str | None = Field(default=None)
|
||||
@@ -45,3 +49,5 @@ class LoginSession(SQLModel, table=True):
|
||||
verified_at: datetime | None = Field(default=None)
|
||||
expires_at: datetime = Field() # 会话过期时间
|
||||
is_new_location: bool = Field(default=False) # 是否新位置登录
|
||||
|
||||
token: Optional["OAuthToken"] = Relationship(back_populates="login_session")
|
||||
@@ -5,7 +5,7 @@ from typing import Annotated
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.config import settings
|
||||
from app.database import User
|
||||
from app.database.auth import V1APIKeys
|
||||
from app.database.auth import OAuthToken, V1APIKeys
|
||||
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||
|
||||
from .database import Database
|
||||
@@ -75,10 +75,10 @@ async def v1_authorize(
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
|
||||
async def get_client_user(
|
||||
async def get_client_user_and_token(
|
||||
db: Database,
|
||||
token: Annotated[str, Depends(oauth2_password)],
|
||||
):
|
||||
) -> tuple[User, OAuthToken]:
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
@@ -87,17 +87,33 @@ async def get_client_user(
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
await db.refresh(user)
|
||||
return user, token_record
|
||||
|
||||
|
||||
UserAndToken = tuple[User, OAuthToken]
|
||||
|
||||
|
||||
async def get_client_user_no_verified(user_and_token: UserAndToken = Depends(get_client_user_and_token)):
|
||||
return user_and_token[0]
|
||||
|
||||
|
||||
async def get_client_user(db: Database, user_and_token: UserAndToken = Depends(get_client_user_and_token)):
|
||||
from app.service.verification_service import LoginSessionService
|
||||
|
||||
user, token = user_and_token
|
||||
|
||||
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
|
||||
raise HTTPException(status_code=403, detail="User not verified")
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
async def get_current_user_and_token(
|
||||
db: Database,
|
||||
security_scopes: SecurityScopes,
|
||||
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
||||
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
||||
token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None,
|
||||
) -> User:
|
||||
) -> UserAndToken:
|
||||
"""获取当前认证用户"""
|
||||
token = token_pw or token_code or token_client_credentials
|
||||
if not token:
|
||||
@@ -120,6 +136,10 @@ async def get_current_user(
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
return user, token_record
|
||||
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def get_current_user(
|
||||
user_and_token: UserAndToken = Depends(get_current_user_and_token),
|
||||
) -> User:
|
||||
return user_and_token[0]
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
"""
|
||||
APIMe 响应模型 - 对应 osu! 的 APIMe 类型
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.lazer_user import UserResp
|
||||
|
||||
|
||||
class APIMe(UserResp):
|
||||
"""
|
||||
/me 端点的响应模型
|
||||
对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段
|
||||
|
||||
session_verified 字段已经在 UserResp 中定义,这里不需要重复定义
|
||||
"""
|
||||
|
||||
pass
|
||||
16
app/models/totp.py
Normal file
16
app/models/totp.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class StartCreateTotpKeyResp(TypedDict):
|
||||
secret: str
|
||||
uri: str
|
||||
|
||||
|
||||
class FinishStatus(str, Enum):
|
||||
INVALID = "invalid"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
TOO_MANY_ATTEMPTS = "too_many_attempts"
|
||||
@@ -17,6 +17,7 @@ from app.auth import (
|
||||
from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import DailyChallengeStats, OAuthClient, User
|
||||
from app.database.auth import TotpKeys
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
@@ -30,12 +31,12 @@ from app.models.oauth import (
|
||||
UserRegistrationErrors,
|
||||
)
|
||||
from app.models.score import GameMode
|
||||
from app.service.email_verification_service import (
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.service.password_reset_service import password_reset_service
|
||||
from app.service.verification_service import (
|
||||
EmailVerificationService,
|
||||
LoginSessionService,
|
||||
)
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.service.password_reset_service import password_reset_service
|
||||
from app.utils import utcnow
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
@@ -287,8 +288,23 @@ async def oauth_token(
|
||||
# 确保用户对象与当前会话关联
|
||||
await db.refresh(user)
|
||||
|
||||
# 获取用户信息和客户端信息
|
||||
user_id = user.id
|
||||
totp_key: TotpKeys | None = await user.awaitable_attrs.totp_key
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
token = await store_token(
|
||||
db,
|
||||
user_id,
|
||||
client_id,
|
||||
scopes,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
token_id = token.id
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
@@ -300,15 +316,22 @@ async def oauth_token(
|
||||
# 检查是否为新位置登录
|
||||
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
|
||||
|
||||
# 创建登录会话记录
|
||||
login_session = await LoginSessionService.create_session( # noqa: F841
|
||||
db, redis, user_id, ip_address, user_agent, country_code, is_new_location
|
||||
)
|
||||
|
||||
# 如果是新位置登录,需要邮件验证
|
||||
if is_new_location and settings.enable_email_verification:
|
||||
session_verification_method = None
|
||||
if settings.enable_totp_verification and totp_key is not None:
|
||||
session_verification_method = "totp"
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
request=request,
|
||||
login_success=True,
|
||||
login_method="password_pending_verification",
|
||||
notes="需要 TOTP 验证",
|
||||
)
|
||||
elif is_new_location and settings.enable_email_verification:
|
||||
# 如果是新位置登录,需要邮件验证
|
||||
# 刷新用户对象以确保属性已加载
|
||||
await db.refresh(user)
|
||||
session_verification_method = "mail"
|
||||
|
||||
# 发送邮件验证码
|
||||
verification_sent = await EmailVerificationService.send_verification_email(
|
||||
@@ -328,9 +351,9 @@ async def oauth_token(
|
||||
if not verification_sent:
|
||||
# 邮件发送失败,记录错误
|
||||
logger.error(f"[Auth] Failed to send email verification code for user {user_id}")
|
||||
elif is_new_location and not settings.enable_email_verification:
|
||||
elif is_new_location:
|
||||
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
await LoginSessionService.mark_session_verified(db, redis, user_id, token_id)
|
||||
logger.debug(
|
||||
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
|
||||
)
|
||||
@@ -345,25 +368,16 @@ async def oauth_token(
|
||||
notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}",
|
||||
)
|
||||
|
||||
# 无论是否新位置登录,都返回正常的token
|
||||
# session_verified状态通过/me接口的session_verified字段来体现
|
||||
if session_verification_method:
|
||||
await LoginSessionService.create_session(
|
||||
db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, False
|
||||
)
|
||||
await LoginSessionService.set_login_method(user_id, token_id, session_verification_method, redis)
|
||||
else:
|
||||
await LoginSessionService.create_session(
|
||||
db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, True
|
||||
)
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
# 获取用户ID,避免触发延迟加载
|
||||
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
await store_token(
|
||||
db,
|
||||
user_id,
|
||||
client_id,
|
||||
scopes,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.dependencies.database import (
|
||||
get_redis,
|
||||
with_db,
|
||||
)
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.dependencies.user import get_current_user_and_token
|
||||
from app.log import logger
|
||||
from app.models.chat import ChatEvent
|
||||
from app.models.notification import NotificationDetail
|
||||
@@ -311,7 +311,11 @@ async def chat_websocket(
|
||||
await websocket.close(code=1008, reason="Missing authentication token")
|
||||
return
|
||||
|
||||
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token)) is None:
|
||||
if (
|
||||
user_and_token := await get_current_user_and_token(
|
||||
session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token
|
||||
)
|
||||
) is None:
|
||||
await websocket.close(code=1008, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
@@ -320,6 +324,7 @@ async def chat_websocket(
|
||||
if login.get("event") != "chat.start":
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
user = user_and_token[0]
|
||||
user_id = user.id
|
||||
server.connect(user_id, websocket)
|
||||
# 使用明确的查询避免延迟加载
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from . import avatar, beatmapset_ratings, cover, oauth, relationship, team, username # noqa: F401
|
||||
from .router import router as private_router
|
||||
|
||||
if settings.enable_totp_verification:
|
||||
from . import totp # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"private_router",
|
||||
]
|
||||
|
||||
104
app/router/private/totp.py
Normal file
104
app/router/private/totp.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.auth import (
|
||||
check_totp_backup_code,
|
||||
finish_create_totp_key,
|
||||
start_create_totp_key,
|
||||
totp_redis_key,
|
||||
verify_totp_key,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.const import BACKUP_CODE_LENGTH
|
||||
from app.database.auth import TotpKeys
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.models.totp import FinishStatus, StartCreateTotpKeyResp
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Body, Depends, HTTPException, Security
|
||||
import pyotp
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
@router.post(
|
||||
"/totp/create",
|
||||
name="开始 TOTP 创建流程",
|
||||
description=(
|
||||
"开始 TOTP 创建流程\n\n"
|
||||
"返回 TOTP 密钥和 URI,供用户在身份验证器应用中添加账户。\n\n"
|
||||
"然后将身份验证器应用提供的 TOTP 代码请求 PUT `/api/private/totp/create` 来完成 TOTP 创建流程。\n\n"
|
||||
"若 5 分钟内未完成或错误 3 次以上则创建流程需要重新开始。"
|
||||
),
|
||||
tags=["验证", "g0v0 API"],
|
||||
response_model=StartCreateTotpKeyResp,
|
||||
status_code=201,
|
||||
)
|
||||
async def start_create_totp(
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
if await current_user.awaitable_attrs.totp_key:
|
||||
raise HTTPException(status_code=400, detail="TOTP is already enabled for this user")
|
||||
|
||||
previous = await redis.hgetall(totp_redis_key(current_user)) # pyright: ignore[reportGeneralTypeIssues]
|
||||
if previous: # pyright: ignore[reportGeneralTypeIssues]
|
||||
return StartCreateTotpKeyResp(
|
||||
secret=previous["secret"],
|
||||
uri=pyotp.totp.TOTP(previous["secret"]).provisioning_uri(
|
||||
name=current_user.email,
|
||||
issuer_name=settings.totp_issuer,
|
||||
),
|
||||
)
|
||||
return await start_create_totp_key(current_user, redis)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/totp/create",
|
||||
name="完成 TOTP 创建流程",
|
||||
description=(
|
||||
"完成 TOTP 创建流程,验证用户提供的 TOTP 代码。\n\n"
|
||||
"- 如果验证成功,启用用户的 TOTP 双因素验证,并返回备份码。\n- 如果验证失败,返回错误信息。"
|
||||
),
|
||||
tags=["验证", "g0v0 API"],
|
||||
response_model=list[str],
|
||||
status_code=201,
|
||||
)
|
||||
async def finish_create_totp(
|
||||
session: Database,
|
||||
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
status, backup_codes = await finish_create_totp_key(current_user, code, redis, session)
|
||||
if status == FinishStatus.SUCCESS:
|
||||
return backup_codes
|
||||
elif status == FinishStatus.INVALID:
|
||||
raise HTTPException(status_code=400, detail="No TOTP setup in progress or invalid data")
|
||||
elif status == FinishStatus.TOO_MANY_ATTEMPTS:
|
||||
raise HTTPException(status_code=400, detail="Too many failed attempts. Please start over.")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid TOTP code")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/totp",
|
||||
name="禁用 TOTP 双因素验证",
|
||||
description="禁用当前用户的 TOTP 双因素验证",
|
||||
tags=["验证", "g0v0 API"],
|
||||
status_code=204,
|
||||
)
|
||||
async def disable_totp(
|
||||
session: Database,
|
||||
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
totp = await session.get(TotpKeys, current_user.id)
|
||||
if not totp:
|
||||
raise HTTPException(status_code=400, detail="TOTP is not enabled for this user")
|
||||
if verify_totp_key(totp.secret, code) or (len(code) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp, code)):
|
||||
await session.delete(totp)
|
||||
await session.commit()
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid TOTP code or backup code")
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import User
|
||||
from app.database import MeResp, User
|
||||
from app.database.lazer_user import ALL_INCLUDED
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import UserAndToken, get_current_user_and_token
|
||||
from app.exceptions.userpage import UserpageError
|
||||
from app.models.api_me import APIMe
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import Page
|
||||
from app.models.userpage import (
|
||||
@@ -23,7 +23,7 @@ from fastapi import HTTPException, Path, Security
|
||||
|
||||
@router.get(
|
||||
"/me/{ruleset}",
|
||||
response_model=APIMe,
|
||||
response_model=MeResp,
|
||||
name="获取当前用户信息 (指定 ruleset)",
|
||||
description="获取当前登录用户信息 (含指定 ruleset 统计)。",
|
||||
tags=["用户"],
|
||||
@@ -31,34 +31,24 @@ from fastapi import HTTPException, Path, Security
|
||||
async def get_user_info_with_ruleset(
|
||||
session: Database,
|
||||
ruleset: GameMode = Path(description="指定 ruleset"),
|
||||
current_user: User = Security(get_current_user, scopes=["identify"]),
|
||||
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
|
||||
):
|
||||
user_resp = await APIMe.from_db(
|
||||
current_user,
|
||||
session,
|
||||
ALL_INCLUDED,
|
||||
ruleset,
|
||||
)
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, ruleset, token_id=user_and_token[1].id)
|
||||
return user_resp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me/",
|
||||
response_model=APIMe,
|
||||
response_model=MeResp,
|
||||
name="获取当前用户信息",
|
||||
description="获取当前登录用户信息。",
|
||||
tags=["用户"],
|
||||
)
|
||||
async def get_user_info_default(
|
||||
session: Database,
|
||||
current_user: User = Security(get_current_user, scopes=["identify"]),
|
||||
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
|
||||
):
|
||||
user_resp = await APIMe.from_db(
|
||||
current_user,
|
||||
session,
|
||||
ALL_INCLUDED,
|
||||
None,
|
||||
)
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, None, token_id=user_and_token[1].id)
|
||||
return user_resp
|
||||
|
||||
|
||||
|
||||
@@ -4,24 +4,35 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from app.database import User
|
||||
from app.dependencies import get_current_user
|
||||
from app.auth import check_totp_backup_code, verify_totp_key
|
||||
from app.config import settings
|
||||
from app.const import BACKUP_CODE_LENGTH
|
||||
from app.database.auth import TotpKeys
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.service.email_verification_service import (
|
||||
EmailVerificationService,
|
||||
)
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
from app.dependencies.user import UserAndToken, get_client_user_and_token
|
||||
from app.log import logger
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.service.verification_service import (
|
||||
EmailVerificationService,
|
||||
LoginSessionService,
|
||||
)
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Request, Security, status
|
||||
from fastapi.responses import Response
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class VerifyMethod(BaseModel):
|
||||
method: Literal["totp", "mail"] = "mail"
|
||||
|
||||
|
||||
class SessionReissueResponse(BaseModel):
|
||||
"""重新发送验证码响应"""
|
||||
|
||||
@@ -29,66 +40,94 @@ class SessionReissueResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class VerifyFailed(Exception): ...
|
||||
|
||||
|
||||
@router.post(
|
||||
"/session/verify", name="验证会话", description="验证邮件验证码并完成会话认证", status_code=204, tags=["验证"]
|
||||
"/session/verify",
|
||||
name="验证会话",
|
||||
description="验证邮件验证码并完成会话认证",
|
||||
status_code=204,
|
||||
tags=["验证"],
|
||||
responses={
|
||||
401: {"model": VerifyMethod, "description": "验证失败,返回当前使用的验证方法"},
|
||||
204: {"description": "验证成功,无内容返回"},
|
||||
},
|
||||
)
|
||||
async def verify_session(
|
||||
request: Request,
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
verification_key: str = Form(..., description="8位邮件验证码"),
|
||||
current_user: User = Security(get_current_user),
|
||||
verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 (g0v0 扩展支持)"),
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
) -> Response:
|
||||
"""
|
||||
验证邮件验证码并完成会话认证
|
||||
current_user = user_and_token[0]
|
||||
token_id = user_and_token[1].id
|
||||
user_id = current_user.id
|
||||
|
||||
if not await LoginSessionService.check_is_need_verification(db, user_id, token_id):
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
verify_method: str | None = (
|
||||
"mail" if api_version < 20250913 else await LoginSessionService.get_login_method(user_id, token_id, redis)
|
||||
)
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
login_method = "password"
|
||||
|
||||
对应 osu! 的 session/verify 接口
|
||||
成功时返回 204 No Content,失败时返回 401 Unauthorized
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
totp_key: TotpKeys | None = await current_user.awaitable_attrs.totp_key
|
||||
if verify_method is None:
|
||||
verify_method = "totp" if totp_key else "mail"
|
||||
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
|
||||
login_method = verify_method
|
||||
|
||||
ip_address = get_client_ip(request) # noqa: F841
|
||||
user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841
|
||||
if verify_method == "totp":
|
||||
if not totp_key:
|
||||
if settings.enable_email_verification:
|
||||
await LoginSessionService.set_login_method(user_id, token_id, "mail", redis)
|
||||
await EmailVerificationService.send_verification_email(
|
||||
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||
)
|
||||
verify_method = "mail"
|
||||
raise VerifyFailed("用户未设置 TOTP,已发送邮件验证码")
|
||||
# 如果未开启邮箱验证,则直接认为认证通过
|
||||
# 正常不会进入到这里
|
||||
|
||||
# 从当前认证用户获取信息
|
||||
user_id = current_user.id
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证")
|
||||
|
||||
# 验证邮件验证码
|
||||
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key)
|
||||
|
||||
if success:
|
||||
# 记录成功的邮件验证
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
request=request,
|
||||
login_method="email_verification",
|
||||
login_success=True,
|
||||
notes="邮件验证成功",
|
||||
)
|
||||
|
||||
# 返回 204 No Content 表示验证成功
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
elif verify_totp_key(totp_key.secret, verification_key):
|
||||
pass
|
||||
elif len(verification_key) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp_key, verification_key):
|
||||
login_method = "totp_backup_code"
|
||||
else:
|
||||
raise VerifyFailed("TOTP 验证失败")
|
||||
else:
|
||||
# 记录失败的邮件验证尝试
|
||||
await LoginLogService.record_failed_login(
|
||||
db=db,
|
||||
request=request,
|
||||
attempted_username=current_user.username,
|
||||
login_method="email_verification",
|
||||
notes=f"邮件验证失败: {message}",
|
||||
)
|
||||
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
|
||||
if not success:
|
||||
raise VerifyFailed(f"邮件验证失败: {message}")
|
||||
|
||||
# 返回 401 Unauthorized 表示验证失败
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message)
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
request=request,
|
||||
login_method=login_method,
|
||||
login_success=True,
|
||||
notes=f"{login_method} 验证成功",
|
||||
)
|
||||
await LoginSessionService.mark_session_verified(db, redis, user_id, token_id)
|
||||
await db.commit()
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误")
|
||||
except VerifyFailed as e:
|
||||
await LoginLogService.record_failed_login(
|
||||
db=db,
|
||||
request=request,
|
||||
attempted_username=current_user.username,
|
||||
login_method=login_method,
|
||||
notes=str(e),
|
||||
)
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": verify_method})
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -101,26 +140,27 @@ async def verify_session(
|
||||
async def reissue_verification_code(
|
||||
request: Request,
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
current_user: User = Security(get_current_user),
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
) -> SessionReissueResponse:
|
||||
"""
|
||||
重新发送邮件验证码
|
||||
current_user = user_and_token[0]
|
||||
token_id = user_and_token[1].id
|
||||
user_id = current_user.id
|
||||
|
||||
if not await LoginSessionService.check_is_need_verification(db, user_id, token_id):
|
||||
return SessionReissueResponse(success=False, message="当前会话不需要验证")
|
||||
|
||||
verify_method: str | None = (
|
||||
"mail" if api_version < 20250913 else await LoginSessionService.get_login_method(user_id, token_id, redis)
|
||||
)
|
||||
if verify_method != "mail":
|
||||
return SessionReissueResponse(success=False, message="当前会话不支持重新发送验证码")
|
||||
|
||||
对应 osu! 的 session/verify/reissue 接口
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
|
||||
# 从当前认证用户获取信息
|
||||
user_id = current_user.id
|
||||
if not user_id:
|
||||
return SessionReissueResponse(success=False, message="用户未认证")
|
||||
|
||||
# 重新发送验证码
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
db,
|
||||
redis,
|
||||
@@ -137,3 +177,41 @@ async def reissue_verification_code(
|
||||
return SessionReissueResponse(success=False, message="无效的用户会话")
|
||||
except Exception:
|
||||
return SessionReissueResponse(success=False, message="重新发送过程中发生错误")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/session/verify/mail-fallback",
|
||||
name="邮件验证码回退",
|
||||
description="当 TOTP 验证不可用时,使用邮件验证码进行回退验证",
|
||||
response_model=VerifyMethod,
|
||||
tags=["验证"],
|
||||
)
|
||||
async def fallback_email(
|
||||
db: Database,
|
||||
request: Request,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
) -> VerifyMethod:
|
||||
current_user = user_and_token[0]
|
||||
token_id = user_and_token[1].id
|
||||
if not await LoginSessionService.get_login_method(current_user.id, token_id, redis):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退")
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
|
||||
await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis)
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
db,
|
||||
redis,
|
||||
current_user.id,
|
||||
current_user.username,
|
||||
current_user.email,
|
||||
ip_address,
|
||||
user_agent,
|
||||
)
|
||||
if not success:
|
||||
logger.error(
|
||||
f"[Email Fallback] Failed to send fallback email to user {current_user.id} (token: {token_id}): {message}"
|
||||
)
|
||||
return VerifyMethod()
|
||||
|
||||
@@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.database.verification import EmailVerification, LoginSession
|
||||
from app.log import logger
|
||||
from app.utils import utcnow
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import asyncio
|
||||
from app.database.user_login_log import UserLoginLog
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip
|
||||
from app.log import logger
|
||||
from app.utils import utcnow
|
||||
from app.utils import simplify_user_agent, utcnow
|
||||
|
||||
from fastapi import Request
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -45,9 +45,6 @@ class LoginLogService:
|
||||
raw_ip = get_client_ip(request)
|
||||
ip_address = normalize_ip(raw_ip)
|
||||
|
||||
# 获取并简化User-Agent
|
||||
from app.utils import simplify_user_agent
|
||||
|
||||
raw_user_agent = request.headers.get("User-Agent", "")
|
||||
user_agent = simplify_user_agent(raw_user_agent, max_length=500)
|
||||
|
||||
|
||||
@@ -7,15 +7,16 @@ from __future__ import annotations
|
||||
from datetime import timedelta
|
||||
import secrets
|
||||
import string
|
||||
from typing import Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.database.verification import EmailVerification, LoginSession
|
||||
from app.log import logger
|
||||
from app.service.email_queue import email_queue # 导入邮件队列
|
||||
from app.utils import utcnow
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -279,20 +280,18 @@ This email was sent automatically, please do not reply.
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def verify_code(
|
||||
async def verify_email_code(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
code: str,
|
||||
ip_address: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""验证验证码"""
|
||||
"""验证邮箱验证码"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}")
|
||||
# 仍然标记登录会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
return True, "验证成功(邮件验证功能已禁用)"
|
||||
|
||||
# 先从 Redis 检查
|
||||
@@ -319,9 +318,6 @@ This email was sent automatically, please do not reply.
|
||||
verification.is_used = True
|
||||
verification.used_at = utcnow()
|
||||
|
||||
# 同时更新对应的登录会话状态
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# 删除 Redis 记录
|
||||
@@ -382,10 +378,12 @@ class LoginSessionService:
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
token_id: int,
|
||||
ip_address: str,
|
||||
user_agent: str | None = None,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False,
|
||||
is_verified: bool = False,
|
||||
) -> LoginSession:
|
||||
"""创建登录会话"""
|
||||
|
||||
@@ -393,13 +391,13 @@ class LoginSessionService:
|
||||
|
||||
session = LoginSession(
|
||||
user_id=user_id,
|
||||
session_token=session_token,
|
||||
token_id=token_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=None,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location,
|
||||
expires_at=utcnow() + timedelta(hours=24), # 24小时过期
|
||||
is_verified=not is_new_location, # 新位置需要验证
|
||||
is_verified=is_verified,
|
||||
)
|
||||
|
||||
db.add(session)
|
||||
@@ -416,46 +414,21 @@ class LoginSessionService:
|
||||
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
async def verify_session(
|
||||
db: AsyncSession, redis: Redis, session_token: str, verification_code: str
|
||||
) -> tuple[bool, str]:
|
||||
"""验证会话(通过邮件验证码)"""
|
||||
try:
|
||||
# 从 Redis 获取用户ID
|
||||
user_id = await redis.get(f"login_session:{session_token}")
|
||||
if not user_id:
|
||||
return False, "会话无效或已过期"
|
||||
@classmethod
|
||||
def _session_verify_redis_key(cls, user_id: int, token_id: int) -> str:
|
||||
return f"session_verification_method:{user_id}:{token_id}"
|
||||
|
||||
user_id = int(user_id)
|
||||
@classmethod
|
||||
async def get_login_method(cls, user_id: int, token_id: int, redis: Redis) -> Literal["totp", "mail"] | None:
|
||||
return await redis.get(cls._session_verify_redis_key(user_id, token_id))
|
||||
|
||||
# 验证邮件验证码
|
||||
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code)
|
||||
@classmethod
|
||||
async def set_login_method(cls, user_id: int, token_id: int, method: Literal["totp", "mail"], redis: Redis) -> None:
|
||||
await redis.set(cls._session_verify_redis_key(user_id, token_id), method)
|
||||
|
||||
if not success:
|
||||
return False, message
|
||||
|
||||
# 更新会话状态
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.session_token == session_token,
|
||||
LoginSession.user_id == user_id,
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
)
|
||||
)
|
||||
|
||||
session = result.first()
|
||||
if session:
|
||||
session.is_verified = True
|
||||
session.verified_at = utcnow()
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"[Login Session] User {user_id} session verification successful")
|
||||
return True, "会话验证成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during session verification: {e}")
|
||||
return False, "验证过程中发生错误"
|
||||
@classmethod
|
||||
async def clear_login_method(cls, user_id: int, token_id: int, redis: Redis) -> None:
|
||||
await redis.delete(cls._session_verify_redis_key(user_id, token_id))
|
||||
|
||||
@staticmethod
|
||||
async def check_new_location(
|
||||
@@ -485,7 +458,7 @@ class LoginSessionService:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def mark_session_verified(db: AsyncSession, user_id: int) -> bool:
|
||||
async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool:
|
||||
"""标记用户的未验证会话为已验证"""
|
||||
try:
|
||||
# 查找用户所有未验证且未过期的会话
|
||||
@@ -494,6 +467,7 @@ class LoginSessionService:
|
||||
LoginSession.user_id == user_id,
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > utcnow(),
|
||||
LoginSession.token_id == token_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -507,8 +481,27 @@ class LoginSessionService:
|
||||
if sessions:
|
||||
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
|
||||
|
||||
await LoginSessionService.clear_login_method(user_id, token_id, redis)
|
||||
|
||||
return len(sessions) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def check_is_need_verification(db: AsyncSession, user_id: int, token_id: int) -> bool:
|
||||
"""检查用户是否需要验证(有未验证的会话)"""
|
||||
if settings.enable_totp_verification or settings.enable_email_verification:
|
||||
unverified_session = (
|
||||
await db.exec(
|
||||
select(exists()).where(
|
||||
LoginSession.user_id == user_id,
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > utcnow(),
|
||||
LoginSession.token_id == token_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
return unverified_session or False
|
||||
return False
|
||||
@@ -9,6 +9,7 @@ import uuid
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import DBFactory, get_db_factory
|
||||
from app.dependencies.user import get_current_user_and_token
|
||||
from app.log import logger
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
|
||||
@@ -61,9 +62,11 @@ async def connect(
|
||||
return
|
||||
try:
|
||||
async for session in factory():
|
||||
if (user := await get_current_user(session, SecurityScopes(scopes=["*"]), token_pw=token)) is None or str(
|
||||
user.id
|
||||
) != user_id:
|
||||
if (
|
||||
user_and_token := await get_current_user_and_token(
|
||||
session, SecurityScopes(scopes=["*"]), token_pw=token
|
||||
)
|
||||
) is None or str(user_and_token[0].id) != user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
except HTTPException:
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""auth: add totp keys
|
||||
|
||||
Revision ID: 15e3a9a05b67
|
||||
Revises: ebaa317ad928
|
||||
Create Date: 2025-09-20 11:27:58.485299
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "15e3a9a05b67"
|
||||
down_revision: str | Sequence[str] | None = "ebaa317ad928"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"totp_keys",
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("secret", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
|
||||
sa.Column("backup_keys", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["lazer_users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("user_id"),
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("totp_keys")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,53 @@
|
||||
"""login_sessions: remove session_token & add token_id
|
||||
|
||||
Revision ID: fe8e9f3da298
|
||||
Revises: 15e3a9a05b67
|
||||
Create Date: 2025-09-21 02:30:58.233846
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "fe8e9f3da298"
|
||||
down_revision: str | Sequence[str] | None = "15e3a9a05b67"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("login_sessions", sa.Column("token_id", sa.Integer(), nullable=True))
|
||||
op.drop_index(op.f("ix_login_sessions_session_token"), table_name="login_sessions")
|
||||
op.create_index(op.f("ix_login_sessions_token_id"), "login_sessions", ["token_id"], unique=False)
|
||||
op.create_foreign_key(None, "login_sessions", "oauth_tokens", ["token_id"], ["id"], ondelete="SET NULL")
|
||||
op.drop_column("login_sessions", "session_token")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("login_sessions", sa.Column("session_token", mysql.VARCHAR(length=255), nullable=True))
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
sa.text("""
|
||||
UPDATE login_sessions
|
||||
SET session_token = CONCAT('migrated_', id, '_', UNIX_TIMESTAMP(), '_', RAND())
|
||||
WHERE session_token IS NULL
|
||||
""")
|
||||
)
|
||||
op.alter_column("login_sessions", "session_token", nullable=False, type_=mysql.VARCHAR(length=255))
|
||||
op.create_index(op.f("ix_login_sessions_session_token"), "login_sessions", ["session_token"], unique=True)
|
||||
|
||||
op.drop_constraint(op.f("login_sessions_ibfk_1"), "login_sessions", type_="foreignkey")
|
||||
op.drop_index(op.f("ix_login_sessions_token_id"), table_name="login_sessions")
|
||||
op.drop_column("login_sessions", "token_id")
|
||||
# ### end Alembic commands ###
|
||||
@@ -25,6 +25,7 @@ dependencies = [
|
||||
"pillow>=11.3.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"pydantic[email]>=2.5.0",
|
||||
"pyotp>=2.9.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"python-multipart>=0.0.6",
|
||||
|
||||
11
uv.lock
generated
11
uv.lock
generated
@@ -602,6 +602,7 @@ dependencies = [
|
||||
{ name = "pillow" },
|
||||
{ name = "pydantic", extra = ["email"] },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyotp" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "python-jose", extra = ["cryptography"] },
|
||||
{ name = "python-multipart" },
|
||||
@@ -645,6 +646,7 @@ requires-dist = [
|
||||
{ name = "pillow", specifier = ">=11.3.0" },
|
||||
{ name = "pydantic", extras = ["email"], specifier = ">=2.5.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.10.1" },
|
||||
{ name = "pyotp", specifier = ">=2.9.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.0" },
|
||||
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.6" },
|
||||
@@ -1285,6 +1287,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/4c/ad33b92b9864cbde84f259d5df035a6447f91891f5be77788e2a3892bce3/pymysql-1.1.2-py3-none-any.whl", hash = "sha256:e6b1d89711dd51f8f74b1631fe08f039e7d76cf67a42a323d3178f0f25762ed9", size = 45300, upload-time = "2025-08-24T12:55:53.394Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyotp"
|
||||
version = "2.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/b2/1d5994ba2acde054a443bd5e2d384175449c7d2b6d1a0614dbca3a63abfc/pyotp-2.9.0.tar.gz", hash = "sha256:346b6642e0dbdde3b4ff5a930b664ca82abfa116356ed48cc42c7d6590d36f63", size = 17763, upload-time = "2023-07-27T23:41:03.295Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/c0/c33c8792c3e50193ef55adb95c1c3c2786fe281123291c2dbf0eaab95a6f/pyotp-2.9.0-py3-none-any.whl", hash = "sha256:81c2e5865b8ac55e825b0358e496e1d9387c811e85bb40e71a3b29b288963612", size = 13376, upload-time = "2023-07-27T23:41:01.685Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.405"
|
||||
|
||||
Reference in New Issue
Block a user