From 1527e23b4337d14e99b8ee80bfe98f9d38e5a14e Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 21 Sep 2025 19:50:11 +0800 Subject: [PATCH] =?UTF-8?q?feat(session-verify):=20=E6=B7=BB=E5=8A=A0=20TO?= =?UTF-8?q?TP=20=E6=94=AF=E6=8C=81=20(#34)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore(deps): add pyotp * feat(auth): implement TOTP verification feat(auth): implement TOTP verification and email verification services - Added TOTP keys management with a new database model `TotpKeys`. - Introduced `EmailVerification` and `LoginSession` models for email verification. - Created `verification_service` to handle email verification logic and TOTP processes. - Updated user response models to include session verification methods. - Implemented routes for TOTP creation, verification, and fallback to email verification. - Enhanced login session management to support new location checks and verification methods. - Added migration script to create `totp_keys` table in the database. * feat(config): update config example * docs(totp): complete creating TOTP flow * refactor(totp): resolve review * feat(api): forbid unverified request * fix(totp): trace session by token id to avoid other sessions are forbidden * chore(linter): make pyright happy * fix(totp): only mark sessions with a specified token id --- .env.example | 4 +- app/auth.py | 77 +++++++ app/config.py | 22 +- app/const.py | 2 + app/database/__init__.py | 7 +- app/database/auth.py | 15 +- app/database/lazer_user.py | 59 +++-- ...{email_verification.py => verification.py} | 14 +- app/dependencies/user.py | 36 ++- app/models/api_me.py | 18 -- app/models/totp.py | 16 ++ app/router/auth.py | 76 ++++--- app/router/notification/server.py | 9 +- app/router/private/__init__.py | 5 + app/router/private/totp.py | 104 +++++++++ app/router/v2/me.py | 26 +-- app/router/v2/session_verify.py | 210 ++++++++++++------ app/service/database_cleanup_service.py | 2 +- app/service/login_log_service.py | 5 +- ...ion_service.py => verification_service.py} | 91 ++++---- app/signalr/router.py | 9 +- ...5-09-20_15e3a9a05b67_auth_add_totp_keys.py | 47 ++++ ...ogin_sessions_remove_session_token_add_.py | 53 +++++ pyproject.toml | 1 + uv.lock | 11 + 25 files changed, 684 insertions(+), 235 deletions(-) rename app/database/{email_verification.py => verification.py} (81%) delete mode 100644 app/models/api_me.py create mode 100644 app/models/totp.py create mode 100644 app/router/private/totp.py rename app/service/{email_verification_service.py => verification_service.py} (88%) create mode 100644 migrations/versions/2025-09-20_15e3a9a05b67_auth_add_totp_keys.py create mode 100644 migrations/versions/2025-09-21_fe8e9f3da298_login_sessions_remove_session_token_add_.py diff --git a/.env.example b/.env.example index 4bd4982..72c3a0e 100644 --- a/.env.example +++ b/.env.example @@ -42,7 +42,9 @@ FETCHER_SCOPES="public" # Logging Settings LOG_LEVEL="INFO" -# Email Service Settings +# Verification Settings +ENABLE_TOTP_VERIFICATION=true +TOTP_ISSUER="osu! server" ENABLE_EMAIL_VERIFICATION=false SMTP_SERVER="localhost" SMTP_PORT=587 diff --git a/app/auth.py b/app/auth.py index 356a73c..6f773a7 100644 --- a/app/auth.py +++ b/app/auth.py @@ -7,16 +7,20 @@ import secrets import string from app.config import settings +from app.const import BACKUP_CODE_LENGTH from app.database import ( OAuthToken, User, ) +from app.database.auth import TotpKeys from app.log import logger +from app.models.totp import FinishStatus, StartCreateTotpKeyResp from app.utils import utcnow import bcrypt from jose import JWTError, jwt from passlib.context import CryptContext +import pyotp from redis.asyncio import Redis from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -277,3 +281,76 @@ async def get_user_by_authorization_code( await db.refresh(user) return (user, scopes.split(",")) return None + + +def totp_redis_key(user: User) -> str: + return f"totp:setup:{user.email}" + + +async def start_create_totp_key(user: User, redis: Redis) -> StartCreateTotpKeyResp: + secret = pyotp.random_base32() + await redis.hset(totp_redis_key(user), mapping={"secret": secret, "fails": 0}) # pyright: ignore[reportGeneralTypeIssues] + await redis.expire(totp_redis_key(user), 300) + return StartCreateTotpKeyResp( + secret=secret, + uri=pyotp.totp.TOTP(secret).provisioning_uri(name=user.email, issuer_name=settings.totp_issuer), + ) + + +def verify_totp_key(secret: str, code: str) -> bool: + return pyotp.TOTP(secret).verify(code, valid_window=1) + + +def _generate_backup_codes(count=10, length=BACKUP_CODE_LENGTH) -> list[str]: + alphabet = string.ascii_uppercase + string.digits + return ["".join(secrets.choice(alphabet) for _ in range(length)) for _ in range(count)] + + +async def _store_totp_key(user: User, secret: str, db: AsyncSession) -> list[str]: + backup_codes = _generate_backup_codes() + hashed_codes = [bcrypt.hashpw(code.encode(), bcrypt.gensalt()) for code in backup_codes] + totp_secret = TotpKeys(user_id=user.id, secret=secret, backup_keys=[code.decode() for code in hashed_codes]) + db.add(totp_secret) + await db.commit() + return backup_codes + + +async def finish_create_totp_key( + user: User, code: str, redis: Redis, db: AsyncSession +) -> tuple[FinishStatus, list[str]]: + data = await redis.hgetall(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues] + if not data or "secret" not in data or "fails" not in data: + return FinishStatus.INVALID, [] + + secret = data["secret"] + fails = int(data["fails"]) + + if fails >= 3: + await redis.delete(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues] + return FinishStatus.TOO_MANY_ATTEMPTS, [] + + if verify_totp_key(secret, code): + await redis.delete(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues] + backup_codes = await _store_totp_key(user, secret, db) + return FinishStatus.SUCCESS, backup_codes + else: + fails += 1 + await redis.hset(totp_redis_key(user), "fails", str(fails)) # pyright: ignore[reportGeneralTypeIssues] + return FinishStatus.FAILED, [] + + +async def disable_totp(user: User, db: AsyncSession) -> None: + totp = await db.get(TotpKeys, user.id) + if totp: + await db.delete(totp) + await db.commit() + + +def check_totp_backup_code(totp: TotpKeys, code: str) -> bool: + for hashed_code in totp.backup_keys: + if bcrypt.checkpw(code.encode(), hashed_code.encode()): + copy = totp.backup_keys[:] + copy.remove(hashed_code) + totp.backup_keys = copy + return True + return False diff --git a/app/config.py b/app/config.py index 0146121..335a307 100644 --- a/app/config.py +++ b/app/config.py @@ -297,41 +297,47 @@ STORAGE_SETTINGS='{ "日志设置", ] - # 邮件服务设置 + # 验证服务设置 + enable_totp_verification: Annotated[bool, Field(default=True, description="是否启用TOTP双因素验证"), "验证服务设置"] + totp_issuer: Annotated[ + str | None, + Field(default=None, description="TOTP 认证器中的发行者名称"), + "验证服务设置", + ] enable_email_verification: Annotated[ bool, Field(default=False, description="是否启用邮件验证功能"), - "邮件服务设置", + "验证服务设置", ] smtp_server: Annotated[ str, Field(default="localhost", description="SMTP 服务器地址"), - "邮件服务设置", + "验证服务设置", ] smtp_port: Annotated[ int, Field(default=587, description="SMTP 服务器端口"), - "邮件服务设置", + "验证服务设置", ] smtp_username: Annotated[ str, Field(default="", description="SMTP 用户名"), - "邮件服务设置", + "验证服务设置", ] smtp_password: Annotated[ str, Field(default="", description="SMTP 密码"), - "邮件服务设置", + "验证服务设置", ] from_email: Annotated[ str, Field(default="noreply@example.com", description="发件人邮箱"), - "邮件服务设置", + "验证服务设置", ] from_name: Annotated[ str, Field(default="osu! server", description="发件人名称"), - "邮件服务设置", + "验证服务设置", ] # 监控配置 diff --git a/app/const.py b/app/const.py index 78ad45c..143e81c 100644 --- a/app/const.py +++ b/app/const.py @@ -1,3 +1,5 @@ from __future__ import annotations BANCHOBOT_ID = 2 + +BACKUP_CODE_LENGTH = 10 diff --git a/app/database/__init__.py b/app/database/__init__.py index b3a0e3d..60267d8 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -1,5 +1,5 @@ from .achievement import UserAchievement, UserAchievementResp -from .auth import OAuthClient, OAuthToken, V1APIKeys +from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys from .beatmap import ( Beatmap, BeatmapResp, @@ -25,10 +25,10 @@ from .counts import ( ReplayWatchedCount, ) from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp -from .email_verification import EmailVerification, LoginSession from .events import Event from .favourite_beatmapset import FavouriteBeatmapset from .lazer_user import ( + MeResp, User, UserResp, ) @@ -67,6 +67,7 @@ from .user_account_history import ( UserAccountHistoryType, ) from .user_login_log import UserLoginLog +from .verification import EmailVerification, LoginSession __all__ = [ "APIUploadedRoom", @@ -93,6 +94,7 @@ __all__ = [ "ItemAttemptsCount", "ItemAttemptsResp", "LoginSession", + "MeResp", "MonthlyPlaycounts", "MultiplayerEvent", "MultiplayerEventResp", @@ -126,6 +128,7 @@ __all__ = [ "Team", "TeamMember", "TeamRequest", + "TotpKeys", "User", "UserAccountHistory", "UserAccountHistoryResp", diff --git a/app/database/auth.py b/app/database/auth.py index fecb3ce..3024d20 100644 --- a/app/database/auth.py +++ b/app/database/auth.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING from app.models.model import UTCBaseModel from app.utils import utcnow +from .verification import LoginSession + from sqlalchemy import Column, DateTime from sqlmodel import ( JSON, @@ -23,7 +25,7 @@ if TYPE_CHECKING: class OAuthToken(UTCBaseModel, SQLModel, table=True): __tablename__: str = "oauth_tokens" - id: int | None = Field(default=None, primary_key=True, index=True) + id: int = Field(default=None, primary_key=True, index=True) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) client_id: int = Field(index=True) access_token: str = Field(max_length=500, unique=True) @@ -34,6 +36,7 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True): created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime)) user: "User" = Relationship() + login_session: LoginSession | None = Relationship(back_populates="token", passive_deletes=True) class OAuthClient(SQLModel, table=True): @@ -52,3 +55,13 @@ class V1APIKeys(SQLModel, table=True): name: str = Field(max_length=100, index=True) key: str = Field(default_factory=secrets.token_hex, index=True) owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) + + +class TotpKeys(SQLModel, table=True): + __tablename__: str = "totp_keys" + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)) + secret: str = Field(max_length=100) + backup_keys: list[str] = Field(default_factory=list, sa_column=Column(JSON)) + created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime)) + + user: "User" = Relationship(back_populates="totp_key") diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 2ead3fd..d14b3e1 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -1,8 +1,9 @@ from datetime import datetime, timedelta import json -from typing import TYPE_CHECKING, NotRequired, TypedDict +from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict from app.config import settings +from app.database.auth import TotpKeys from app.models.model import UTCBaseModel from app.models.score import GameMode from app.models.user import Country, Page @@ -166,6 +167,7 @@ class User(AsyncAttrs, UserBase, table=True): back_populates="user", ) events: list[Event] = Relationship(back_populates="user") + totp_key: TotpKeys | None = Relationship(back_populates="user") email: str = Field(max_length=254, unique=True, index=True, exclude=True) priv: int = Field(default=1, exclude=True) @@ -255,6 +257,8 @@ class UserResp(UserBase): session: AsyncSession, include: list[str] = [], ruleset: GameMode | None = None, + *, + token_id: int | None = None, ) -> "UserResp": from app.dependencies.database import get_redis @@ -421,26 +425,42 @@ class UserResp(UserBase): ) ).one() - # 检查会话验证状态 - # 如果邮件验证功能被禁用,则始终设置 session_verified 为 true + if "session_verified" in include: + from app.service.verification_service import LoginSessionService - if not settings.enable_email_verification: - u.session_verified = True + u.session_verified = ( + not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id) + if token_id + else True + ) + + return u + + +class MeResp(UserResp): + session_verification_method: Literal["totp", "mail"] | None = None + + @classmethod + async def from_db( + cls, + obj: User, + session: AsyncSession, + include: list[str] = [], + ruleset: GameMode | None = None, + *, + token_id: int | None = None, + ) -> "MeResp": + from app.dependencies.database import get_redis + from app.service.verification_service import LoginSessionService + + u = await super().from_db(obj, session, ["session_verified", *include], ruleset, token_id=token_id) + u = cls.model_validate(u.model_dump()) + if (settings.enable_totp_verification or settings.enable_email_verification) and token_id: + redis = get_redis() + if not u.session_verified: + u.session_verification_method = await LoginSessionService.get_login_method(obj.id, token_id, redis) else: - # 如果用户有未验证的登录会话,则设置 session_verified 为 false - from .email_verification import LoginSession - - unverified_session = ( - await session.exec( - select(LoginSession).where( - LoginSession.user_id == obj.id, - col(LoginSession.is_verified).is_(False), - LoginSession.expires_at > utcnow(), - ) - ) - ).first() - u.session_verified = unverified_session is None - + u.session_verification_method = None return u @@ -455,6 +475,7 @@ ALL_INCLUDED = [ "monthly_playcounts", "replays_watched_counts", "rank_history", + "session_verified", ] diff --git a/app/database/email_verification.py b/app/database/verification.py similarity index 81% rename from app/database/email_verification.py rename to app/database/verification.py index 2fd65a9..52a42d7 100644 --- a/app/database/email_verification.py +++ b/app/database/verification.py @@ -2,14 +2,16 @@ 邮件验证相关数据库模型 """ -from __future__ import annotations - from datetime import datetime +from typing import TYPE_CHECKING, Optional from app.utils import utcnow from sqlalchemy import BigInteger, Column, ForeignKey -from sqlmodel import Field, SQLModel +from sqlmodel import Field, Integer, Relationship, SQLModel + +if TYPE_CHECKING: + from .auth import OAuthToken class EmailVerification(SQLModel, table=True): @@ -36,7 +38,9 @@ class LoginSession(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True)) - session_token: str = Field(unique=True, index=True) # 会话令牌 + token_id: int | None = Field( + sa_column=Column(Integer, ForeignKey("oauth_tokens.id", ondelete="SET NULL"), nullable=True, index=True) + ) ip_address: str = Field() # 登录IP user_agent: str | None = Field(default=None, max_length=250) country_code: str | None = Field(default=None) @@ -45,3 +49,5 @@ class LoginSession(SQLModel, table=True): verified_at: datetime | None = Field(default=None) expires_at: datetime = Field() # 会话过期时间 is_new_location: bool = Field(default=False) # 是否新位置登录 + + token: Optional["OAuthToken"] = Relationship(back_populates="login_session") diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 69f3edd..674df37 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -5,7 +5,7 @@ from typing import Annotated from app.auth import get_token_by_access_token from app.config import settings from app.database import User -from app.database.auth import V1APIKeys +from app.database.auth import OAuthToken, V1APIKeys from app.models.oauth import OAuth2ClientCredentialsBearer from .database import Database @@ -75,10 +75,10 @@ async def v1_authorize( raise HTTPException(status_code=401, detail="Invalid API key") -async def get_client_user( +async def get_client_user_and_token( db: Database, token: Annotated[str, Depends(oauth2_password)], -): +) -> tuple[User, OAuthToken]: token_record = await get_token_by_access_token(db, token) if not token_record: raise HTTPException(status_code=401, detail="Invalid or expired token") @@ -87,17 +87,33 @@ async def get_client_user( if not user: raise HTTPException(status_code=401, detail="Invalid or expired token") - await db.refresh(user) + return user, token_record + + +UserAndToken = tuple[User, OAuthToken] + + +async def get_client_user_no_verified(user_and_token: UserAndToken = Depends(get_client_user_and_token)): + return user_and_token[0] + + +async def get_client_user(db: Database, user_and_token: UserAndToken = Depends(get_client_user_and_token)): + from app.service.verification_service import LoginSessionService + + user, token = user_and_token + + if await LoginSessionService.check_is_need_verification(db, user.id, token.id): + raise HTTPException(status_code=403, detail="User not verified") return user -async def get_current_user( +async def get_current_user_and_token( db: Database, security_scopes: SecurityScopes, token_pw: Annotated[str | None, Depends(oauth2_password)] = None, token_code: Annotated[str | None, Depends(oauth2_code)] = None, token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None, -) -> User: +) -> UserAndToken: """获取当前认证用户""" token = token_pw or token_code or token_client_credentials if not token: @@ -120,6 +136,10 @@ async def get_current_user( user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() if not user: raise HTTPException(status_code=401, detail="Invalid or expired token") + return user, token_record - await db.refresh(user) - return user + +async def get_current_user( + user_and_token: UserAndToken = Depends(get_current_user_and_token), +) -> User: + return user_and_token[0] diff --git a/app/models/api_me.py b/app/models/api_me.py deleted file mode 100644 index dab1256..0000000 --- a/app/models/api_me.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -APIMe 响应模型 - 对应 osu! 的 APIMe 类型 -""" - -from __future__ import annotations - -from app.database.lazer_user import UserResp - - -class APIMe(UserResp): - """ - /me 端点的响应模型 - 对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段 - - session_verified 字段已经在 UserResp 中定义,这里不需要重复定义 - """ - - pass diff --git a/app/models/totp.py b/app/models/totp.py new file mode 100644 index 0000000..d07ec29 --- /dev/null +++ b/app/models/totp.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from enum import Enum +from typing import TypedDict + + +class StartCreateTotpKeyResp(TypedDict): + secret: str + uri: str + + +class FinishStatus(str, Enum): + INVALID = "invalid" + SUCCESS = "success" + FAILED = "failed" + TOO_MANY_ATTEMPTS = "too_many_attempts" diff --git a/app/router/auth.py b/app/router/auth.py index b667324..c7b84e7 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -17,6 +17,7 @@ from app.auth import ( from app.config import settings from app.const import BANCHOBOT_ID from app.database import DailyChallengeStats, OAuthClient, User +from app.database.auth import TotpKeys from app.database.statistics import UserStatistics from app.dependencies.database import Database, get_redis from app.dependencies.geoip import get_client_ip, get_geoip_helper @@ -30,12 +31,12 @@ from app.models.oauth import ( UserRegistrationErrors, ) from app.models.score import GameMode -from app.service.email_verification_service import ( +from app.service.login_log_service import LoginLogService +from app.service.password_reset_service import password_reset_service +from app.service.verification_service import ( EmailVerificationService, LoginSessionService, ) -from app.service.login_log_service import LoginLogService -from app.service.password_reset_service import password_reset_service from app.utils import utcnow from fastapi import APIRouter, Depends, Form, Request @@ -287,8 +288,23 @@ async def oauth_token( # 确保用户对象与当前会话关联 await db.refresh(user) - # 获取用户信息和客户端信息 user_id = user.id + totp_key: TotpKeys | None = await user.awaitable_attrs.totp_key + + # 生成令牌 + access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) + access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires) + refresh_token_str = generate_refresh_token() + token = await store_token( + db, + user_id, + client_id, + scopes, + access_token, + refresh_token_str, + settings.access_token_expire_minutes * 60, + ) + token_id = token.id ip_address = get_client_ip(request) user_agent = request.headers.get("User-Agent", "") @@ -300,15 +316,22 @@ async def oauth_token( # 检查是否为新位置登录 is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code) - # 创建登录会话记录 - login_session = await LoginSessionService.create_session( # noqa: F841 - db, redis, user_id, ip_address, user_agent, country_code, is_new_location - ) - - # 如果是新位置登录,需要邮件验证 - if is_new_location and settings.enable_email_verification: + session_verification_method = None + if settings.enable_totp_verification and totp_key is not None: + session_verification_method = "totp" + await LoginLogService.record_login( + db=db, + user_id=user_id, + request=request, + login_success=True, + login_method="password_pending_verification", + notes="需要 TOTP 验证", + ) + elif is_new_location and settings.enable_email_verification: + # 如果是新位置登录,需要邮件验证 # 刷新用户对象以确保属性已加载 await db.refresh(user) + session_verification_method = "mail" # 发送邮件验证码 verification_sent = await EmailVerificationService.send_verification_email( @@ -328,9 +351,9 @@ async def oauth_token( if not verification_sent: # 邮件发送失败,记录错误 logger.error(f"[Auth] Failed to send email verification code for user {user_id}") - elif is_new_location and not settings.enable_email_verification: + elif is_new_location: # 新位置登录但邮件验证功能被禁用,直接标记会话为已验证 - await LoginSessionService.mark_session_verified(db, user_id) + await LoginSessionService.mark_session_verified(db, redis, user_id, token_id) logger.debug( f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}" ) @@ -345,25 +368,16 @@ async def oauth_token( notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}", ) - # 无论是否新位置登录,都返回正常的token - # session_verified状态通过/me接口的session_verified字段来体现 + if session_verification_method: + await LoginSessionService.create_session( + db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, False + ) + await LoginSessionService.set_login_method(user_id, token_id, session_verification_method, redis) + else: + await LoginSessionService.create_session( + db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, True + ) - # 生成令牌 - access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) - # 获取用户ID,避免触发延迟加载 - access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires) - refresh_token_str = generate_refresh_token() - - # 存储令牌 - await store_token( - db, - user_id, - client_id, - scopes, - access_token, - refresh_token_str, - settings.access_token_expire_minutes * 60, - ) return TokenResponse( access_token=access_token, token_type="Bearer", diff --git a/app/router/notification/server.py b/app/router/notification/server.py index bec5b04..16e9836 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -12,7 +12,7 @@ from app.dependencies.database import ( get_redis, with_db, ) -from app.dependencies.user import get_current_user +from app.dependencies.user import get_current_user_and_token from app.log import logger from app.models.chat import ChatEvent from app.models.notification import NotificationDetail @@ -311,7 +311,11 @@ async def chat_websocket( await websocket.close(code=1008, reason="Missing authentication token") return - if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token)) is None: + if ( + user_and_token := await get_current_user_and_token( + session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token + ) + ) is None: await websocket.close(code=1008, reason="Invalid or expired token") return @@ -320,6 +324,7 @@ async def chat_websocket( if login.get("event") != "chat.start": await websocket.close(code=1008) return + user = user_and_token[0] user_id = user.id server.connect(user_id, websocket) # 使用明确的查询避免延迟加载 diff --git a/app/router/private/__init__.py b/app/router/private/__init__.py index b17c2f7..4402a47 100644 --- a/app/router/private/__init__.py +++ b/app/router/private/__init__.py @@ -1,8 +1,13 @@ from __future__ import annotations +from app.config import settings + from . import avatar, beatmapset_ratings, cover, oauth, relationship, team, username # noqa: F401 from .router import router as private_router +if settings.enable_totp_verification: + from . import totp # noqa: F401 + __all__ = [ "private_router", ] diff --git a/app/router/private/totp.py b/app/router/private/totp.py new file mode 100644 index 0000000..7e65e57 --- /dev/null +++ b/app/router/private/totp.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from app.auth import ( + check_totp_backup_code, + finish_create_totp_key, + start_create_totp_key, + totp_redis_key, + verify_totp_key, +) +from app.config import settings +from app.const import BACKUP_CODE_LENGTH +from app.database.auth import TotpKeys +from app.database.lazer_user import User +from app.dependencies.database import Database, get_redis +from app.dependencies.user import get_client_user +from app.models.totp import FinishStatus, StartCreateTotpKeyResp + +from .router import router + +from fastapi import Body, Depends, HTTPException, Security +import pyotp +from redis.asyncio import Redis + + +@router.post( + "/totp/create", + name="开始 TOTP 创建流程", + description=( + "开始 TOTP 创建流程\n\n" + "返回 TOTP 密钥和 URI,供用户在身份验证器应用中添加账户。\n\n" + "然后将身份验证器应用提供的 TOTP 代码请求 PUT `/api/private/totp/create` 来完成 TOTP 创建流程。\n\n" + "若 5 分钟内未完成或错误 3 次以上则创建流程需要重新开始。" + ), + tags=["验证", "g0v0 API"], + response_model=StartCreateTotpKeyResp, + status_code=201, +) +async def start_create_totp( + redis: Redis = Depends(get_redis), + current_user: User = Security(get_client_user), +): + if await current_user.awaitable_attrs.totp_key: + raise HTTPException(status_code=400, detail="TOTP is already enabled for this user") + + previous = await redis.hgetall(totp_redis_key(current_user)) # pyright: ignore[reportGeneralTypeIssues] + if previous: # pyright: ignore[reportGeneralTypeIssues] + return StartCreateTotpKeyResp( + secret=previous["secret"], + uri=pyotp.totp.TOTP(previous["secret"]).provisioning_uri( + name=current_user.email, + issuer_name=settings.totp_issuer, + ), + ) + return await start_create_totp_key(current_user, redis) + + +@router.put( + "/totp/create", + name="完成 TOTP 创建流程", + description=( + "完成 TOTP 创建流程,验证用户提供的 TOTP 代码。\n\n" + "- 如果验证成功,启用用户的 TOTP 双因素验证,并返回备份码。\n- 如果验证失败,返回错误信息。" + ), + tags=["验证", "g0v0 API"], + response_model=list[str], + status_code=201, +) +async def finish_create_totp( + session: Database, + code: str = Body(..., embed=True, description="用户提供的 TOTP 代码"), + redis: Redis = Depends(get_redis), + current_user: User = Security(get_client_user), +): + status, backup_codes = await finish_create_totp_key(current_user, code, redis, session) + if status == FinishStatus.SUCCESS: + return backup_codes + elif status == FinishStatus.INVALID: + raise HTTPException(status_code=400, detail="No TOTP setup in progress or invalid data") + elif status == FinishStatus.TOO_MANY_ATTEMPTS: + raise HTTPException(status_code=400, detail="Too many failed attempts. Please start over.") + else: + raise HTTPException(status_code=400, detail="Invalid TOTP code") + + +@router.delete( + "/totp", + name="禁用 TOTP 双因素验证", + description="禁用当前用户的 TOTP 双因素验证", + tags=["验证", "g0v0 API"], + status_code=204, +) +async def disable_totp( + session: Database, + code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"), + current_user: User = Security(get_client_user), +): + totp = await session.get(TotpKeys, current_user.id) + if not totp: + raise HTTPException(status_code=400, detail="TOTP is not enabled for this user") + if verify_totp_key(totp.secret, code) or (len(code) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp, code)): + await session.delete(totp) + await session.commit() + else: + raise HTTPException(status_code=400, detail="Invalid TOTP code or backup code") diff --git a/app/router/v2/me.py b/app/router/v2/me.py index 3ff90e3..cab441d 100644 --- a/app/router/v2/me.py +++ b/app/router/v2/me.py @@ -1,11 +1,11 @@ from __future__ import annotations -from app.database import User +from app.database import MeResp, User from app.database.lazer_user import ALL_INCLUDED from app.dependencies import get_current_user from app.dependencies.database import Database +from app.dependencies.user import UserAndToken, get_current_user_and_token from app.exceptions.userpage import UserpageError -from app.models.api_me import APIMe from app.models.score import GameMode from app.models.user import Page from app.models.userpage import ( @@ -23,7 +23,7 @@ from fastapi import HTTPException, Path, Security @router.get( "/me/{ruleset}", - response_model=APIMe, + response_model=MeResp, name="获取当前用户信息 (指定 ruleset)", description="获取当前登录用户信息 (含指定 ruleset 统计)。", tags=["用户"], @@ -31,34 +31,24 @@ from fastapi import HTTPException, Path, Security async def get_user_info_with_ruleset( session: Database, ruleset: GameMode = Path(description="指定 ruleset"), - current_user: User = Security(get_current_user, scopes=["identify"]), + user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]), ): - user_resp = await APIMe.from_db( - current_user, - session, - ALL_INCLUDED, - ruleset, - ) + user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, ruleset, token_id=user_and_token[1].id) return user_resp @router.get( "/me/", - response_model=APIMe, + response_model=MeResp, name="获取当前用户信息", description="获取当前登录用户信息。", tags=["用户"], ) async def get_user_info_default( session: Database, - current_user: User = Security(get_current_user, scopes=["identify"]), + user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]), ): - user_resp = await APIMe.from_db( - current_user, - session, - ALL_INCLUDED, - None, - ) + user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, None, token_id=user_and_token[1].id) return user_resp diff --git a/app/router/v2/session_verify.py b/app/router/v2/session_verify.py index 11d4a81..60a1b40 100644 --- a/app/router/v2/session_verify.py +++ b/app/router/v2/session_verify.py @@ -4,24 +4,35 @@ from __future__ import annotations -from typing import Annotated +from typing import Annotated, Literal -from app.database import User -from app.dependencies import get_current_user +from app.auth import check_totp_backup_code, verify_totp_key +from app.config import settings +from app.const import BACKUP_CODE_LENGTH +from app.database.auth import TotpKeys +from app.dependencies.api_version import APIVersion from app.dependencies.database import Database, get_redis -from app.service.email_verification_service import ( - EmailVerificationService, -) +from app.dependencies.geoip import get_client_ip +from app.dependencies.user import UserAndToken, get_client_user_and_token +from app.log import logger from app.service.login_log_service import LoginLogService +from app.service.verification_service import ( + EmailVerificationService, + LoginSessionService, +) from .router import router from fastapi import Depends, Form, HTTPException, Request, Security, status -from fastapi.responses import Response +from fastapi.responses import JSONResponse, Response from pydantic import BaseModel from redis.asyncio import Redis +class VerifyMethod(BaseModel): + method: Literal["totp", "mail"] = "mail" + + class SessionReissueResponse(BaseModel): """重新发送验证码响应""" @@ -29,66 +40,94 @@ class SessionReissueResponse(BaseModel): message: str +class VerifyFailed(Exception): ... + + @router.post( - "/session/verify", name="验证会话", description="验证邮件验证码并完成会话认证", status_code=204, tags=["验证"] + "/session/verify", + name="验证会话", + description="验证邮件验证码并完成会话认证", + status_code=204, + tags=["验证"], + responses={ + 401: {"model": VerifyMethod, "description": "验证失败,返回当前使用的验证方法"}, + 204: {"description": "验证成功,无内容返回"}, + }, ) async def verify_session( request: Request, db: Database, + api_version: APIVersion, redis: Annotated[Redis, Depends(get_redis)], - verification_key: str = Form(..., description="8位邮件验证码"), - current_user: User = Security(get_current_user), + verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 (g0v0 扩展支持)"), + user_and_token: UserAndToken = Security(get_client_user_and_token), ) -> Response: - """ - 验证邮件验证码并完成会话认证 + current_user = user_and_token[0] + token_id = user_and_token[1].id + user_id = current_user.id + + if not await LoginSessionService.check_is_need_verification(db, user_id, token_id): + return Response(status_code=status.HTTP_204_NO_CONTENT) + + verify_method: str | None = ( + "mail" if api_version < 20250913 else await LoginSessionService.get_login_method(user_id, token_id, redis) + ) + + ip_address = get_client_ip(request) + user_agent = request.headers.get("User-Agent", "Unknown") + login_method = "password" - 对应 osu! 的 session/verify 接口 - 成功时返回 204 No Content,失败时返回 401 Unauthorized - """ try: - from app.dependencies.geoip import get_client_ip + totp_key: TotpKeys | None = await current_user.awaitable_attrs.totp_key + if verify_method is None: + verify_method = "totp" if totp_key else "mail" + await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis) + login_method = verify_method - ip_address = get_client_ip(request) # noqa: F841 - user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841 + if verify_method == "totp": + if not totp_key: + if settings.enable_email_verification: + await LoginSessionService.set_login_method(user_id, token_id, "mail", redis) + await EmailVerificationService.send_verification_email( + db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent + ) + verify_method = "mail" + raise VerifyFailed("用户未设置 TOTP,已发送邮件验证码") + # 如果未开启邮箱验证,则直接认为认证通过 + # 正常不会进入到这里 - # 从当前认证用户获取信息 - user_id = current_user.id - if not user_id: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证") - - # 验证邮件验证码 - success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key) - - if success: - # 记录成功的邮件验证 - await LoginLogService.record_login( - db=db, - user_id=user_id, - request=request, - login_method="email_verification", - login_success=True, - notes="邮件验证成功", - ) - - # 返回 204 No Content 表示验证成功 - return Response(status_code=status.HTTP_204_NO_CONTENT) + elif verify_totp_key(totp_key.secret, verification_key): + pass + elif len(verification_key) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp_key, verification_key): + login_method = "totp_backup_code" + else: + raise VerifyFailed("TOTP 验证失败") else: - # 记录失败的邮件验证尝试 - await LoginLogService.record_failed_login( - db=db, - request=request, - attempted_username=current_user.username, - login_method="email_verification", - notes=f"邮件验证失败: {message}", - ) + success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key) + if not success: + raise VerifyFailed(f"邮件验证失败: {message}") - # 返回 401 Unauthorized 表示验证失败 - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message) + await LoginLogService.record_login( + db=db, + user_id=user_id, + request=request, + login_method=login_method, + login_success=True, + notes=f"{login_method} 验证成功", + ) + await LoginSessionService.mark_session_verified(db, redis, user_id, token_id) + await db.commit() + return Response(status_code=status.HTTP_204_NO_CONTENT) - except ValueError: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话") - except Exception: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误") + except VerifyFailed as e: + await LoginLogService.record_failed_login( + db=db, + request=request, + attempted_username=current_user.username, + login_method=login_method, + notes=str(e), + ) + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": verify_method}) @router.post( @@ -101,26 +140,27 @@ async def verify_session( async def reissue_verification_code( request: Request, db: Database, + api_version: APIVersion, redis: Annotated[Redis, Depends(get_redis)], - current_user: User = Security(get_current_user), + user_and_token: UserAndToken = Security(get_client_user_and_token), ) -> SessionReissueResponse: - """ - 重新发送邮件验证码 + current_user = user_and_token[0] + token_id = user_and_token[1].id + user_id = current_user.id + + if not await LoginSessionService.check_is_need_verification(db, user_id, token_id): + return SessionReissueResponse(success=False, message="当前会话不需要验证") + + verify_method: str | None = ( + "mail" if api_version < 20250913 else await LoginSessionService.get_login_method(user_id, token_id, redis) + ) + if verify_method != "mail": + return SessionReissueResponse(success=False, message="当前会话不支持重新发送验证码") - 对应 osu! 的 session/verify/reissue 接口 - """ try: - from app.dependencies.geoip import get_client_ip - ip_address = get_client_ip(request) user_agent = request.headers.get("User-Agent", "Unknown") - - # 从当前认证用户获取信息 user_id = current_user.id - if not user_id: - return SessionReissueResponse(success=False, message="用户未认证") - - # 重新发送验证码 success, message = await EmailVerificationService.resend_verification_code( db, redis, @@ -137,3 +177,41 @@ async def reissue_verification_code( return SessionReissueResponse(success=False, message="无效的用户会话") except Exception: return SessionReissueResponse(success=False, message="重新发送过程中发生错误") + + +@router.post( + "/session/verify/mail-fallback", + name="邮件验证码回退", + description="当 TOTP 验证不可用时,使用邮件验证码进行回退验证", + response_model=VerifyMethod, + tags=["验证"], +) +async def fallback_email( + db: Database, + request: Request, + redis: Annotated[Redis, Depends(get_redis)], + user_and_token: UserAndToken = Security(get_client_user_and_token), +) -> VerifyMethod: + current_user = user_and_token[0] + token_id = user_and_token[1].id + if not await LoginSessionService.get_login_method(current_user.id, token_id, redis): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退") + + ip_address = get_client_ip(request) + user_agent = request.headers.get("User-Agent", "Unknown") + + await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis) + success, message = await EmailVerificationService.resend_verification_code( + db, + redis, + current_user.id, + current_user.username, + current_user.email, + ip_address, + user_agent, + ) + if not success: + logger.error( + f"[Email Fallback] Failed to send fallback email to user {current_user.id} (token: {token_id}): {message}" + ) + return VerifyMethod() diff --git a/app/service/database_cleanup_service.py b/app/service/database_cleanup_service.py index f17dc71..75dfce9 100644 --- a/app/service/database_cleanup_service.py +++ b/app/service/database_cleanup_service.py @@ -6,7 +6,7 @@ from __future__ import annotations from datetime import timedelta -from app.database.email_verification import EmailVerification, LoginSession +from app.database.verification import EmailVerification, LoginSession from app.log import logger from app.utils import utcnow diff --git a/app/service/login_log_service.py b/app/service/login_log_service.py index 87d3b7e..d771019 100644 --- a/app/service/login_log_service.py +++ b/app/service/login_log_service.py @@ -9,7 +9,7 @@ import asyncio from app.database.user_login_log import UserLoginLog from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip from app.log import logger -from app.utils import utcnow +from app.utils import simplify_user_agent, utcnow from fastapi import Request from sqlmodel.ext.asyncio.session import AsyncSession @@ -45,9 +45,6 @@ class LoginLogService: raw_ip = get_client_ip(request) ip_address = normalize_ip(raw_ip) - # 获取并简化User-Agent - from app.utils import simplify_user_agent - raw_user_agent = request.headers.get("User-Agent", "") user_agent = simplify_user_agent(raw_user_agent, max_length=500) diff --git a/app/service/email_verification_service.py b/app/service/verification_service.py similarity index 88% rename from app/service/email_verification_service.py rename to app/service/verification_service.py index 960d590..5aec20b 100644 --- a/app/service/email_verification_service.py +++ b/app/service/verification_service.py @@ -7,15 +7,16 @@ from __future__ import annotations from datetime import timedelta import secrets import string +from typing import Literal from app.config import settings -from app.database.email_verification import EmailVerification, LoginSession +from app.database.verification import EmailVerification, LoginSession from app.log import logger from app.service.email_queue import email_queue # 导入邮件队列 from app.utils import utcnow from redis.asyncio import Redis -from sqlmodel import col, select +from sqlmodel import col, exists, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -279,20 +280,18 @@ This email was sent automatically, please do not reply. return False @staticmethod - async def verify_code( + async def verify_email_code( db: AsyncSession, redis: Redis, user_id: int, code: str, ip_address: str | None = None, ) -> tuple[bool, str]: - """验证验证码""" + """验证邮箱验证码""" try: # 检查是否启用邮件验证功能 if not settings.enable_email_verification: logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}") - # 仍然标记登录会话为已验证 - await LoginSessionService.mark_session_verified(db, user_id) return True, "验证成功(邮件验证功能已禁用)" # 先从 Redis 检查 @@ -319,9 +318,6 @@ This email was sent automatically, please do not reply. verification.is_used = True verification.used_at = utcnow() - # 同时更新对应的登录会话状态 - await LoginSessionService.mark_session_verified(db, user_id) - await db.commit() # 删除 Redis 记录 @@ -382,10 +378,12 @@ class LoginSessionService: db: AsyncSession, redis: Redis, user_id: int, + token_id: int, ip_address: str, user_agent: str | None = None, country_code: str | None = None, is_new_location: bool = False, + is_verified: bool = False, ) -> LoginSession: """创建登录会话""" @@ -393,13 +391,13 @@ class LoginSessionService: session = LoginSession( user_id=user_id, - session_token=session_token, + token_id=token_id, ip_address=ip_address, user_agent=None, country_code=country_code, is_new_location=is_new_location, expires_at=utcnow() + timedelta(hours=24), # 24小时过期 - is_verified=not is_new_location, # 新位置需要验证 + is_verified=is_verified, ) db.add(session) @@ -416,46 +414,21 @@ class LoginSessionService: logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})") return session - @staticmethod - async def verify_session( - db: AsyncSession, redis: Redis, session_token: str, verification_code: str - ) -> tuple[bool, str]: - """验证会话(通过邮件验证码)""" - try: - # 从 Redis 获取用户ID - user_id = await redis.get(f"login_session:{session_token}") - if not user_id: - return False, "会话无效或已过期" + @classmethod + def _session_verify_redis_key(cls, user_id: int, token_id: int) -> str: + return f"session_verification_method:{user_id}:{token_id}" - user_id = int(user_id) + @classmethod + async def get_login_method(cls, user_id: int, token_id: int, redis: Redis) -> Literal["totp", "mail"] | None: + return await redis.get(cls._session_verify_redis_key(user_id, token_id)) - # 验证邮件验证码 - success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code) + @classmethod + async def set_login_method(cls, user_id: int, token_id: int, method: Literal["totp", "mail"], redis: Redis) -> None: + await redis.set(cls._session_verify_redis_key(user_id, token_id), method) - if not success: - return False, message - - # 更新会话状态 - result = await db.exec( - select(LoginSession).where( - LoginSession.session_token == session_token, - LoginSession.user_id == user_id, - col(LoginSession.is_verified).is_(False), - ) - ) - - session = result.first() - if session: - session.is_verified = True - session.verified_at = utcnow() - await db.commit() - - logger.info(f"[Login Session] User {user_id} session verification successful") - return True, "会话验证成功" - - except Exception as e: - logger.error(f"[Login Session] Exception during session verification: {e}") - return False, "验证过程中发生错误" + @classmethod + async def clear_login_method(cls, user_id: int, token_id: int, redis: Redis) -> None: + await redis.delete(cls._session_verify_redis_key(user_id, token_id)) @staticmethod async def check_new_location( @@ -485,7 +458,7 @@ class LoginSessionService: return True @staticmethod - async def mark_session_verified(db: AsyncSession, user_id: int) -> bool: + async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool: """标记用户的未验证会话为已验证""" try: # 查找用户所有未验证且未过期的会话 @@ -494,6 +467,7 @@ class LoginSessionService: LoginSession.user_id == user_id, col(LoginSession.is_verified).is_(False), LoginSession.expires_at > utcnow(), + LoginSession.token_id == token_id, ) ) @@ -507,8 +481,27 @@ class LoginSessionService: if sessions: logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}") + await LoginSessionService.clear_login_method(user_id, token_id, redis) + return len(sessions) > 0 except Exception as e: logger.error(f"[Login Session] Exception during marking sessions as verified: {e}") return False + + @staticmethod + async def check_is_need_verification(db: AsyncSession, user_id: int, token_id: int) -> bool: + """检查用户是否需要验证(有未验证的会话)""" + if settings.enable_totp_verification or settings.enable_email_verification: + unverified_session = ( + await db.exec( + select(exists()).where( + LoginSession.user_id == user_id, + col(LoginSession.is_verified).is_(False), + LoginSession.expires_at > utcnow(), + LoginSession.token_id == token_id, + ) + ) + ).first() + return unverified_session or False + return False diff --git a/app/signalr/router.py b/app/signalr/router.py index 9fa316c..cf4bf97 100644 --- a/app/signalr/router.py +++ b/app/signalr/router.py @@ -9,6 +9,7 @@ import uuid from app.database import User as DBUser from app.dependencies import get_current_user from app.dependencies.database import DBFactory, get_db_factory +from app.dependencies.user import get_current_user_and_token from app.log import logger from app.models.signalr import NegotiateResponse, Transport @@ -61,9 +62,11 @@ async def connect( return try: async for session in factory(): - if (user := await get_current_user(session, SecurityScopes(scopes=["*"]), token_pw=token)) is None or str( - user.id - ) != user_id: + if ( + user_and_token := await get_current_user_and_token( + session, SecurityScopes(scopes=["*"]), token_pw=token + ) + ) is None or str(user_and_token[0].id) != user_id: await websocket.close(code=1008) return except HTTPException: diff --git a/migrations/versions/2025-09-20_15e3a9a05b67_auth_add_totp_keys.py b/migrations/versions/2025-09-20_15e3a9a05b67_auth_add_totp_keys.py new file mode 100644 index 0000000..1eb0eee --- /dev/null +++ b/migrations/versions/2025-09-20_15e3a9a05b67_auth_add_totp_keys.py @@ -0,0 +1,47 @@ +"""auth: add totp keys + +Revision ID: 15e3a9a05b67 +Revises: ebaa317ad928 +Create Date: 2025-09-20 11:27:58.485299 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +import sqlmodel + +# revision identifiers, used by Alembic. +revision: str = "15e3a9a05b67" +down_revision: str | Sequence[str] | None = "ebaa317ad928" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "totp_keys", + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("secret", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False), + sa.Column("backup_keys", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("user_id"), + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("totp_keys") + # ### end Alembic commands ### diff --git a/migrations/versions/2025-09-21_fe8e9f3da298_login_sessions_remove_session_token_add_.py b/migrations/versions/2025-09-21_fe8e9f3da298_login_sessions_remove_session_token_add_.py new file mode 100644 index 0000000..4f4d0c9 --- /dev/null +++ b/migrations/versions/2025-09-21_fe8e9f3da298_login_sessions_remove_session_token_add_.py @@ -0,0 +1,53 @@ +"""login_sessions: remove session_token & add token_id + +Revision ID: fe8e9f3da298 +Revises: 15e3a9a05b67 +Create Date: 2025-09-21 02:30:58.233846 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "fe8e9f3da298" +down_revision: str | Sequence[str] | None = "15e3a9a05b67" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("login_sessions", sa.Column("token_id", sa.Integer(), nullable=True)) + op.drop_index(op.f("ix_login_sessions_session_token"), table_name="login_sessions") + op.create_index(op.f("ix_login_sessions_token_id"), "login_sessions", ["token_id"], unique=False) + op.create_foreign_key(None, "login_sessions", "oauth_tokens", ["token_id"], ["id"], ondelete="SET NULL") + op.drop_column("login_sessions", "session_token") + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("login_sessions", sa.Column("session_token", mysql.VARCHAR(length=255), nullable=True)) + connection = op.get_bind() + connection.execute( + sa.text(""" + UPDATE login_sessions + SET session_token = CONCAT('migrated_', id, '_', UNIX_TIMESTAMP(), '_', RAND()) + WHERE session_token IS NULL + """) + ) + op.alter_column("login_sessions", "session_token", nullable=False, type_=mysql.VARCHAR(length=255)) + op.create_index(op.f("ix_login_sessions_session_token"), "login_sessions", ["session_token"], unique=True) + + op.drop_constraint(op.f("login_sessions_ibfk_1"), "login_sessions", type_="foreignkey") + op.drop_index(op.f("ix_login_sessions_token_id"), table_name="login_sessions") + op.drop_column("login_sessions", "token_id") + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index 2a51831..3527b16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "pillow>=11.3.0", "pydantic-settings>=2.10.1", "pydantic[email]>=2.5.0", + "pyotp>=2.9.0", "python-dotenv>=1.0.0", "python-jose[cryptography]>=3.3.0", "python-multipart>=0.0.6", diff --git a/uv.lock b/uv.lock index 29f2cd1..f48d358 100644 --- a/uv.lock +++ b/uv.lock @@ -602,6 +602,7 @@ dependencies = [ { name = "pillow" }, { name = "pydantic", extra = ["email"] }, { name = "pydantic-settings" }, + { name = "pyotp" }, { name = "python-dotenv" }, { name = "python-jose", extra = ["cryptography"] }, { name = "python-multipart" }, @@ -645,6 +646,7 @@ requires-dist = [ { name = "pillow", specifier = ">=11.3.0" }, { name = "pydantic", extras = ["email"], specifier = ">=2.5.0" }, { name = "pydantic-settings", specifier = ">=2.10.1" }, + { name = "pyotp", specifier = ">=2.9.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "python-multipart", specifier = ">=0.0.6" }, @@ -1285,6 +1287,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/4c/ad33b92b9864cbde84f259d5df035a6447f91891f5be77788e2a3892bce3/pymysql-1.1.2-py3-none-any.whl", hash = "sha256:e6b1d89711dd51f8f74b1631fe08f039e7d76cf67a42a323d3178f0f25762ed9", size = 45300, upload-time = "2025-08-24T12:55:53.394Z" }, ] +[[package]] +name = "pyotp" +version = "2.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/b2/1d5994ba2acde054a443bd5e2d384175449c7d2b6d1a0614dbca3a63abfc/pyotp-2.9.0.tar.gz", hash = "sha256:346b6642e0dbdde3b4ff5a930b664ca82abfa116356ed48cc42c7d6590d36f63", size = 17763, upload-time = "2023-07-27T23:41:03.295Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/c0/c33c8792c3e50193ef55adb95c1c3c2786fe281123291c2dbf0eaab95a6f/pyotp-2.9.0-py3-none-any.whl", hash = "sha256:81c2e5865b8ac55e825b0358e496e1d9387c811e85bb40e71a3b29b288963612", size = 13376, upload-time = "2023-07-27T23:41:01.685Z" }, +] + [[package]] name = "pyright" version = "1.1.405"