feat(session-verify): 添加 TOTP 支持 (#34)

* chore(deps): add pyotp

* feat(auth): implement TOTP verification

feat(auth): implement TOTP verification and email verification services

- Added TOTP keys management with a new database model `TotpKeys`.
- Introduced `EmailVerification` and `LoginSession` models for email verification.
- Created `verification_service` to handle email verification logic and TOTP processes.
- Updated user response models to include session verification methods.
- Implemented routes for TOTP creation, verification, and fallback to email verification.
- Enhanced login session management to support new location checks and verification methods.
- Added migration script to create `totp_keys` table in the database.

* feat(config): update config example

* docs(totp): complete creating TOTP flow

* refactor(totp): resolve review

* feat(api): forbid unverified request

* fix(totp): trace session by token id to avoid other sessions are forbidden

* chore(linter): make pyright happy

* fix(totp): only mark sessions with a specified token id
This commit is contained in:
MingxuanGame
2025-09-21 19:50:11 +08:00
committed by GitHub
parent 7b4ff1224d
commit 1527e23b43
25 changed files with 684 additions and 235 deletions

View File

@@ -42,7 +42,9 @@ FETCHER_SCOPES="public"
# Logging Settings # Logging Settings
LOG_LEVEL="INFO" LOG_LEVEL="INFO"
# Email Service Settings # Verification Settings
ENABLE_TOTP_VERIFICATION=true
TOTP_ISSUER="osu! server"
ENABLE_EMAIL_VERIFICATION=false ENABLE_EMAIL_VERIFICATION=false
SMTP_SERVER="localhost" SMTP_SERVER="localhost"
SMTP_PORT=587 SMTP_PORT=587

View File

@@ -7,16 +7,20 @@ import secrets
import string import string
from app.config import settings from app.config import settings
from app.const import BACKUP_CODE_LENGTH
from app.database import ( from app.database import (
OAuthToken, OAuthToken,
User, User,
) )
from app.database.auth import TotpKeys
from app.log import logger from app.log import logger
from app.models.totp import FinishStatus, StartCreateTotpKeyResp
from app.utils import utcnow from app.utils import utcnow
import bcrypt import bcrypt
from jose import JWTError, jwt from jose import JWTError, jwt
from passlib.context import CryptContext from passlib.context import CryptContext
import pyotp
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -277,3 +281,76 @@ async def get_user_by_authorization_code(
await db.refresh(user) await db.refresh(user)
return (user, scopes.split(",")) return (user, scopes.split(","))
return None return None
def totp_redis_key(user: User) -> str:
return f"totp:setup:{user.email}"
async def start_create_totp_key(user: User, redis: Redis) -> StartCreateTotpKeyResp:
secret = pyotp.random_base32()
await redis.hset(totp_redis_key(user), mapping={"secret": secret, "fails": 0}) # pyright: ignore[reportGeneralTypeIssues]
await redis.expire(totp_redis_key(user), 300)
return StartCreateTotpKeyResp(
secret=secret,
uri=pyotp.totp.TOTP(secret).provisioning_uri(name=user.email, issuer_name=settings.totp_issuer),
)
def verify_totp_key(secret: str, code: str) -> bool:
return pyotp.TOTP(secret).verify(code, valid_window=1)
def _generate_backup_codes(count=10, length=BACKUP_CODE_LENGTH) -> list[str]:
alphabet = string.ascii_uppercase + string.digits
return ["".join(secrets.choice(alphabet) for _ in range(length)) for _ in range(count)]
async def _store_totp_key(user: User, secret: str, db: AsyncSession) -> list[str]:
backup_codes = _generate_backup_codes()
hashed_codes = [bcrypt.hashpw(code.encode(), bcrypt.gensalt()) for code in backup_codes]
totp_secret = TotpKeys(user_id=user.id, secret=secret, backup_keys=[code.decode() for code in hashed_codes])
db.add(totp_secret)
await db.commit()
return backup_codes
async def finish_create_totp_key(
user: User, code: str, redis: Redis, db: AsyncSession
) -> tuple[FinishStatus, list[str]]:
data = await redis.hgetall(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues]
if not data or "secret" not in data or "fails" not in data:
return FinishStatus.INVALID, []
secret = data["secret"]
fails = int(data["fails"])
if fails >= 3:
await redis.delete(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues]
return FinishStatus.TOO_MANY_ATTEMPTS, []
if verify_totp_key(secret, code):
await redis.delete(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues]
backup_codes = await _store_totp_key(user, secret, db)
return FinishStatus.SUCCESS, backup_codes
else:
fails += 1
await redis.hset(totp_redis_key(user), "fails", str(fails)) # pyright: ignore[reportGeneralTypeIssues]
return FinishStatus.FAILED, []
async def disable_totp(user: User, db: AsyncSession) -> None:
totp = await db.get(TotpKeys, user.id)
if totp:
await db.delete(totp)
await db.commit()
def check_totp_backup_code(totp: TotpKeys, code: str) -> bool:
for hashed_code in totp.backup_keys:
if bcrypt.checkpw(code.encode(), hashed_code.encode()):
copy = totp.backup_keys[:]
copy.remove(hashed_code)
totp.backup_keys = copy
return True
return False

View File

@@ -297,41 +297,47 @@ STORAGE_SETTINGS='{
"日志设置", "日志设置",
] ]
# 邮件服务设置 # 验证服务设置
enable_totp_verification: Annotated[bool, Field(default=True, description="是否启用TOTP双因素验证"), "验证服务设置"]
totp_issuer: Annotated[
str | None,
Field(default=None, description="TOTP 认证器中的发行者名称"),
"验证服务设置",
]
enable_email_verification: Annotated[ enable_email_verification: Annotated[
bool, bool,
Field(default=False, description="是否启用邮件验证功能"), Field(default=False, description="是否启用邮件验证功能"),
"邮件服务设置", "验证服务设置",
] ]
smtp_server: Annotated[ smtp_server: Annotated[
str, str,
Field(default="localhost", description="SMTP 服务器地址"), Field(default="localhost", description="SMTP 服务器地址"),
"邮件服务设置", "验证服务设置",
] ]
smtp_port: Annotated[ smtp_port: Annotated[
int, int,
Field(default=587, description="SMTP 服务器端口"), Field(default=587, description="SMTP 服务器端口"),
"邮件服务设置", "验证服务设置",
] ]
smtp_username: Annotated[ smtp_username: Annotated[
str, str,
Field(default="", description="SMTP 用户名"), Field(default="", description="SMTP 用户名"),
"邮件服务设置", "验证服务设置",
] ]
smtp_password: Annotated[ smtp_password: Annotated[
str, str,
Field(default="", description="SMTP 密码"), Field(default="", description="SMTP 密码"),
"邮件服务设置", "验证服务设置",
] ]
from_email: Annotated[ from_email: Annotated[
str, str,
Field(default="noreply@example.com", description="发件人邮箱"), Field(default="noreply@example.com", description="发件人邮箱"),
"邮件服务设置", "验证服务设置",
] ]
from_name: Annotated[ from_name: Annotated[
str, str,
Field(default="osu! server", description="发件人名称"), Field(default="osu! server", description="发件人名称"),
"邮件服务设置", "验证服务设置",
] ]
# 监控配置 # 监控配置

View File

@@ -1,3 +1,5 @@
from __future__ import annotations from __future__ import annotations
BANCHOBOT_ID = 2 BANCHOBOT_ID = 2
BACKUP_CODE_LENGTH = 10

View File

@@ -1,5 +1,5 @@
from .achievement import UserAchievement, UserAchievementResp from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthClient, OAuthToken, V1APIKeys from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys
from .beatmap import ( from .beatmap import (
Beatmap, Beatmap,
BeatmapResp, BeatmapResp,
@@ -25,10 +25,10 @@ from .counts import (
ReplayWatchedCount, ReplayWatchedCount,
) )
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
from .email_verification import EmailVerification, LoginSession
from .events import Event from .events import Event
from .favourite_beatmapset import FavouriteBeatmapset from .favourite_beatmapset import FavouriteBeatmapset
from .lazer_user import ( from .lazer_user import (
MeResp,
User, User,
UserResp, UserResp,
) )
@@ -67,6 +67,7 @@ from .user_account_history import (
UserAccountHistoryType, UserAccountHistoryType,
) )
from .user_login_log import UserLoginLog from .user_login_log import UserLoginLog
from .verification import EmailVerification, LoginSession
__all__ = [ __all__ = [
"APIUploadedRoom", "APIUploadedRoom",
@@ -93,6 +94,7 @@ __all__ = [
"ItemAttemptsCount", "ItemAttemptsCount",
"ItemAttemptsResp", "ItemAttemptsResp",
"LoginSession", "LoginSession",
"MeResp",
"MonthlyPlaycounts", "MonthlyPlaycounts",
"MultiplayerEvent", "MultiplayerEvent",
"MultiplayerEventResp", "MultiplayerEventResp",
@@ -126,6 +128,7 @@ __all__ = [
"Team", "Team",
"TeamMember", "TeamMember",
"TeamRequest", "TeamRequest",
"TotpKeys",
"User", "User",
"UserAccountHistory", "UserAccountHistory",
"UserAccountHistoryResp", "UserAccountHistoryResp",

View File

@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel from app.models.model import UTCBaseModel
from app.utils import utcnow from app.utils import utcnow
from .verification import LoginSession
from sqlalchemy import Column, DateTime from sqlalchemy import Column, DateTime
from sqlmodel import ( from sqlmodel import (
JSON, JSON,
@@ -23,7 +25,7 @@ if TYPE_CHECKING:
class OAuthToken(UTCBaseModel, SQLModel, table=True): class OAuthToken(UTCBaseModel, SQLModel, table=True):
__tablename__: str = "oauth_tokens" __tablename__: str = "oauth_tokens"
id: int | None = Field(default=None, primary_key=True, index=True) id: int = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
client_id: int = Field(index=True) client_id: int = Field(index=True)
access_token: str = Field(max_length=500, unique=True) access_token: str = Field(max_length=500, unique=True)
@@ -34,6 +36,7 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime)) created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime))
user: "User" = Relationship() user: "User" = Relationship()
login_session: LoginSession | None = Relationship(back_populates="token", passive_deletes=True)
class OAuthClient(SQLModel, table=True): class OAuthClient(SQLModel, table=True):
@@ -52,3 +55,13 @@ class V1APIKeys(SQLModel, table=True):
name: str = Field(max_length=100, index=True) name: str = Field(max_length=100, index=True)
key: str = Field(default_factory=secrets.token_hex, index=True) key: str = Field(default_factory=secrets.token_hex, index=True)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
class TotpKeys(SQLModel, table=True):
__tablename__: str = "totp_keys"
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
secret: str = Field(max_length=100)
backup_keys: list[str] = Field(default_factory=list, sa_column=Column(JSON))
created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime))
user: "User" = Relationship(back_populates="totp_key")

View File

@@ -1,8 +1,9 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json import json
from typing import TYPE_CHECKING, NotRequired, TypedDict from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict
from app.config import settings from app.config import settings
from app.database.auth import TotpKeys
from app.models.model import UTCBaseModel from app.models.model import UTCBaseModel
from app.models.score import GameMode from app.models.score import GameMode
from app.models.user import Country, Page from app.models.user import Country, Page
@@ -166,6 +167,7 @@ class User(AsyncAttrs, UserBase, table=True):
back_populates="user", back_populates="user",
) )
events: list[Event] = Relationship(back_populates="user") events: list[Event] = Relationship(back_populates="user")
totp_key: TotpKeys | None = Relationship(back_populates="user")
email: str = Field(max_length=254, unique=True, index=True, exclude=True) email: str = Field(max_length=254, unique=True, index=True, exclude=True)
priv: int = Field(default=1, exclude=True) priv: int = Field(default=1, exclude=True)
@@ -255,6 +257,8 @@ class UserResp(UserBase):
session: AsyncSession, session: AsyncSession,
include: list[str] = [], include: list[str] = [],
ruleset: GameMode | None = None, ruleset: GameMode | None = None,
*,
token_id: int | None = None,
) -> "UserResp": ) -> "UserResp":
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
@@ -421,26 +425,42 @@ class UserResp(UserBase):
) )
).one() ).one()
# 检查会话验证状态 if "session_verified" in include:
# 如果邮件验证功能被禁用,则始终设置 session_verified 为 true from app.service.verification_service import LoginSessionService
if not settings.enable_email_verification: u.session_verified = (
u.session_verified = True not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
if token_id
else True
)
return u
class MeResp(UserResp):
session_verification_method: Literal["totp", "mail"] | None = None
@classmethod
async def from_db(
cls,
obj: User,
session: AsyncSession,
include: list[str] = [],
ruleset: GameMode | None = None,
*,
token_id: int | None = None,
) -> "MeResp":
from app.dependencies.database import get_redis
from app.service.verification_service import LoginSessionService
u = await super().from_db(obj, session, ["session_verified", *include], ruleset, token_id=token_id)
u = cls.model_validate(u.model_dump())
if (settings.enable_totp_verification or settings.enable_email_verification) and token_id:
redis = get_redis()
if not u.session_verified:
u.session_verification_method = await LoginSessionService.get_login_method(obj.id, token_id, redis)
else: else:
# 如果用户有未验证的登录会话,则设置 session_verified 为 false u.session_verification_method = None
from .email_verification import LoginSession
unverified_session = (
await session.exec(
select(LoginSession).where(
LoginSession.user_id == obj.id,
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > utcnow(),
)
)
).first()
u.session_verified = unverified_session is None
return u return u
@@ -455,6 +475,7 @@ ALL_INCLUDED = [
"monthly_playcounts", "monthly_playcounts",
"replays_watched_counts", "replays_watched_counts",
"rank_history", "rank_history",
"session_verified",
] ]

View File

@@ -2,14 +2,16 @@
邮件验证相关数据库模型 邮件验证相关数据库模型
""" """
from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Optional
from app.utils import utcnow from app.utils import utcnow
from sqlalchemy import BigInteger, Column, ForeignKey from sqlalchemy import BigInteger, Column, ForeignKey
from sqlmodel import Field, SQLModel from sqlmodel import Field, Integer, Relationship, SQLModel
if TYPE_CHECKING:
from .auth import OAuthToken
class EmailVerification(SQLModel, table=True): class EmailVerification(SQLModel, table=True):
@@ -36,7 +38,9 @@ class LoginSession(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True)) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
session_token: str = Field(unique=True, index=True) # 会话令牌 token_id: int | None = Field(
sa_column=Column(Integer, ForeignKey("oauth_tokens.id", ondelete="SET NULL"), nullable=True, index=True)
)
ip_address: str = Field() # 登录IP ip_address: str = Field() # 登录IP
user_agent: str | None = Field(default=None, max_length=250) user_agent: str | None = Field(default=None, max_length=250)
country_code: str | None = Field(default=None) country_code: str | None = Field(default=None)
@@ -45,3 +49,5 @@ class LoginSession(SQLModel, table=True):
verified_at: datetime | None = Field(default=None) verified_at: datetime | None = Field(default=None)
expires_at: datetime = Field() # 会话过期时间 expires_at: datetime = Field() # 会话过期时间
is_new_location: bool = Field(default=False) # 是否新位置登录 is_new_location: bool = Field(default=False) # 是否新位置登录
token: Optional["OAuthToken"] = Relationship(back_populates="login_session")

View File

@@ -5,7 +5,7 @@ from typing import Annotated
from app.auth import get_token_by_access_token from app.auth import get_token_by_access_token
from app.config import settings from app.config import settings
from app.database import User from app.database import User
from app.database.auth import V1APIKeys from app.database.auth import OAuthToken, V1APIKeys
from app.models.oauth import OAuth2ClientCredentialsBearer from app.models.oauth import OAuth2ClientCredentialsBearer
from .database import Database from .database import Database
@@ -75,10 +75,10 @@ async def v1_authorize(
raise HTTPException(status_code=401, detail="Invalid API key") raise HTTPException(status_code=401, detail="Invalid API key")
async def get_client_user( async def get_client_user_and_token(
db: Database, db: Database,
token: Annotated[str, Depends(oauth2_password)], token: Annotated[str, Depends(oauth2_password)],
): ) -> tuple[User, OAuthToken]:
token_record = await get_token_by_access_token(db, token) token_record = await get_token_by_access_token(db, token)
if not token_record: if not token_record:
raise HTTPException(status_code=401, detail="Invalid or expired token") raise HTTPException(status_code=401, detail="Invalid or expired token")
@@ -87,17 +87,33 @@ async def get_client_user(
if not user: if not user:
raise HTTPException(status_code=401, detail="Invalid or expired token") raise HTTPException(status_code=401, detail="Invalid or expired token")
await db.refresh(user) return user, token_record
UserAndToken = tuple[User, OAuthToken]
async def get_client_user_no_verified(user_and_token: UserAndToken = Depends(get_client_user_and_token)):
return user_and_token[0]
async def get_client_user(db: Database, user_and_token: UserAndToken = Depends(get_client_user_and_token)):
from app.service.verification_service import LoginSessionService
user, token = user_and_token
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
raise HTTPException(status_code=403, detail="User not verified")
return user return user
async def get_current_user( async def get_current_user_and_token(
db: Database, db: Database,
security_scopes: SecurityScopes, security_scopes: SecurityScopes,
token_pw: Annotated[str | None, Depends(oauth2_password)] = None, token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None, token_code: Annotated[str | None, Depends(oauth2_code)] = None,
token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None, token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None,
) -> User: ) -> UserAndToken:
"""获取当前认证用户""" """获取当前认证用户"""
token = token_pw or token_code or token_client_credentials token = token_pw or token_code or token_client_credentials
if not token: if not token:
@@ -120,6 +136,10 @@ async def get_current_user(
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
if not user: if not user:
raise HTTPException(status_code=401, detail="Invalid or expired token") raise HTTPException(status_code=401, detail="Invalid or expired token")
return user, token_record
await db.refresh(user)
return user async def get_current_user(
user_and_token: UserAndToken = Depends(get_current_user_and_token),
) -> User:
return user_and_token[0]

View File

@@ -1,18 +0,0 @@
"""
APIMe 响应模型 - 对应 osu! 的 APIMe 类型
"""
from __future__ import annotations
from app.database.lazer_user import UserResp
class APIMe(UserResp):
"""
/me 端点的响应模型
对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段
session_verified 字段已经在 UserResp 中定义,这里不需要重复定义
"""
pass

16
app/models/totp.py Normal file
View File

@@ -0,0 +1,16 @@
from __future__ import annotations
from enum import Enum
from typing import TypedDict
class StartCreateTotpKeyResp(TypedDict):
secret: str
uri: str
class FinishStatus(str, Enum):
INVALID = "invalid"
SUCCESS = "success"
FAILED = "failed"
TOO_MANY_ATTEMPTS = "too_many_attempts"

View File

@@ -17,6 +17,7 @@ from app.auth import (
from app.config import settings from app.config import settings
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
from app.database import DailyChallengeStats, OAuthClient, User from app.database import DailyChallengeStats, OAuthClient, User
from app.database.auth import TotpKeys
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.dependencies.geoip import get_client_ip, get_geoip_helper
@@ -30,12 +31,12 @@ from app.models.oauth import (
UserRegistrationErrors, UserRegistrationErrors,
) )
from app.models.score import GameMode from app.models.score import GameMode
from app.service.email_verification_service import ( from app.service.login_log_service import LoginLogService
from app.service.password_reset_service import password_reset_service
from app.service.verification_service import (
EmailVerificationService, EmailVerificationService,
LoginSessionService, LoginSessionService,
) )
from app.service.login_log_service import LoginLogService
from app.service.password_reset_service import password_reset_service
from app.utils import utcnow from app.utils import utcnow
from fastapi import APIRouter, Depends, Form, Request from fastapi import APIRouter, Depends, Form, Request
@@ -287,8 +288,23 @@ async def oauth_token(
# 确保用户对象与当前会话关联 # 确保用户对象与当前会话关联
await db.refresh(user) await db.refresh(user)
# 获取用户信息和客户端信息
user_id = user.id user_id = user.id
totp_key: TotpKeys | None = await user.awaitable_attrs.totp_key
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
refresh_token_str = generate_refresh_token()
token = await store_token(
db,
user_id,
client_id,
scopes,
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
)
token_id = token.id
ip_address = get_client_ip(request) ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "") user_agent = request.headers.get("User-Agent", "")
@@ -300,15 +316,22 @@ async def oauth_token(
# 检查是否为新位置登录 # 检查是否为新位置登录
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code) is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
# 创建登录会话记录 session_verification_method = None
login_session = await LoginSessionService.create_session( # noqa: F841 if settings.enable_totp_verification and totp_key is not None:
db, redis, user_id, ip_address, user_agent, country_code, is_new_location session_verification_method = "totp"
) await LoginLogService.record_login(
db=db,
# 如果是新位置登录,需要邮件验证 user_id=user_id,
if is_new_location and settings.enable_email_verification: request=request,
login_success=True,
login_method="password_pending_verification",
notes="需要 TOTP 验证",
)
elif is_new_location and settings.enable_email_verification:
# 如果是新位置登录,需要邮件验证
# 刷新用户对象以确保属性已加载 # 刷新用户对象以确保属性已加载
await db.refresh(user) await db.refresh(user)
session_verification_method = "mail"
# 发送邮件验证码 # 发送邮件验证码
verification_sent = await EmailVerificationService.send_verification_email( verification_sent = await EmailVerificationService.send_verification_email(
@@ -328,9 +351,9 @@ async def oauth_token(
if not verification_sent: if not verification_sent:
# 邮件发送失败,记录错误 # 邮件发送失败,记录错误
logger.error(f"[Auth] Failed to send email verification code for user {user_id}") logger.error(f"[Auth] Failed to send email verification code for user {user_id}")
elif is_new_location and not settings.enable_email_verification: elif is_new_location:
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证 # 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
await LoginSessionService.mark_session_verified(db, user_id) await LoginSessionService.mark_session_verified(db, redis, user_id, token_id)
logger.debug( logger.debug(
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}" f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
) )
@@ -345,25 +368,16 @@ async def oauth_token(
notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}", notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}",
) )
# 无论是否新位置登录都返回正常的token if session_verification_method:
# session_verified状态通过/me接口的session_verified字段来体现 await LoginSessionService.create_session(
db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, False
)
await LoginSessionService.set_login_method(user_id, token_id, session_verification_method, redis)
else:
await LoginSessionService.create_session(
db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, True
)
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 获取用户ID避免触发延迟加载
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
refresh_token_str = generate_refresh_token()
# 存储令牌
await store_token(
db,
user_id,
client_id,
scopes,
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
)
return TokenResponse( return TokenResponse(
access_token=access_token, access_token=access_token,
token_type="Bearer", token_type="Bearer",

View File

@@ -12,7 +12,7 @@ from app.dependencies.database import (
get_redis, get_redis,
with_db, with_db,
) )
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user_and_token
from app.log import logger from app.log import logger
from app.models.chat import ChatEvent from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail from app.models.notification import NotificationDetail
@@ -311,7 +311,11 @@ async def chat_websocket(
await websocket.close(code=1008, reason="Missing authentication token") await websocket.close(code=1008, reason="Missing authentication token")
return return
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token)) is None: if (
user_and_token := await get_current_user_and_token(
session, SecurityScopes(scopes=["chat.read"]), token_pw=auth_token
)
) is None:
await websocket.close(code=1008, reason="Invalid or expired token") await websocket.close(code=1008, reason="Invalid or expired token")
return return
@@ -320,6 +324,7 @@ async def chat_websocket(
if login.get("event") != "chat.start": if login.get("event") != "chat.start":
await websocket.close(code=1008) await websocket.close(code=1008)
return return
user = user_and_token[0]
user_id = user.id user_id = user.id
server.connect(user_id, websocket) server.connect(user_id, websocket)
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载

View File

@@ -1,8 +1,13 @@
from __future__ import annotations from __future__ import annotations
from app.config import settings
from . import avatar, beatmapset_ratings, cover, oauth, relationship, team, username # noqa: F401 from . import avatar, beatmapset_ratings, cover, oauth, relationship, team, username # noqa: F401
from .router import router as private_router from .router import router as private_router
if settings.enable_totp_verification:
from . import totp # noqa: F401
__all__ = [ __all__ = [
"private_router", "private_router",
] ]

104
app/router/private/totp.py Normal file
View File

@@ -0,0 +1,104 @@
from __future__ import annotations
from app.auth import (
check_totp_backup_code,
finish_create_totp_key,
start_create_totp_key,
totp_redis_key,
verify_totp_key,
)
from app.config import settings
from app.const import BACKUP_CODE_LENGTH
from app.database.auth import TotpKeys
from app.database.lazer_user import User
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_client_user
from app.models.totp import FinishStatus, StartCreateTotpKeyResp
from .router import router
from fastapi import Body, Depends, HTTPException, Security
import pyotp
from redis.asyncio import Redis
@router.post(
"/totp/create",
name="开始 TOTP 创建流程",
description=(
"开始 TOTP 创建流程\n\n"
"返回 TOTP 密钥和 URI供用户在身份验证器应用中添加账户。\n\n"
"然后将身份验证器应用提供的 TOTP 代码请求 PUT `/api/private/totp/create` 来完成 TOTP 创建流程。\n\n"
"若 5 分钟内未完成或错误 3 次以上则创建流程需要重新开始。"
),
tags=["验证", "g0v0 API"],
response_model=StartCreateTotpKeyResp,
status_code=201,
)
async def start_create_totp(
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
):
if await current_user.awaitable_attrs.totp_key:
raise HTTPException(status_code=400, detail="TOTP is already enabled for this user")
previous = await redis.hgetall(totp_redis_key(current_user)) # pyright: ignore[reportGeneralTypeIssues]
if previous: # pyright: ignore[reportGeneralTypeIssues]
return StartCreateTotpKeyResp(
secret=previous["secret"],
uri=pyotp.totp.TOTP(previous["secret"]).provisioning_uri(
name=current_user.email,
issuer_name=settings.totp_issuer,
),
)
return await start_create_totp_key(current_user, redis)
@router.put(
"/totp/create",
name="完成 TOTP 创建流程",
description=(
"完成 TOTP 创建流程,验证用户提供的 TOTP 代码。\n\n"
"- 如果验证成功,启用用户的 TOTP 双因素验证,并返回备份码。\n- 如果验证失败,返回错误信息。"
),
tags=["验证", "g0v0 API"],
response_model=list[str],
status_code=201,
)
async def finish_create_totp(
session: Database,
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码"),
redis: Redis = Depends(get_redis),
current_user: User = Security(get_client_user),
):
status, backup_codes = await finish_create_totp_key(current_user, code, redis, session)
if status == FinishStatus.SUCCESS:
return backup_codes
elif status == FinishStatus.INVALID:
raise HTTPException(status_code=400, detail="No TOTP setup in progress or invalid data")
elif status == FinishStatus.TOO_MANY_ATTEMPTS:
raise HTTPException(status_code=400, detail="Too many failed attempts. Please start over.")
else:
raise HTTPException(status_code=400, detail="Invalid TOTP code")
@router.delete(
"/totp",
name="禁用 TOTP 双因素验证",
description="禁用当前用户的 TOTP 双因素验证",
tags=["验证", "g0v0 API"],
status_code=204,
)
async def disable_totp(
session: Database,
code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"),
current_user: User = Security(get_client_user),
):
totp = await session.get(TotpKeys, current_user.id)
if not totp:
raise HTTPException(status_code=400, detail="TOTP is not enabled for this user")
if verify_totp_key(totp.secret, code) or (len(code) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp, code)):
await session.delete(totp)
await session.commit()
else:
raise HTTPException(status_code=400, detail="Invalid TOTP code or backup code")

View File

@@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
from app.database import User from app.database import MeResp, User
from app.database.lazer_user import ALL_INCLUDED from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.dependencies.database import Database from app.dependencies.database import Database
from app.dependencies.user import UserAndToken, get_current_user_and_token
from app.exceptions.userpage import UserpageError from app.exceptions.userpage import UserpageError
from app.models.api_me import APIMe
from app.models.score import GameMode from app.models.score import GameMode
from app.models.user import Page from app.models.user import Page
from app.models.userpage import ( from app.models.userpage import (
@@ -23,7 +23,7 @@ from fastapi import HTTPException, Path, Security
@router.get( @router.get(
"/me/{ruleset}", "/me/{ruleset}",
response_model=APIMe, response_model=MeResp,
name="获取当前用户信息 (指定 ruleset)", name="获取当前用户信息 (指定 ruleset)",
description="获取当前登录用户信息 (含指定 ruleset 统计)。", description="获取当前登录用户信息 (含指定 ruleset 统计)。",
tags=["用户"], tags=["用户"],
@@ -31,34 +31,24 @@ from fastapi import HTTPException, Path, Security
async def get_user_info_with_ruleset( async def get_user_info_with_ruleset(
session: Database, session: Database,
ruleset: GameMode = Path(description="指定 ruleset"), ruleset: GameMode = Path(description="指定 ruleset"),
current_user: User = Security(get_current_user, scopes=["identify"]), user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
): ):
user_resp = await APIMe.from_db( user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, ruleset, token_id=user_and_token[1].id)
current_user,
session,
ALL_INCLUDED,
ruleset,
)
return user_resp return user_resp
@router.get( @router.get(
"/me/", "/me/",
response_model=APIMe, response_model=MeResp,
name="获取当前用户信息", name="获取当前用户信息",
description="获取当前登录用户信息。", description="获取当前登录用户信息。",
tags=["用户"], tags=["用户"],
) )
async def get_user_info_default( async def get_user_info_default(
session: Database, session: Database,
current_user: User = Security(get_current_user, scopes=["identify"]), user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
): ):
user_resp = await APIMe.from_db( user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, None, token_id=user_and_token[1].id)
current_user,
session,
ALL_INCLUDED,
None,
)
return user_resp return user_resp

View File

@@ -4,24 +4,35 @@
from __future__ import annotations from __future__ import annotations
from typing import Annotated from typing import Annotated, Literal
from app.database import User from app.auth import check_totp_backup_code, verify_totp_key
from app.dependencies import get_current_user from app.config import settings
from app.const import BACKUP_CODE_LENGTH
from app.database.auth import TotpKeys
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, get_redis
from app.service.email_verification_service import ( from app.dependencies.geoip import get_client_ip
EmailVerificationService, from app.dependencies.user import UserAndToken, get_client_user_and_token
) from app.log import logger
from app.service.login_log_service import LoginLogService from app.service.login_log_service import LoginLogService
from app.service.verification_service import (
EmailVerificationService,
LoginSessionService,
)
from .router import router from .router import router
from fastapi import Depends, Form, HTTPException, Request, Security, status from fastapi import Depends, Form, HTTPException, Request, Security, status
from fastapi.responses import Response from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel from pydantic import BaseModel
from redis.asyncio import Redis from redis.asyncio import Redis
class VerifyMethod(BaseModel):
method: Literal["totp", "mail"] = "mail"
class SessionReissueResponse(BaseModel): class SessionReissueResponse(BaseModel):
"""重新发送验证码响应""" """重新发送验证码响应"""
@@ -29,66 +40,94 @@ class SessionReissueResponse(BaseModel):
message: str message: str
class VerifyFailed(Exception): ...
@router.post( @router.post(
"/session/verify", name="验证会话", description="验证邮件验证码并完成会话认证", status_code=204, tags=["验证"] "/session/verify",
name="验证会话",
description="验证邮件验证码并完成会话认证",
status_code=204,
tags=["验证"],
responses={
401: {"model": VerifyMethod, "description": "验证失败,返回当前使用的验证方法"},
204: {"description": "验证成功,无内容返回"},
},
) )
async def verify_session( async def verify_session(
request: Request, request: Request,
db: Database, db: Database,
api_version: APIVersion,
redis: Annotated[Redis, Depends(get_redis)], redis: Annotated[Redis, Depends(get_redis)],
verification_key: str = Form(..., description="8位邮件验证码"), verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 g0v0 扩展支持)"),
current_user: User = Security(get_current_user), user_and_token: UserAndToken = Security(get_client_user_and_token),
) -> Response: ) -> Response:
""" current_user = user_and_token[0]
验证邮件验证码并完成会话认证 token_id = user_and_token[1].id
user_id = current_user.id
if not await LoginSessionService.check_is_need_verification(db, user_id, token_id):
return Response(status_code=status.HTTP_204_NO_CONTENT)
verify_method: str | None = (
"mail" if api_version < 20250913 else await LoginSessionService.get_login_method(user_id, token_id, redis)
)
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
login_method = "password"
对应 osu! 的 session/verify 接口
成功时返回 204 No Content失败时返回 401 Unauthorized
"""
try: try:
from app.dependencies.geoip import get_client_ip totp_key: TotpKeys | None = await current_user.awaitable_attrs.totp_key
if verify_method is None:
verify_method = "totp" if totp_key else "mail"
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
login_method = verify_method
ip_address = get_client_ip(request) # noqa: F841 if verify_method == "totp":
user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841 if not totp_key:
if settings.enable_email_verification:
await LoginSessionService.set_login_method(user_id, token_id, "mail", redis)
await EmailVerificationService.send_verification_email(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
)
verify_method = "mail"
raise VerifyFailed("用户未设置 TOTP已发送邮件验证码")
# 如果未开启邮箱验证,则直接认为认证通过
# 正常不会进入到这里
# 从当前认证用户获取信息 elif verify_totp_key(totp_key.secret, verification_key):
user_id = current_user.id pass
if not user_id: elif len(verification_key) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp_key, verification_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证") login_method = "totp_backup_code"
else:
# 验证邮件验证码 raise VerifyFailed("TOTP 验证失败")
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key)
if success:
# 记录成功的邮件验证
await LoginLogService.record_login(
db=db,
user_id=user_id,
request=request,
login_method="email_verification",
login_success=True,
notes="邮件验证成功",
)
# 返回 204 No Content 表示验证成功
return Response(status_code=status.HTTP_204_NO_CONTENT)
else: else:
# 记录失败的邮件验证尝试 success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
await LoginLogService.record_failed_login( if not success:
db=db, raise VerifyFailed(f"邮件验证失败: {message}")
request=request,
attempted_username=current_user.username,
login_method="email_verification",
notes=f"邮件验证失败: {message}",
)
# 返回 401 Unauthorized 表示验证失败 await LoginLogService.record_login(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message) db=db,
user_id=user_id,
request=request,
login_method=login_method,
login_success=True,
notes=f"{login_method} 验证成功",
)
await LoginSessionService.mark_session_verified(db, redis, user_id, token_id)
await db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
except ValueError: except VerifyFailed as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话") await LoginLogService.record_failed_login(
except Exception: db=db,
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误") request=request,
attempted_username=current_user.username,
login_method=login_method,
notes=str(e),
)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": verify_method})
@router.post( @router.post(
@@ -101,26 +140,27 @@ async def verify_session(
async def reissue_verification_code( async def reissue_verification_code(
request: Request, request: Request,
db: Database, db: Database,
api_version: APIVersion,
redis: Annotated[Redis, Depends(get_redis)], redis: Annotated[Redis, Depends(get_redis)],
current_user: User = Security(get_current_user), user_and_token: UserAndToken = Security(get_client_user_and_token),
) -> SessionReissueResponse: ) -> SessionReissueResponse:
""" current_user = user_and_token[0]
重新发送邮件验证码 token_id = user_and_token[1].id
user_id = current_user.id
if not await LoginSessionService.check_is_need_verification(db, user_id, token_id):
return SessionReissueResponse(success=False, message="当前会话不需要验证")
verify_method: str | None = (
"mail" if api_version < 20250913 else await LoginSessionService.get_login_method(user_id, token_id, redis)
)
if verify_method != "mail":
return SessionReissueResponse(success=False, message="当前会话不支持重新发送验证码")
对应 osu! 的 session/verify/reissue 接口
"""
try: try:
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request) ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown") user_agent = request.headers.get("User-Agent", "Unknown")
# 从当前认证用户获取信息
user_id = current_user.id user_id = current_user.id
if not user_id:
return SessionReissueResponse(success=False, message="用户未认证")
# 重新发送验证码
success, message = await EmailVerificationService.resend_verification_code( success, message = await EmailVerificationService.resend_verification_code(
db, db,
redis, redis,
@@ -137,3 +177,41 @@ async def reissue_verification_code(
return SessionReissueResponse(success=False, message="无效的用户会话") return SessionReissueResponse(success=False, message="无效的用户会话")
except Exception: except Exception:
return SessionReissueResponse(success=False, message="重新发送过程中发生错误") return SessionReissueResponse(success=False, message="重新发送过程中发生错误")
@router.post(
"/session/verify/mail-fallback",
name="邮件验证码回退",
description="当 TOTP 验证不可用时,使用邮件验证码进行回退验证",
response_model=VerifyMethod,
tags=["验证"],
)
async def fallback_email(
db: Database,
request: Request,
redis: Annotated[Redis, Depends(get_redis)],
user_and_token: UserAndToken = Security(get_client_user_and_token),
) -> VerifyMethod:
current_user = user_and_token[0]
token_id = user_and_token[1].id
if not await LoginSessionService.get_login_method(current_user.id, token_id, redis):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退")
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis)
success, message = await EmailVerificationService.resend_verification_code(
db,
redis,
current_user.id,
current_user.username,
current_user.email,
ip_address,
user_agent,
)
if not success:
logger.error(
f"[Email Fallback] Failed to send fallback email to user {current_user.id} (token: {token_id}): {message}"
)
return VerifyMethod()

View File

@@ -6,7 +6,7 @@ from __future__ import annotations
from datetime import timedelta from datetime import timedelta
from app.database.email_verification import EmailVerification, LoginSession from app.database.verification import EmailVerification, LoginSession
from app.log import logger from app.log import logger
from app.utils import utcnow from app.utils import utcnow

View File

@@ -9,7 +9,7 @@ import asyncio
from app.database.user_login_log import UserLoginLog from app.database.user_login_log import UserLoginLog
from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip
from app.log import logger from app.log import logger
from app.utils import utcnow from app.utils import simplify_user_agent, utcnow
from fastapi import Request from fastapi import Request
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -45,9 +45,6 @@ class LoginLogService:
raw_ip = get_client_ip(request) raw_ip = get_client_ip(request)
ip_address = normalize_ip(raw_ip) ip_address = normalize_ip(raw_ip)
# 获取并简化User-Agent
from app.utils import simplify_user_agent
raw_user_agent = request.headers.get("User-Agent", "") raw_user_agent = request.headers.get("User-Agent", "")
user_agent = simplify_user_agent(raw_user_agent, max_length=500) user_agent = simplify_user_agent(raw_user_agent, max_length=500)

View File

@@ -7,15 +7,16 @@ from __future__ import annotations
from datetime import timedelta from datetime import timedelta
import secrets import secrets
import string import string
from typing import Literal
from app.config import settings from app.config import settings
from app.database.email_verification import EmailVerification, LoginSession from app.database.verification import EmailVerification, LoginSession
from app.log import logger from app.log import logger
from app.service.email_queue import email_queue # 导入邮件队列 from app.service.email_queue import email_queue # 导入邮件队列
from app.utils import utcnow from app.utils import utcnow
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import col, select from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -279,20 +280,18 @@ This email was sent automatically, please do not reply.
return False return False
@staticmethod @staticmethod
async def verify_code( async def verify_email_code(
db: AsyncSession, db: AsyncSession,
redis: Redis, redis: Redis,
user_id: int, user_id: int,
code: str, code: str,
ip_address: str | None = None, ip_address: str | None = None,
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""验证验证码""" """验证邮箱验证码"""
try: try:
# 检查是否启用邮件验证功能 # 检查是否启用邮件验证功能
if not settings.enable_email_verification: if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}") logger.debug(f"[Email Verification] Email verification is disabled, auto-approving for user {user_id}")
# 仍然标记登录会话为已验证
await LoginSessionService.mark_session_verified(db, user_id)
return True, "验证成功(邮件验证功能已禁用)" return True, "验证成功(邮件验证功能已禁用)"
# 先从 Redis 检查 # 先从 Redis 检查
@@ -319,9 +318,6 @@ This email was sent automatically, please do not reply.
verification.is_used = True verification.is_used = True
verification.used_at = utcnow() verification.used_at = utcnow()
# 同时更新对应的登录会话状态
await LoginSessionService.mark_session_verified(db, user_id)
await db.commit() await db.commit()
# 删除 Redis 记录 # 删除 Redis 记录
@@ -382,10 +378,12 @@ class LoginSessionService:
db: AsyncSession, db: AsyncSession,
redis: Redis, redis: Redis,
user_id: int, user_id: int,
token_id: int,
ip_address: str, ip_address: str,
user_agent: str | None = None, user_agent: str | None = None,
country_code: str | None = None, country_code: str | None = None,
is_new_location: bool = False, is_new_location: bool = False,
is_verified: bool = False,
) -> LoginSession: ) -> LoginSession:
"""创建登录会话""" """创建登录会话"""
@@ -393,13 +391,13 @@ class LoginSessionService:
session = LoginSession( session = LoginSession(
user_id=user_id, user_id=user_id,
session_token=session_token, token_id=token_id,
ip_address=ip_address, ip_address=ip_address,
user_agent=None, user_agent=None,
country_code=country_code, country_code=country_code,
is_new_location=is_new_location, is_new_location=is_new_location,
expires_at=utcnow() + timedelta(hours=24), # 24小时过期 expires_at=utcnow() + timedelta(hours=24), # 24小时过期
is_verified=not is_new_location, # 新位置需要验证 is_verified=is_verified,
) )
db.add(session) db.add(session)
@@ -416,46 +414,21 @@ class LoginSessionService:
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})") logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
return session return session
@staticmethod @classmethod
async def verify_session( def _session_verify_redis_key(cls, user_id: int, token_id: int) -> str:
db: AsyncSession, redis: Redis, session_token: str, verification_code: str return f"session_verification_method:{user_id}:{token_id}"
) -> tuple[bool, str]:
"""验证会话(通过邮件验证码)"""
try:
# 从 Redis 获取用户ID
user_id = await redis.get(f"login_session:{session_token}")
if not user_id:
return False, "会话无效或已过期"
user_id = int(user_id) @classmethod
async def get_login_method(cls, user_id: int, token_id: int, redis: Redis) -> Literal["totp", "mail"] | None:
return await redis.get(cls._session_verify_redis_key(user_id, token_id))
# 验证邮件验证码 @classmethod
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code) async def set_login_method(cls, user_id: int, token_id: int, method: Literal["totp", "mail"], redis: Redis) -> None:
await redis.set(cls._session_verify_redis_key(user_id, token_id), method)
if not success: @classmethod
return False, message async def clear_login_method(cls, user_id: int, token_id: int, redis: Redis) -> None:
await redis.delete(cls._session_verify_redis_key(user_id, token_id))
# 更新会话状态
result = await db.exec(
select(LoginSession).where(
LoginSession.session_token == session_token,
LoginSession.user_id == user_id,
col(LoginSession.is_verified).is_(False),
)
)
session = result.first()
if session:
session.is_verified = True
session.verified_at = utcnow()
await db.commit()
logger.info(f"[Login Session] User {user_id} session verification successful")
return True, "会话验证成功"
except Exception as e:
logger.error(f"[Login Session] Exception during session verification: {e}")
return False, "验证过程中发生错误"
@staticmethod @staticmethod
async def check_new_location( async def check_new_location(
@@ -485,7 +458,7 @@ class LoginSessionService:
return True return True
@staticmethod @staticmethod
async def mark_session_verified(db: AsyncSession, user_id: int) -> bool: async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool:
"""标记用户的未验证会话为已验证""" """标记用户的未验证会话为已验证"""
try: try:
# 查找用户所有未验证且未过期的会话 # 查找用户所有未验证且未过期的会话
@@ -494,6 +467,7 @@ class LoginSessionService:
LoginSession.user_id == user_id, LoginSession.user_id == user_id,
col(LoginSession.is_verified).is_(False), col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > utcnow(), LoginSession.expires_at > utcnow(),
LoginSession.token_id == token_id,
) )
) )
@@ -507,8 +481,27 @@ class LoginSessionService:
if sessions: if sessions:
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}") logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
await LoginSessionService.clear_login_method(user_id, token_id, redis)
return len(sessions) > 0 return len(sessions) > 0
except Exception as e: except Exception as e:
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}") logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
return False return False
@staticmethod
async def check_is_need_verification(db: AsyncSession, user_id: int, token_id: int) -> bool:
"""检查用户是否需要验证(有未验证的会话)"""
if settings.enable_totp_verification or settings.enable_email_verification:
unverified_session = (
await db.exec(
select(exists()).where(
LoginSession.user_id == user_id,
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > utcnow(),
LoginSession.token_id == token_id,
)
)
).first()
return unverified_session or False
return False

View File

@@ -9,6 +9,7 @@ import uuid
from app.database import User as DBUser from app.database import User as DBUser
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.dependencies.database import DBFactory, get_db_factory from app.dependencies.database import DBFactory, get_db_factory
from app.dependencies.user import get_current_user_and_token
from app.log import logger from app.log import logger
from app.models.signalr import NegotiateResponse, Transport from app.models.signalr import NegotiateResponse, Transport
@@ -61,9 +62,11 @@ async def connect(
return return
try: try:
async for session in factory(): async for session in factory():
if (user := await get_current_user(session, SecurityScopes(scopes=["*"]), token_pw=token)) is None or str( if (
user.id user_and_token := await get_current_user_and_token(
) != user_id: session, SecurityScopes(scopes=["*"]), token_pw=token
)
) is None or str(user_and_token[0].id) != user_id:
await websocket.close(code=1008) await websocket.close(code=1008)
return return
except HTTPException: except HTTPException:

View File

@@ -0,0 +1,47 @@
"""auth: add totp keys
Revision ID: 15e3a9a05b67
Revises: ebaa317ad928
Create Date: 2025-09-20 11:27:58.485299
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = "15e3a9a05b67"
down_revision: str | Sequence[str] | None = "ebaa317ad928"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"totp_keys",
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("secret", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
sa.Column("backup_keys", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["lazer_users.id"],
),
sa.PrimaryKeyConstraint("user_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("totp_keys")
# ### end Alembic commands ###

View File

@@ -0,0 +1,53 @@
"""login_sessions: remove session_token & add token_id
Revision ID: fe8e9f3da298
Revises: 15e3a9a05b67
Create Date: 2025-09-21 02:30:58.233846
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "fe8e9f3da298"
down_revision: str | Sequence[str] | None = "15e3a9a05b67"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("login_sessions", sa.Column("token_id", sa.Integer(), nullable=True))
op.drop_index(op.f("ix_login_sessions_session_token"), table_name="login_sessions")
op.create_index(op.f("ix_login_sessions_token_id"), "login_sessions", ["token_id"], unique=False)
op.create_foreign_key(None, "login_sessions", "oauth_tokens", ["token_id"], ["id"], ondelete="SET NULL")
op.drop_column("login_sessions", "session_token")
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("login_sessions", sa.Column("session_token", mysql.VARCHAR(length=255), nullable=True))
connection = op.get_bind()
connection.execute(
sa.text("""
UPDATE login_sessions
SET session_token = CONCAT('migrated_', id, '_', UNIX_TIMESTAMP(), '_', RAND())
WHERE session_token IS NULL
""")
)
op.alter_column("login_sessions", "session_token", nullable=False, type_=mysql.VARCHAR(length=255))
op.create_index(op.f("ix_login_sessions_session_token"), "login_sessions", ["session_token"], unique=True)
op.drop_constraint(op.f("login_sessions_ibfk_1"), "login_sessions", type_="foreignkey")
op.drop_index(op.f("ix_login_sessions_token_id"), table_name="login_sessions")
op.drop_column("login_sessions", "token_id")
# ### end Alembic commands ###

View File

@@ -25,6 +25,7 @@ dependencies = [
"pillow>=11.3.0", "pillow>=11.3.0",
"pydantic-settings>=2.10.1", "pydantic-settings>=2.10.1",
"pydantic[email]>=2.5.0", "pydantic[email]>=2.5.0",
"pyotp>=2.9.0",
"python-dotenv>=1.0.0", "python-dotenv>=1.0.0",
"python-jose[cryptography]>=3.3.0", "python-jose[cryptography]>=3.3.0",
"python-multipart>=0.0.6", "python-multipart>=0.0.6",

11
uv.lock generated
View File

@@ -602,6 +602,7 @@ dependencies = [
{ name = "pillow" }, { name = "pillow" },
{ name = "pydantic", extra = ["email"] }, { name = "pydantic", extra = ["email"] },
{ name = "pydantic-settings" }, { name = "pydantic-settings" },
{ name = "pyotp" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "python-jose", extra = ["cryptography"] }, { name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" }, { name = "python-multipart" },
@@ -645,6 +646,7 @@ requires-dist = [
{ name = "pillow", specifier = ">=11.3.0" }, { name = "pillow", specifier = ">=11.3.0" },
{ name = "pydantic", extras = ["email"], specifier = ">=2.5.0" }, { name = "pydantic", extras = ["email"], specifier = ">=2.5.0" },
{ name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "pydantic-settings", specifier = ">=2.10.1" },
{ name = "pyotp", specifier = ">=2.9.0" },
{ name = "python-dotenv", specifier = ">=1.0.0" }, { name = "python-dotenv", specifier = ">=1.0.0" },
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
{ name = "python-multipart", specifier = ">=0.0.6" }, { name = "python-multipart", specifier = ">=0.0.6" },
@@ -1285,6 +1287,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7c/4c/ad33b92b9864cbde84f259d5df035a6447f91891f5be77788e2a3892bce3/pymysql-1.1.2-py3-none-any.whl", hash = "sha256:e6b1d89711dd51f8f74b1631fe08f039e7d76cf67a42a323d3178f0f25762ed9", size = 45300, upload-time = "2025-08-24T12:55:53.394Z" }, { url = "https://files.pythonhosted.org/packages/7c/4c/ad33b92b9864cbde84f259d5df035a6447f91891f5be77788e2a3892bce3/pymysql-1.1.2-py3-none-any.whl", hash = "sha256:e6b1d89711dd51f8f74b1631fe08f039e7d76cf67a42a323d3178f0f25762ed9", size = 45300, upload-time = "2025-08-24T12:55:53.394Z" },
] ]
[[package]]
name = "pyotp"
version = "2.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f3/b2/1d5994ba2acde054a443bd5e2d384175449c7d2b6d1a0614dbca3a63abfc/pyotp-2.9.0.tar.gz", hash = "sha256:346b6642e0dbdde3b4ff5a930b664ca82abfa116356ed48cc42c7d6590d36f63", size = 17763, upload-time = "2023-07-27T23:41:03.295Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/c0/c33c8792c3e50193ef55adb95c1c3c2786fe281123291c2dbf0eaab95a6f/pyotp-2.9.0-py3-none-any.whl", hash = "sha256:81c2e5865b8ac55e825b0358e496e1d9387c811e85bb40e71a3b29b288963612", size = 13376, upload-time = "2023-07-27T23:41:01.685Z" },
]
[[package]] [[package]]
name = "pyright" name = "pyright"
version = "1.1.405" version = "1.1.405"