diff --git a/app/auth.py b/app/auth.py index 4435fdf..669537b 100644 --- a/app/auth.py +++ b/app/auth.py @@ -317,13 +317,51 @@ def totp_redis_key(user: User) -> str: return f"totp:setup:{user.email}" +def _generate_totp_account_label(user: User) -> str: + """生成TOTP账户标签 + + 根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性 + """ + if settings.totp_use_username_in_label: + # 使用用户名作为主要标识 + primary_identifier = user.username + else: + # 使用邮箱作为标识 + primary_identifier = user.email + + # 如果配置了服务名称,添加到标签中以便在认证器中区分 + if settings.totp_service_name: + return f"{primary_identifier} ({settings.totp_service_name})" + else: + return primary_identifier + + +def _generate_totp_issuer_name() -> str: + """生成TOTP发行者名称 + + 优先使用自定义的totp_issuer,否则使用服务名称 + """ + if settings.totp_issuer: + return settings.totp_issuer + elif settings.totp_service_name: + return settings.totp_service_name + else: + # 回退到默认值 + return "osu! Private Server" + + async def start_create_totp_key(user: User, redis: Redis) -> StartCreateTotpKeyResp: secret = pyotp.random_base32() await redis.hset(totp_redis_key(user), mapping={"secret": secret, "fails": 0}) # pyright: ignore[reportGeneralTypeIssues] await redis.expire(totp_redis_key(user), 300) + + # 生成更完整的账户标签和issuer信息 + account_label = _generate_totp_account_label(user) + issuer_name = _generate_totp_issuer_name() + return StartCreateTotpKeyResp( secret=secret, - uri=pyotp.totp.TOTP(secret).provisioning_uri(name=user.email, issuer_name=settings.totp_issuer), + uri=pyotp.totp.TOTP(secret).provisioning_uri(name=account_label, issuer_name=issuer_name), ) @@ -331,6 +369,23 @@ def verify_totp_key(secret: str, code: str) -> bool: return pyotp.TOTP(secret).verify(code, valid_window=1) +async def verify_totp_key_with_replay_protection( + user_id: int, secret: str, code: str, redis: Redis +) -> bool: + """验证TOTP密钥,并防止密钥重放攻击""" + if not pyotp.TOTP(secret).verify(code, valid_window=1): + return False + + # 防止120秒内重复使用同一密钥(参考osu-web实现) + cache_key = f"totp:{user_id}:{code}" + if await redis.exists(cache_key): + return False + + # 设置120秒过期时间 + await redis.setex(cache_key, 120, "1") + return True + + def _generate_backup_codes(count=10, length=BACKUP_CODE_LENGTH) -> list[str]: alphabet = string.ascii_uppercase + string.digits return ["".join(secrets.choice(alphabet) for _ in range(length)) for _ in range(count)] diff --git a/app/config.py b/app/config.py index 1805976..cf1107e 100644 --- a/app/config.py +++ b/app/config.py @@ -304,6 +304,16 @@ STORAGE_SETTINGS='{ Field(default=None, description="TOTP 认证器中的发行者名称"), "验证服务设置", ] + totp_service_name: Annotated[ + str, + Field(default="g0v0! Lazer Server", description="TOTP 认证器中显示的服务名称"), + "验证服务设置", + ] + totp_use_username_in_label: Annotated[ + bool, + Field(default=True, description="在TOTP标签中使用用户名而不是邮箱"), + "验证服务设置", + ] enable_email_verification: Annotated[ bool, Field(default=False, description="是否启用邮件验证功能"), @@ -314,6 +324,11 @@ STORAGE_SETTINGS='{ Field(default=True, description="是否启用智能验证(基于客户端类型和设备信任)"), "验证服务设置", ] + enable_session_verification: Annotated[ + bool, + Field(default=True, description="是否启用会话验证中间件"), + "验证服务设置", + ] enable_multi_device_login: Annotated[ bool, Field(default=True, description="是否允许多设备同时登录"), diff --git a/app/database/verification.py b/app/database/verification.py index 52a42d7..f5b9e32 100644 --- a/app/database/verification.py +++ b/app/database/verification.py @@ -49,5 +49,7 @@ class LoginSession(SQLModel, table=True): verified_at: datetime | None = Field(default=None) expires_at: datetime = Field() # 会话过期时间 is_new_location: bool = Field(default=False) # 是否新位置登录 + session_token: str | None = Field(default=None, max_length=64, index=True) # 会话令牌 + verification_method: str | None = Field(default=None, max_length=20) # 验证方法 (totp/mail) token: Optional["OAuthToken"] = Relationship(back_populates="login_session") diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 674df37..aa091b4 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -8,9 +8,11 @@ from app.database import User from app.database.auth import OAuthToken, V1APIKeys from app.models.oauth import OAuth2ClientCredentialsBearer -from .database import Database +from .api_version import APIVersion +from .database import Database, get_redis from fastapi import Depends, HTTPException +from redis.asyncio import Redis from fastapi.security import ( APIKeyQuery, HTTPBearer, @@ -97,13 +99,40 @@ async def get_client_user_no_verified(user_and_token: UserAndToken = Depends(get return user_and_token[0] -async def get_client_user(db: Database, user_and_token: UserAndToken = Depends(get_client_user_and_token)): +async def get_client_user( + db: Database, + redis: Annotated[Redis, Depends(get_redis)], + api_version: APIVersion, + user_and_token: UserAndToken = Depends(get_client_user_and_token) +): from app.service.verification_service import LoginSessionService user, token = user_and_token if await LoginSessionService.check_is_need_verification(db, user.id, token.id): - raise HTTPException(status_code=403, detail="User not verified") + # 获取当前验证方式 + verify_method = None + if api_version >= 20250913: + verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis) + + if verify_method is None: + # 智能选择验证方式(有TOTP优先TOTP) + totp_key = await user.awaitable_attrs.totp_key + if totp_key is not None and api_version >= 20240101: + verify_method = "totp" + else: + verify_method = "mail" + + # 设置选择的验证方法到Redis中,避免重复选择 + if api_version >= 20250913: + await LoginSessionService.set_login_method(user.id, token.id, verify_method, redis) + + # 返回符合 osu! API 标准的错误响应 + error_response = { + "error": "User not verified", + "method": verify_method + } + raise HTTPException(status_code=401, detail=error_response) return user diff --git a/app/interfaces/session_verification.py b/app/interfaces/session_verification.py new file mode 100644 index 0000000..878274f --- /dev/null +++ b/app/interfaces/session_verification.py @@ -0,0 +1,74 @@ +""" +会话验证接口 + +基于osu-web的SessionVerificationInterface实现 +用于标准化会话验证行为 +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + + +class SessionVerificationInterface(ABC): + """会话验证接口 + + 定义了会话验证所需的基本操作,参考osu-web的实现 + """ + + @classmethod + @abstractmethod + async def find_for_verification(cls, session_id: str) -> Optional[SessionVerificationInterface]: + """根据会话ID查找会话用于验证 + + Args: + session_id: 会话ID + + Returns: + 会话实例或None + """ + pass + + @abstractmethod + def get_key(self) -> str: + """获取会话密钥/ID""" + pass + + @abstractmethod + def get_key_for_event(self) -> str: + """获取用于事件广播的会话密钥""" + pass + + @abstractmethod + def get_verification_method(self) -> Optional[str]: + """获取当前验证方法 + + Returns: + 验证方法 ('totp', 'mail') 或 None + """ + pass + + @abstractmethod + def is_verified(self) -> bool: + """检查会话是否已验证""" + pass + + @abstractmethod + async def mark_verified(self) -> None: + """标记会话为已验证""" + pass + + @abstractmethod + async def set_verification_method(self, method: str) -> None: + """设置验证方法 + + Args: + method: 验证方法 ('totp', 'mail') + """ + pass + + @abstractmethod + def user_id(self) -> Optional[int]: + """获取关联的用户ID""" + pass diff --git a/app/middleware/__init__.py b/app/middleware/__init__.py new file mode 100644 index 0000000..e9464bf --- /dev/null +++ b/app/middleware/__init__.py @@ -0,0 +1,9 @@ +""" +中间件模块 + +提供会话验证和其他中间件功能 +""" + +from .verify_session import VerifySessionMiddleware, SessionState + +__all__ = ["VerifySessionMiddleware", "SessionState"] diff --git a/app/middleware/session_verification.py b/app/middleware/session_verification.py new file mode 100644 index 0000000..44cbfc2 --- /dev/null +++ b/app/middleware/session_verification.py @@ -0,0 +1,318 @@ +""" +会话验证中间件和状态管理 + +基于osu-web的会话验证系统实现 +""" + +from __future__ import annotations + +import asyncio +from typing import Awaitable, Callable, Optional + +from fastapi import HTTPException, Request, Response, status +from fastapi.responses import JSONResponse +from redis.asyncio import Redis +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.database.lazer_user import User +from app.database.verification import LoginSession +from app.dependencies.database import with_db, get_redis +from app.log import logger +from app.service.verification_service import LoginSessionService + + +class SessionVerificationState: + """会话验证状态管理类 + + 参考osu-web的State类实现 + """ + + def __init__(self, session: LoginSession, user: User, redis: Redis): + self.session = session + self.user = user + self.redis = redis + + @classmethod + async def get_current( + cls, + request: Request, + db: AsyncSession, + redis: Redis, + user: User, + ) -> Optional[SessionVerificationState]: + """获取当前会话验证状态""" + try: + # 从请求头或token中获取会话信息 + session_token = cls._extract_session_token(request) + if not session_token: + return None + + # 查找会话 + session = await LoginSessionService.find_for_verification(db, session_token) + if not session or session.user_id != user.id: + return None + + return cls(session, user, redis) + except Exception as e: + logger.error(f"[Session Verification] Error getting current state: {e}") + return None + + @staticmethod + def _extract_session_token(request: Request) -> Optional[str]: + """从请求中提取会话token""" + # 尝试从Authorization header提取 + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + return auth_header[7:] # 移除"Bearer "前缀 + + # 可以扩展其他提取方式 + return None + + def get_method(self) -> str: + """获取验证方法 + + 参考osu-web的逻辑,智能选择验证方法 + """ + current_method = self.session.verification_method + + if current_method is None: + # 智能选择验证方法 + # 参考osu-web: API版本 < 20250913 或用户没有TOTP时使用邮件验证 + # 这里简化为检查用户是否有TOTP + totp_key = getattr(self.user, 'totp_key', None) + current_method = 'totp' if totp_key else 'mail' + + # 设置验证方法 + asyncio.create_task(self._set_verification_method(current_method)) + + return current_method + + async def _set_verification_method(self, method: str) -> None: + """内部方法:设置验证方法""" + try: + token_id = self.session.token_id + if token_id is not None and method in ['totp', 'mail']: + # 类型检查确保method是正确的字面量类型 + verification_method = method if method in ['totp', 'mail'] else 'totp' + await LoginSessionService.set_login_method( + self.user.id, token_id, verification_method, self.redis # type: ignore + ) + except Exception as e: + logger.error(f"[Session Verification] Error setting verification method: {e}") + + def is_verified(self) -> bool: + """检查会话是否已验证""" + return self.session.is_verified + + async def mark_verified(self) -> None: + """标记会话为已验证""" + try: + # 创建专用数据库会话 + db = with_db() + try: + token_id = self.session.token_id + if token_id is not None: + await LoginSessionService.mark_session_verified( + db, self.redis, self.user.id, token_id + ) + finally: + await db.close() + except Exception as e: + logger.error(f"[Session Verification] Error marking session verified: {e}") + + def get_key(self) -> str: + """获取会话密钥""" + return str(self.session.id) if self.session.id else "" + + def get_key_for_event(self) -> str: + """获取用于事件广播的会话密钥""" + return LoginSessionService.get_key_for_event(self.get_key()) + + def user_id(self) -> int: + """获取用户ID""" + return self.user.id + + async def issue_mail_if_needed(self) -> None: + """如果需要,发送验证邮件""" + try: + if self.get_method() == "mail": + from app.service.verification_service import EmailVerificationService + + # 创建专用数据库会话发送邮件 + db = with_db() + try: + await EmailVerificationService.send_verification_email( + db, self.redis, self.user.id, self.user.username, + self.user.email, None, None + ) + finally: + await db.close() + except Exception as e: + logger.error(f"[Session Verification] Error issuing mail: {e}") + + +class SessionVerificationController: + """会话验证控制器 + + 参考osu-web的Controller类实现 + """ + + # 需要跳过验证的路由(参考osu-web的SKIP_VERIFICATION_ROUTES) + SKIP_VERIFICATION_ROUTES = { + "/api/v2/session/verify", + "/api/v2/session/verify/reissue", + "/api/v2/me", + "/api/v2/logout", + "/oauth/token", + } + + @staticmethod + def should_skip_verification(request: Request) -> bool: + """检查是否应该跳过验证""" + path = request.url.path + return path in SessionVerificationController.SKIP_VERIFICATION_ROUTES + + @staticmethod + async def initiate_verification( + state: SessionVerificationState, + request: Request, + ) -> Response: + """启动会话验证流程 + + 参考osu-web的initiate方法 + """ + try: + method = state.get_method() + + # 如果是邮件验证,发送验证邮件 + if method == "mail": + await state.issue_mail_if_needed() + + # API请求返回JSON响应 + if request.url.path.startswith("/api/"): + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"method": method} + ) + + # 其他情况可以扩展支持HTML响应 + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "authentication": "verify", + "method": method, + "message": "Session verification required" + } + ) + + except Exception as e: + logger.error(f"[Session Verification] Error initiating verification: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Verification initiation failed" + ) + + +class SessionVerificationMiddleware: + """会话验证中间件 + + 参考osu-web的VerifyUser中间件实现 + """ + + def __init__(self, app: Callable[[Request], Awaitable[Response]]): + self.app = app + + async def __call__(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + """中间件主要逻辑""" + try: + # 检查是否需要跳过验证 + if SessionVerificationController.should_skip_verification(request): + return await call_next(request) + + # 获取依赖项 + user = await self._get_user(request) + if not user: + # 未认证用户跳过验证 + return await call_next(request) + + # 获取数据库和Redis连接 + db = await self._get_db() + redis = await self._get_redis() + + # 获取会话验证状态 + state = await SessionVerificationState.get_current(request, db, redis, user) + if not state: + # 无法获取会话状态,继续请求 + return await call_next(request) + + # 检查是否已验证 + if state.is_verified(): + # 已验证,继续请求 + return await call_next(request) + + # 检查是否需要验证 + if not self._requires_verification(request): + return await call_next(request) + + # 启动验证流程 + return await SessionVerificationController.initiate_verification(state, request) + + except Exception as e: + logger.error(f"[Session Verification Middleware] Unexpected error: {e}") + # 出错时允许请求继续,避免阻塞正常流程 + return await call_next(request) + + async def _get_user(self, request: Request) -> Optional[User]: + """获取当前用户""" + try: + # 这里需要手动获取用户,因为在中间件中无法直接使用依赖注入 + # 简化实现,实际应该从token中解析用户 + return None # 暂时返回None,需要实际实现 + except Exception: + return None + + async def _get_db(self) -> AsyncSession: + """获取数据库连接""" + return with_db() + + async def _get_redis(self) -> Redis: + """获取Redis连接""" + return get_redis() + + def _requires_verification(self, request: Request) -> bool: + """检查是否需要验证 + + 参考osu-web的requiresVerification方法 + """ + method = request.method + + # GET/HEAD/OPTIONS请求一般不需要验证 + safe_methods = {"GET", "HEAD", "OPTIONS"} + if method in safe_methods: + return False + + # POST/PUT/DELETE等修改操作需要验证 + return True + + +# FastAPI中间件包装器 +class FastAPISessionVerificationMiddleware: + """FastAPI会话验证中间件包装器""" + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope, receive) + + async def call_next(req: Request) -> Response: + # 这里需要调用FastAPI应用 + return Response("OK") # 占位符实现 + + middleware = SessionVerificationMiddleware(call_next) + response = await middleware(request, call_next) + + await response(scope, receive, send) diff --git a/app/middleware/setup.py b/app/middleware/setup.py new file mode 100644 index 0000000..13cdca7 --- /dev/null +++ b/app/middleware/setup.py @@ -0,0 +1,44 @@ +""" +中间件设置和配置 + +展示如何将会话验证中间件集成到FastAPI应用中 +""" + +from fastapi import FastAPI + +from app.config import settings +from app.middleware.verify_session import VerifySessionMiddleware + + +def setup_session_verification_middleware(app: FastAPI) -> None: + """设置会话验证中间件 + + Args: + app: FastAPI应用实例 + """ + # 只在启用会话验证时添加中间件 + if settings.enable_session_verification: + app.add_middleware(VerifySessionMiddleware) + + # 可以在这里添加中间件配置日志 + from app.log import logger + logger.info("[Middleware] Session verification middleware enabled") + else: + from app.log import logger + logger.info("[Middleware] Session verification middleware disabled") + + +def setup_all_middlewares(app: FastAPI) -> None: + """设置所有中间件 + + Args: + app: FastAPI应用实例 + """ + # 设置会话验证中间件 + setup_session_verification_middleware(app) + + # 可以在这里添加其他中间件 + # app.add_middleware(OtherMiddleware) + + from app.log import logger + logger.info("[Middleware] All middlewares configured") diff --git a/app/middleware/verify_session.py b/app/middleware/verify_session.py new file mode 100644 index 0000000..8d406e3 --- /dev/null +++ b/app/middleware/verify_session.py @@ -0,0 +1,287 @@ +""" +FastAPI会话验证中间件 + +基于osu-web的会话验证系统,适配FastAPI框架 +""" + +from __future__ import annotations + +from typing import Callable, Optional + +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from redis.asyncio import Redis +from sqlmodel.ext.asyncio.session import AsyncSession +from starlette.middleware.base import BaseHTTPMiddleware + +from app.database.lazer_user import User +from app.database.verification import LoginSession +from app.dependencies.database import with_db, get_redis +from app.auth import get_token_by_access_token +from app.log import logger +from app.service.verification_service import LoginSessionService +from sqlmodel import select + + +class VerifySessionMiddleware(BaseHTTPMiddleware): + """会话验证中间件 + + 参考osu-web的VerifyUser中间件,适配FastAPI + """ + + # 需要跳过验证的路由 + SKIP_VERIFICATION_ROUTES = { + "/api/v2/session/verify", + "/api/v2/session/verify/reissue", + "/api/v2/me", + "/api/v2/logout", + "/oauth/token", + "/health", + "/metrics", + "/docs", + "/openapi.json", + "/redoc", + } + + # 需要强制验证的路由模式(敏感操作) + ALWAYS_VERIFY_PATTERNS = { + "/api/v2/account/", + "/api/v2/settings/", + "/api/private/admin/", + } + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """中间件主处理逻辑""" + try: + # 检查是否跳过验证 + if self._should_skip_verification(request): + return await call_next(request) + + # 获取当前用户 + user = await self._get_current_user(request) + if not user: + # 未登录用户跳过验证 + return await call_next(request) + + # 获取会话状态 + session_state = await self._get_session_state(request, user) + if not session_state: + # 无会话状态,继续请求 + return await call_next(request) + + # 检查是否已验证 + if session_state.is_verified(): + return await call_next(request) + + # 检查是否需要验证 + if not self._requires_verification(request, user): + return await call_next(request) + + # 启动验证流程 + return await self._initiate_verification(request, session_state) + + except Exception as e: + logger.error(f"[Verify Session Middleware] Error: {e}") + # 出错时允许请求继续,避免阻塞 + return await call_next(request) + + def _should_skip_verification(self, request: Request) -> bool: + """检查是否应该跳过验证""" + path = request.url.path + + # 完全匹配的跳过路由 + if path in self.SKIP_VERIFICATION_ROUTES: + return True + + # 非API请求跳过 + if not path.startswith("/api/"): + return True + + return False + + def _requires_verification(self, request: Request, user: User) -> bool: + """检查是否需要验证""" + path = request.url.path + method = request.method + + # 检查是否为强制验证的路由 + for pattern in self.ALWAYS_VERIFY_PATTERNS: + if path.startswith(pattern): + return True + + # 特权用户或非活跃用户需要验证 + if hasattr(user, 'is_privileged') and user.is_privileged(): + return True + if hasattr(user, 'is_inactive') and user.is_inactive(): + return True + + # 安全方法(GET/HEAD/OPTIONS)一般不需要验证 + safe_methods = {"GET", "HEAD", "OPTIONS"} + if method in safe_methods: + return False + + # 修改操作(POST/PUT/DELETE/PATCH)需要验证 + return method in {"POST", "PUT", "DELETE", "PATCH"} + + async def _get_current_user(self, request: Request) -> Optional[User]: + """获取当前用户""" + try: + # 从Authorization header提取token + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return None + + token = auth_header[7:] # 移除"Bearer "前缀 + + # 创建专用数据库会话 + db = with_db() + try: + # 获取token记录 + token_record = await get_token_by_access_token(db, token) + if not token_record: + return None + + # 获取用户 + user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() + return user + finally: + await db.close() + + except Exception as e: + logger.debug(f"[Verify Session Middleware] Error getting user: {e}") + return None + + async def _get_session_state(self, request: Request, user: User) -> Optional[SessionState]: + """获取会话状态""" + try: + # 提取会话token(这里简化为使用相同的auth token) + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return None + + session_token = auth_header[7:] + + # 获取数据库和Redis连接 + db = with_db() + try: + redis = get_redis() + + # 查找会话 + session = await LoginSessionService.find_for_verification(db, session_token) + if not session or session.user_id != user.id: + return None + + return SessionState(session, user, redis, db) + finally: + await db.close() + + except Exception as e: + logger.error(f"[Verify Session Middleware] Error getting session state: {e}") + return None + + async def _initiate_verification(self, request: Request, state: SessionState) -> Response: + """启动验证流程""" + try: + method = await state.get_method() + + # 如果是邮件验证,可以在这里触发发送邮件 + if method == "mail": + await state.issue_mail_if_needed() + + # 返回验证要求响应 + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "method": method, + "message": "Session verification required" + } + ) + + except Exception as e: + logger.error(f"[Verify Session Middleware] Error initiating verification: {e}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"error": "Verification initiation failed"} + ) + + +class SessionState: + """会话状态类 + + 简化版本的会话状态管理 + """ + + def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession): + self.session = session + self.user = user + self.redis = redis + self.db = db + self._verification_method: Optional[str] = None + + def is_verified(self) -> bool: + """检查会话是否已验证""" + return self.session.is_verified + + async def get_method(self) -> str: + """获取验证方法""" + if self._verification_method is None: + # 从Redis获取已设置的方法 + token_id = self.session.token_id + if token_id is not None: + self._verification_method = await LoginSessionService.get_login_method( + self.user.id, token_id, self.redis + ) + + # 如果没有设置,智能选择 + if self._verification_method is None: + # 检查用户是否有TOTP密钥 + await self.user.awaitable_attrs.totp_key # 预加载 + totp_key = getattr(self.user, 'totp_key', None) + self._verification_method = 'totp' if totp_key else 'mail' + + # 保存选择的方法 + token_id = self.session.token_id + if token_id is not None: + await LoginSessionService.set_login_method( + self.user.id, token_id, self._verification_method, self.redis + ) + + return self._verification_method + + async def mark_verified(self) -> None: + """标记会话为已验证""" + try: + token_id = self.session.token_id + if token_id is not None: + await LoginSessionService.mark_session_verified( + self.db, self.redis, self.user.id, token_id + ) + self.session.is_verified = True # 更新本地状态 + except Exception as e: + logger.error(f"[Session State] Error marking verified: {e}") + + async def issue_mail_if_needed(self) -> None: + """如果需要,发送验证邮件""" + try: + if await self.get_method() == "mail": + from app.service.verification_service import EmailVerificationService + + # 这里可以触发邮件发送 + await EmailVerificationService.send_verification_email( + self.db, self.redis, self.user.id, self.user.username, + self.user.email, None, None + ) + except Exception as e: + logger.error(f"[Session State] Error issuing mail: {e}") + + def get_key(self) -> str: + """获取会话密钥""" + return str(self.session.id) if self.session.id else "" + + def get_key_for_event(self) -> str: + """获取用于事件广播的会话密钥""" + return LoginSessionService.get_key_for_event(self.get_key()) + + def user_id(self) -> int: + """获取用户ID""" + return self.user.id diff --git a/app/router/notification/server.py b/app/router/notification/server.py index 16e9836..7fe8859 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -48,9 +48,18 @@ class ChatServer: user_id = user.id if user_id in self.connect_client: del self.connect_client[user_id] + + # 创建频道ID列表的副本以避免在迭代过程中修改字典 + channel_ids_to_process = [] for channel_id, channel in self.channels.items(): if user_id in channel: - channel.remove(user_id) + channel_ids_to_process.append(channel_id) + + # 现在安全地处理每个频道 + for channel_id in channel_ids_to_process: + # 再次检查用户是否仍在频道中(防止并发修改) + if channel_id in self.channels and user_id in self.channels[channel_id]: + self.channels[channel_id].remove(user_id) # 使用明确的查询避免延迟加载 db_channel = ( await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id)) diff --git a/app/router/private/totp.py b/app/router/private/totp.py index 7e65e57..8dcb1c3 100644 --- a/app/router/private/totp.py +++ b/app/router/private/totp.py @@ -5,9 +5,8 @@ from app.auth import ( finish_create_totp_key, start_create_totp_key, totp_redis_key, - verify_totp_key, + verify_totp_key_with_replay_protection, ) -from app.config import settings from app.const import BACKUP_CODE_LENGTH from app.database.auth import TotpKeys from app.database.lazer_user import User @@ -19,9 +18,38 @@ from .router import router from fastapi import Body, Depends, HTTPException, Security import pyotp +from pydantic import BaseModel from redis.asyncio import Redis +class TotpStatusResp(BaseModel): + """TOTP状态响应""" + enabled: bool + created_at: str | None = None + + +@router.get( + "/totp/status", + name="检查 TOTP 状态", + description="检查当前用户是否已启用 TOTP 双因素验证", + tags=["验证", "g0v0 API"], + response_model=TotpStatusResp, +) +async def get_totp_status( + current_user: User = Security(get_client_user), +): + """检查用户是否已创建TOTP""" + totp_key = await current_user.awaitable_attrs.totp_key + + if totp_key: + return TotpStatusResp( + enabled=True, + created_at=totp_key.created_at.isoformat() + ) + else: + return TotpStatusResp(enabled=False) + + @router.post( "/totp/create", name="开始 TOTP 创建流程", @@ -44,11 +72,16 @@ async def start_create_totp( previous = await redis.hgetall(totp_redis_key(current_user)) # pyright: ignore[reportGeneralTypeIssues] if previous: # pyright: ignore[reportGeneralTypeIssues] + from app.auth import _generate_totp_account_label, _generate_totp_issuer_name + + account_label = _generate_totp_account_label(current_user) + issuer_name = _generate_totp_issuer_name() + return StartCreateTotpKeyResp( secret=previous["secret"], uri=pyotp.totp.TOTP(previous["secret"]).provisioning_uri( - name=current_user.email, - issuer_name=settings.totp_issuer, + name=account_label, + issuer_name=issuer_name, ), ) return await start_create_totp_key(current_user, redis) @@ -92,12 +125,21 @@ async def finish_create_totp( async def disable_totp( session: Database, code: str = Body(..., embed=True, description="用户提供的 TOTP 代码或备份码"), + redis: Redis = Depends(get_redis), current_user: User = Security(get_client_user), ): totp = await session.get(TotpKeys, current_user.id) if not totp: raise HTTPException(status_code=400, detail="TOTP is not enabled for this user") - if verify_totp_key(totp.secret, code) or (len(code) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp, code)): + + # 使用防重放保护的TOTP验证或备份码验证 + is_totp_valid = False + if len(code) == 6 and code.isdigit(): + is_totp_valid = await verify_totp_key_with_replay_protection(current_user.id, totp.secret, code, redis) + elif len(code) == BACKUP_CODE_LENGTH: + is_totp_valid = check_totp_backup_code(totp, code) + + if is_totp_valid: await session.delete(totp) await session.commit() else: diff --git a/app/router/v2/session_verify.py b/app/router/v2/session_verify.py index 60a1b40..ce60225 100644 --- a/app/router/v2/session_verify.py +++ b/app/router/v2/session_verify.py @@ -6,7 +6,7 @@ from __future__ import annotations from typing import Annotated, Literal -from app.auth import check_totp_backup_code, verify_totp_key +from app.auth import check_totp_backup_code, verify_totp_key_with_replay_protection from app.config import settings from app.const import BACKUP_CODE_LENGTH from app.database.auth import TotpKeys @@ -40,7 +40,11 @@ class SessionReissueResponse(BaseModel): message: str -class VerifyFailed(Exception): ... +class VerifyFailed(Exception): + def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False): + super().__init__(message) + self.reason = reason + self.should_reissue = should_reissue @router.post( @@ -80,28 +84,42 @@ async def verify_session( try: totp_key: TotpKeys | None = await current_user.awaitable_attrs.totp_key if verify_method is None: - verify_method = "totp" if totp_key else "mail" + # 智能选择验证方法(参考osu-web实现) + # API版本较老或用户未设置TOTP时强制使用邮件验证 + #print(api_version, totp_key) + if api_version < 20240101 or totp_key is None: + verify_method = "mail" + else: + verify_method = "totp" await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis) login_method = verify_method if verify_method == "totp": if not totp_key: + # TOTP密钥在验证开始和现在之间被删除(参考osu-web的fallback机制) if settings.enable_email_verification: await LoginSessionService.set_login_method(user_id, token_id, "mail", redis) await EmailVerificationService.send_verification_email( db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent ) verify_method = "mail" - raise VerifyFailed("用户未设置 TOTP,已发送邮件验证码") + raise VerifyFailed("用户TOTP已被删除,已切换到邮件验证") # 如果未开启邮箱验证,则直接认为认证通过 # 正常不会进入到这里 - elif verify_totp_key(totp_key.secret, verification_key): + elif await verify_totp_key_with_replay_protection(user_id, totp_key.secret, verification_key, redis): pass elif len(verification_key) == BACKUP_CODE_LENGTH and check_totp_backup_code(totp_key, verification_key): login_method = "totp_backup_code" else: - raise VerifyFailed("TOTP 验证失败") + # 记录详细的验证失败原因(参考osu-web的错误处理) + if len(verification_key) != 6: + raise VerifyFailed("TOTP验证码长度错误,应为6位数字", reason="incorrect_length") + elif not verification_key.isdigit(): + raise VerifyFailed("TOTP验证码格式错误,应为纯数字", reason="incorrect_format") + else: + # 可能是密钥错误或者重放攻击 + raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key") else: success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key) if not success: @@ -127,7 +145,28 @@ async def verify_session( login_method=login_method, notes=str(e), ) - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": verify_method}) + + # 构建更详细的错误响应(参考osu-web的错误处理) + error_response = { + "error": str(e), + "method": verify_method, + } + + # 如果有具体的错误原因,添加到响应中 + if hasattr(e, 'reason') and e.reason: + error_response["reason"] = e.reason + + # 如果需要重新发送邮件验证码 + if hasattr(e, 'should_reissue') and e.should_reissue and verify_method == "mail": + try: + await EmailVerificationService.send_verification_email( + db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent + ) + error_response["reissued"] = True + except Exception: + pass # 忽略重发邮件失败的错误 + + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response) @router.post( diff --git a/app/scheduler/database_cleanup_scheduler.py b/app/scheduler/database_cleanup_scheduler.py index 1a21bcf..43a94a2 100644 --- a/app/scheduler/database_cleanup_scheduler.py +++ b/app/scheduler/database_cleanup_scheduler.py @@ -74,10 +74,13 @@ class DatabaseCleanupScheduler: # 清理过期的登录会话 expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db) + # 清理1小时前未验证的登录会话 + unverified_sessions = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1) + # 只在有清理记录时输出总结 - total_cleaned = expired_codes + expired_sessions + total_cleaned = expired_codes + expired_sessions + unverified_sessions if total_cleaned > 0: - logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}") + logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}, unverified: {unverified_sessions}") except Exception as e: logger.error(f"Error during scheduled database cleanup: {e!s}") diff --git a/app/service/database_cleanup_service.py b/app/service/database_cleanup_service.py index 75dfce9..6fea2df 100644 --- a/app/service/database_cleanup_service.py +++ b/app/service/database_cleanup_service.py @@ -10,7 +10,7 @@ from app.database.verification import EmailVerification, LoginSession from app.log import logger from app.utils import utcnow -from sqlmodel import col, select +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -107,7 +107,7 @@ class DatabaseCleanupService: # 查找指定天数前的已使用验证码记录 cutoff_time = utcnow() - timedelta(days=days_old) - stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True)) + stmt = select(EmailVerification).where(EmailVerification.is_used == True) result = await db.exec(stmt) all_used_codes = result.all() @@ -134,6 +134,50 @@ class DatabaseCleanupService: logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}") return 0 + @staticmethod + async def cleanup_unverified_login_sessions(db: AsyncSession, hours_old: int = 1) -> int: + """ + 清理指定小时前创建但仍未验证的登录会话 + + Args: + db: 数据库会话 + hours_old: 清理多少小时前创建但仍未验证的会话,默认1小时 + + Returns: + int: 清理的记录数 + """ + try: + # 计算截止时间 + cutoff_time = utcnow() - timedelta(hours=hours_old) + + # 查找指定时间前创建且仍未验证的会话记录 + stmt = select(LoginSession).where( + LoginSession.is_verified == False, + LoginSession.created_at < cutoff_time + ) + result = await db.exec(stmt) + unverified_sessions = result.all() + + # 删除未验证的会话记录 + deleted_count = 0 + for session in unverified_sessions: + await db.delete(session) + deleted_count += 1 + + await db.commit() + + if deleted_count > 0: + logger.debug( + f"[Cleanup Service] Cleaned up {deleted_count} unverified login sessions older than {hours_old} hour(s)" + ) + + return deleted_count + + except Exception as e: + await db.rollback() + logger.error(f"[Cleanup Service] Error cleaning unverified login sessions: {e!s}") + return 0 + @staticmethod async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int: """ @@ -150,7 +194,7 @@ class DatabaseCleanupService: # 查找指定天数前的已验证会话记录 cutoff_time = utcnow() - timedelta(days=days_old) - stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) + stmt = select(LoginSession).where(LoginSession.is_verified == True) result = await db.exec(stmt) all_verified_sessions = result.all() @@ -200,6 +244,9 @@ class DatabaseCleanupService: # 清理过期的登录会话 results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db) + # 清理1小时前未验证的登录会话 + results["unverified_login_sessions"] = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1) + # 清理7天前的已使用验证码 results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7) @@ -227,6 +274,7 @@ class DatabaseCleanupService: """ try: current_time = utcnow() + cutoff_1_hour = current_time - timedelta(hours=1) cutoff_7_days = current_time - timedelta(days=7) cutoff_30_days = current_time - timedelta(days=30) @@ -240,8 +288,16 @@ class DatabaseCleanupService: expired_sessions_result = await db.exec(expired_sessions_stmt) expired_sessions_count = len(expired_sessions_result.all()) + # 统计1小时前未验证的登录会话数量 + unverified_sessions_stmt = select(LoginSession).where( + LoginSession.is_verified == False, + LoginSession.created_at < cutoff_1_hour + ) + unverified_sessions_result = await db.exec(unverified_sessions_stmt) + unverified_sessions_count = len(unverified_sessions_result.all()) + # 统计7天前的已使用验证码数量 - old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True)) + old_used_codes_stmt = select(EmailVerification).where(EmailVerification.is_used == True) old_used_codes_result = await db.exec(old_used_codes_stmt) all_used_codes = old_used_codes_result.all() old_used_codes_count = len( @@ -249,7 +305,7 @@ class DatabaseCleanupService: ) # 统计30天前的已验证会话数量 - old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) + old_verified_sessions_stmt = select(LoginSession).where(LoginSession.is_verified == True) old_verified_sessions_result = await db.exec(old_verified_sessions_stmt) all_verified_sessions = old_verified_sessions_result.all() old_verified_sessions_count = len( @@ -263,10 +319,12 @@ class DatabaseCleanupService: return { "expired_verification_codes": expired_codes_count, "expired_login_sessions": expired_sessions_count, + "unverified_login_sessions": unverified_sessions_count, "old_used_verification_codes": old_used_codes_count, "old_verified_sessions": old_verified_sessions_count, "total_cleanable": expired_codes_count + expired_sessions_count + + unverified_sessions_count + old_used_codes_count + old_verified_sessions_count, } @@ -276,6 +334,7 @@ class DatabaseCleanupService: return { "expired_verification_codes": 0, "expired_login_sessions": 0, + "unverified_login_sessions": 0, "old_used_verification_codes": 0, "old_verified_sessions": 0, "total_cleanable": 0, diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py index 3d76248..8b4f3cf 100644 --- a/app/service/ranking_cache_service.py +++ b/app/service/ranking_cache_service.py @@ -330,6 +330,7 @@ class RankingCacheService: # 计算统计信息 stats = { + "total": total_users, "total_users": total_users, "last_updated": utcnow().isoformat(), "type": type, diff --git a/app/service/verification_service.py b/app/service/verification_service.py index 1a1dfd2..32c72bf 100644 --- a/app/service/verification_service.py +++ b/app/service/verification_service.py @@ -7,10 +7,11 @@ from __future__ import annotations from datetime import timedelta import secrets import string -from typing import Literal +from typing import Literal, Optional from app.config import settings from app.database.verification import EmailVerification, LoginSession +from app.interfaces.session_verification import SessionVerificationInterface from app.log import logger from app.service.client_detection_service import ClientDetectionService, ClientInfo from app.service.device_trust_service import DeviceTrustService @@ -514,6 +515,26 @@ This email was sent automatically, please do not reply. class LoginSessionService: """登录会话服务""" + # Session verification interface methods + @staticmethod + async def find_for_verification(db: AsyncSession, session_id: str) -> Optional[LoginSession]: + """根据会话ID查找会话用于验证""" + try: + result = await db.exec( + select(LoginSession).where( + LoginSession.session_token == session_id, + LoginSession.expires_at > utcnow(), + ) + ) + return result.first() + except Exception: + return None + + @staticmethod + def get_key_for_event(session_id: str) -> str: + """获取用于事件广播的会话密钥""" + return f"g0v0:{session_id}" + @staticmethod async def create_session( db: AsyncSession, diff --git a/main.py b/main.py index 23c0851..b1463eb 100644 --- a/main.py +++ b/main.py @@ -46,6 +46,8 @@ from fastapi.responses import JSONResponse, RedirectResponse from fastapi_limiter import FastAPILimiter import sentry_sdk +from app.middleware.verify_session import VerifySessionMiddleware + @asynccontextmanager async def lifespan(app: FastAPI): @@ -171,6 +173,10 @@ app.include_router(lio_router) # from app.signalr import signalr_router # app.include_router(signalr_router) +# 会话验证中间件 +if settings.enable_session_verification: + app.add_middleware(VerifySessionMiddleware) + # CORS 配置 origins = [] for url in [*settings.cors_urls, settings.server_url]: diff --git a/migrations/versions/2025-09-24_9419272e4c85_feat_db_add_session_verification_fields_.py b/migrations/versions/2025-09-24_9419272e4c85_feat_db_add_session_verification_fields_.py new file mode 100644 index 0000000..0207889 --- /dev/null +++ b/migrations/versions/2025-09-24_9419272e4c85_feat_db_add_session_verification_fields_.py @@ -0,0 +1,38 @@ +"""feat(db): add session verification fields to login_session + +Revision ID: 9419272e4c85 +Revises: fe8e9f3da298 +Create Date: 2025-09-24 00:46:57.367742 + +""" +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +import sqlmodel + +# revision identifiers, used by Alembic. +revision: str = "9419272e4c85" +down_revision: str | Sequence[str] | None = "fe8e9f3da298" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("login_sessions", sa.Column("session_token", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=True)) + op.add_column("login_sessions", sa.Column("verification_method", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True)) + op.create_index(op.f("ix_login_sessions_session_token"), "login_sessions", ["session_token"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_login_sessions_session_token"), table_name="login_sessions") + op.drop_column("login_sessions", "verification_method") + op.drop_column("login_sessions", "session_token") + # ### end Alembic commands ###