feat(session-verify): 添加 TOTP 支持 (#34)

* chore(deps): add pyotp

* feat(auth): implement TOTP verification

feat(auth): implement TOTP verification and email verification services

- Added TOTP keys management with a new database model `TotpKeys`.
- Introduced `EmailVerification` and `LoginSession` models for email verification.
- Created `verification_service` to handle email verification logic and TOTP processes.
- Updated user response models to include session verification methods.
- Implemented routes for TOTP creation, verification, and fallback to email verification.
- Enhanced login session management to support new location checks and verification methods.
- Added migration script to create `totp_keys` table in the database.

* feat(config): update config example

* docs(totp): complete creating TOTP flow

* refactor(totp): resolve review

* feat(api): forbid unverified request

* fix(totp): trace session by token id to avoid other sessions are forbidden

* chore(linter): make pyright happy

* fix(totp): only mark sessions with a specified token id
This commit is contained in:
MingxuanGame
2025-09-21 19:50:11 +08:00
committed by GitHub
parent 7b4ff1224d
commit 1527e23b43
25 changed files with 684 additions and 235 deletions

View File

@@ -1,5 +1,5 @@
from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthClient, OAuthToken, V1APIKeys
from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys
from .beatmap import (
Beatmap,
BeatmapResp,
@@ -25,10 +25,10 @@ from .counts import (
ReplayWatchedCount,
)
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
from .email_verification import EmailVerification, LoginSession
from .events import Event
from .favourite_beatmapset import FavouriteBeatmapset
from .lazer_user import (
MeResp,
User,
UserResp,
)
@@ -67,6 +67,7 @@ from .user_account_history import (
UserAccountHistoryType,
)
from .user_login_log import UserLoginLog
from .verification import EmailVerification, LoginSession
__all__ = [
"APIUploadedRoom",
@@ -93,6 +94,7 @@ __all__ = [
"ItemAttemptsCount",
"ItemAttemptsResp",
"LoginSession",
"MeResp",
"MonthlyPlaycounts",
"MultiplayerEvent",
"MultiplayerEventResp",
@@ -126,6 +128,7 @@ __all__ = [
"Team",
"TeamMember",
"TeamRequest",
"TotpKeys",
"User",
"UserAccountHistory",
"UserAccountHistoryResp",

View File

@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from app.utils import utcnow
from .verification import LoginSession
from sqlalchemy import Column, DateTime
from sqlmodel import (
JSON,
@@ -23,7 +25,7 @@ if TYPE_CHECKING:
class OAuthToken(UTCBaseModel, SQLModel, table=True):
__tablename__: str = "oauth_tokens"
id: int | None = Field(default=None, primary_key=True, index=True)
id: int = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
client_id: int = Field(index=True)
access_token: str = Field(max_length=500, unique=True)
@@ -34,6 +36,7 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime))
user: "User" = Relationship()
login_session: LoginSession | None = Relationship(back_populates="token", passive_deletes=True)
class OAuthClient(SQLModel, table=True):
@@ -52,3 +55,13 @@ class V1APIKeys(SQLModel, table=True):
name: str = Field(max_length=100, index=True)
key: str = Field(default_factory=secrets.token_hex, index=True)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
class TotpKeys(SQLModel, table=True):
__tablename__: str = "totp_keys"
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
secret: str = Field(max_length=100)
backup_keys: list[str] = Field(default_factory=list, sa_column=Column(JSON))
created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime))
user: "User" = Relationship(back_populates="totp_key")

View File

@@ -1,8 +1,9 @@
from datetime import datetime, timedelta
import json
from typing import TYPE_CHECKING, NotRequired, TypedDict
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict
from app.config import settings
from app.database.auth import TotpKeys
from app.models.model import UTCBaseModel
from app.models.score import GameMode
from app.models.user import Country, Page
@@ -166,6 +167,7 @@ class User(AsyncAttrs, UserBase, table=True):
back_populates="user",
)
events: list[Event] = Relationship(back_populates="user")
totp_key: TotpKeys | None = Relationship(back_populates="user")
email: str = Field(max_length=254, unique=True, index=True, exclude=True)
priv: int = Field(default=1, exclude=True)
@@ -255,6 +257,8 @@ class UserResp(UserBase):
session: AsyncSession,
include: list[str] = [],
ruleset: GameMode | None = None,
*,
token_id: int | None = None,
) -> "UserResp":
from app.dependencies.database import get_redis
@@ -421,26 +425,42 @@ class UserResp(UserBase):
)
).one()
# 检查会话验证状态
# 如果邮件验证功能被禁用,则始终设置 session_verified 为 true
if "session_verified" in include:
from app.service.verification_service import LoginSessionService
if not settings.enable_email_verification:
u.session_verified = True
u.session_verified = (
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
if token_id
else True
)
return u
class MeResp(UserResp):
session_verification_method: Literal["totp", "mail"] | None = None
@classmethod
async def from_db(
cls,
obj: User,
session: AsyncSession,
include: list[str] = [],
ruleset: GameMode | None = None,
*,
token_id: int | None = None,
) -> "MeResp":
from app.dependencies.database import get_redis
from app.service.verification_service import LoginSessionService
u = await super().from_db(obj, session, ["session_verified", *include], ruleset, token_id=token_id)
u = cls.model_validate(u.model_dump())
if (settings.enable_totp_verification or settings.enable_email_verification) and token_id:
redis = get_redis()
if not u.session_verified:
u.session_verification_method = await LoginSessionService.get_login_method(obj.id, token_id, redis)
else:
# 如果用户有未验证的登录会话,则设置 session_verified 为 false
from .email_verification import LoginSession
unverified_session = (
await session.exec(
select(LoginSession).where(
LoginSession.user_id == obj.id,
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > utcnow(),
)
)
).first()
u.session_verified = unverified_session is None
u.session_verification_method = None
return u
@@ -455,6 +475,7 @@ ALL_INCLUDED = [
"monthly_playcounts",
"replays_watched_counts",
"rank_history",
"session_verified",
]

View File

@@ -2,14 +2,16 @@
邮件验证相关数据库模型
"""
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from app.utils import utcnow
from sqlalchemy import BigInteger, Column, ForeignKey
from sqlmodel import Field, SQLModel
from sqlmodel import Field, Integer, Relationship, SQLModel
if TYPE_CHECKING:
from .auth import OAuthToken
class EmailVerification(SQLModel, table=True):
@@ -36,7 +38,9 @@ class LoginSession(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
session_token: str = Field(unique=True, index=True) # 会话令牌
token_id: int | None = Field(
sa_column=Column(Integer, ForeignKey("oauth_tokens.id", ondelete="SET NULL"), nullable=True, index=True)
)
ip_address: str = Field() # 登录IP
user_agent: str | None = Field(default=None, max_length=250)
country_code: str | None = Field(default=None)
@@ -45,3 +49,5 @@ class LoginSession(SQLModel, table=True):
verified_at: datetime | None = Field(default=None)
expires_at: datetime = Field() # 会话过期时间
is_new_location: bool = Field(default=False) # 是否新位置登录
token: Optional["OAuthToken"] = Relationship(back_populates="login_session")