diff --git a/app/auth.py b/app/auth.py index cb13d75..8a8bf01 100644 --- a/app/auth.py +++ b/app/auth.py @@ -217,10 +217,12 @@ async def store_token( access_token: str, refresh_token: str, expires_in: int, + refresh_token_expires_in: int, allow_multiple_devices: bool = True, ) -> OAuthToken: """存储令牌到数据库(支持多设备)""" expires_at = utcnow() + timedelta(seconds=expires_in) + refresh_token_expires_at = utcnow() + timedelta(seconds=refresh_token_expires_in) if not allow_multiple_devices: # 旧的行为:删除用户的旧令牌(单设备模式) @@ -266,6 +268,7 @@ async def store_token( scope=",".join(scopes), refresh_token=refresh_token, expires_at=expires_at, + refresh_token_expires_at=refresh_token_expires_at, ) db.add(token_record) await db.commit() @@ -290,7 +293,7 @@ async def get_token_by_refresh_token(db: AsyncSession, refresh_token: str) -> OA """根据刷新令牌获取令牌记录""" statement = select(OAuthToken).where( OAuthToken.refresh_token == refresh_token, - OAuthToken.expires_at > utcnow(), + OAuthToken.refresh_token_expires_at > utcnow(), ) return (await db.exec(statement)).first() diff --git a/app/config.py b/app/config.py index 0254364..1761d7d 100644 --- a/app/config.py +++ b/app/config.py @@ -170,6 +170,11 @@ STORAGE_SETTINGS='{ Field(default=1440, description="访问令牌过期时间(分钟)"), "JWT 设置", ] + refresh_token_expire_minutes: Annotated[ + int, + Field(default=21600, description="刷新令牌过期时间(分钟)"), + "JWT 设置", + ] # 15 days jwt_audience: Annotated[ str, Field(default="5", description="JWT 受众"), @@ -349,11 +354,6 @@ STORAGE_SETTINGS='{ Field(default=30, description="设备信任持续天数"), "验证服务设置", ] - location_trust_duration_days: Annotated[ - int, - Field(default=90, description="位置信任持续天数"), - "验证服务设置", - ] smtp_server: Annotated[ str, Field(default="localhost", description="SMTP 服务器地址"), diff --git a/app/const.py b/app/const.py index 143e81c..439bc83 100644 --- a/app/const.py +++ b/app/const.py @@ -3,3 +3,5 @@ from __future__ import annotations BANCHOBOT_ID = 2 BACKUP_CODE_LENGTH = 10 + +SUPPORT_TOTP_VERIFICATION_VER = 20250913 diff --git a/app/database/__init__.py b/app/database/__init__.py index bac7cc1..328710a 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -68,7 +68,7 @@ from .user_account_history import ( UserAccountHistoryType, ) from .user_login_log import UserLoginLog -from .verification import EmailVerification, LoginSession +from .verification import EmailVerification, LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp __all__ = [ "APIUploadedRoom", @@ -96,6 +96,7 @@ __all__ = [ "ItemAttemptsCount", "ItemAttemptsResp", "LoginSession", + "LoginSessionResp", "MeResp", "MonthlyPlaycounts", "MultiplayerEvent", @@ -131,6 +132,8 @@ __all__ = [ "TeamMember", "TeamRequest", "TotpKeys", + "TrustedDevice", + "TrustedDeviceResp", "User", "UserAccountHistory", "UserAccountHistoryResp", diff --git a/app/database/auth.py b/app/database/auth.py index 3024d20..11995d8 100644 --- a/app/database/auth.py +++ b/app/database/auth.py @@ -32,7 +32,8 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True): refresh_token: str = Field(max_length=500, unique=True) token_type: str = Field(default="Bearer", max_length=20) scope: str = Field(default="*", max_length=100) - expires_at: datetime = Field(sa_column=Column(DateTime)) + expires_at: datetime = Field(sa_column=Column(DateTime, index=True)) + refresh_token_expires_at: datetime = Field(sa_column=Column(DateTime, index=True)) created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime)) user: "User" = Relationship() diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index d14b3e1..8d34778 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -243,7 +243,6 @@ class UserResp(UserBase): user_achievements: list[UserAchievementResp] = Field(default_factory=list) cover_url: str = "" # deprecated team: Team | None = None - session_verified: bool = True daily_challenge_user_stats: DailyChallengeStatsResp | None = None default_group: str = "" is_deleted: bool = False # TODO @@ -425,27 +424,18 @@ class UserResp(UserBase): ) ).one() - if "session_verified" in include: - from app.service.verification_service import LoginSessionService - - u.session_verified = ( - not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id) - if token_id - else True - ) - return u class MeResp(UserResp): session_verification_method: Literal["totp", "mail"] | None = None + session_verified: bool = True @classmethod async def from_db( cls, obj: User, session: AsyncSession, - include: list[str] = [], ruleset: GameMode | None = None, *, token_id: int | None = None, @@ -453,7 +443,12 @@ class MeResp(UserResp): from app.dependencies.database import get_redis from app.service.verification_service import LoginSessionService - u = await super().from_db(obj, session, ["session_verified", *include], ruleset, token_id=token_id) + u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id) + u.session_verified = ( + not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id) + if token_id + else True + ) u = cls.model_validate(u.model_dump()) if (settings.enable_totp_verification or settings.enable_email_verification) and token_id: redis = get_redis() diff --git a/app/database/verification.py b/app/database/verification.py index f5b9e32..f7e42b3 100644 --- a/app/database/verification.py +++ b/app/database/verification.py @@ -3,17 +3,26 @@ """ from datetime import datetime -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional -from app.utils import utcnow +from app.helpers.geoip_helper import GeoIPHelper +from app.models.model import UserAgentInfo, UTCBaseModel +from app.utils import extract_user_agent, utcnow +from pydantic import BaseModel from sqlalchemy import BigInteger, Column, ForeignKey -from sqlmodel import Field, Integer, Relationship, SQLModel +from sqlmodel import VARCHAR, DateTime, Field, Integer, Relationship, SQLModel, Text if TYPE_CHECKING: from .auth import OAuthToken +class Location(BaseModel): + country: str = "" + city: str = "" + country_code: str = "" + + class EmailVerification(SQLModel, table=True): """邮件验证记录""" @@ -31,25 +40,90 @@ class EmailVerification(SQLModel, table=True): user_agent: str | None = Field(default=None) # 用户代理 -class LoginSession(SQLModel, table=True): +class LoginSessionBase(SQLModel): """登录会话记录""" - __tablename__: str = "login_sessions" - - id: int | None = Field(default=None, primary_key=True) + id: int = Field(default=None, primary_key=True) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True)) - token_id: int | None = Field( - sa_column=Column(Integer, ForeignKey("oauth_tokens.id", ondelete="SET NULL"), nullable=True, index=True) - ) - ip_address: str = Field() # 登录IP - user_agent: str | None = Field(default=None, max_length=250) - country_code: str | None = Field(default=None) + ip_address: str = Field(sa_column=Column(VARCHAR(45), nullable=False), default="127.0.0.1", exclude=True) + user_agent: str | None = Field(default=None, sa_column=Column(Text)) is_verified: bool = Field(default=False) # 是否已验证 created_at: datetime = Field(default_factory=lambda: utcnow()) verified_at: datetime | None = Field(default=None) expires_at: datetime = Field() # 会话过期时间 - is_new_location: bool = Field(default=False) # 是否新位置登录 - session_token: str | None = Field(default=None, max_length=64, index=True) # 会话令牌 - verification_method: str | None = Field(default=None, max_length=20) # 验证方法 (totp/mail) + device_id: int | None = Field( + sa_column=Column(BigInteger, ForeignKey("trusted_devices.id", ondelete="SET NULL"), nullable=True, index=True), + default=None, + ) + +class LoginSession(LoginSessionBase, table=True): + __tablename__: str = "login_sessions" + token_id: int | None = Field( + sa_column=Column(Integer, ForeignKey("oauth_tokens.id", ondelete="SET NULL"), nullable=True, index=True), + exclude=True, + ) + is_new_device: bool = Field(default=False, exclude=True) # 是否新位置登录 + web_uuid: str | None = Field(sa_column=Column(VARCHAR(36), nullable=True), default=None, exclude=True) + verification_method: str | None = Field(default=None, max_length=20, exclude=True) # 验证方法 (totp/mail) + + device: Optional["TrustedDevice"] = Relationship(back_populates="sessions") token: Optional["OAuthToken"] = Relationship(back_populates="login_session") + + +class LoginSessionResp(UTCBaseModel, LoginSessionBase): + user_agent_info: UserAgentInfo | None = None + location: Location | None = None + + @classmethod + def from_db(cls, obj: LoginSession, get_geoip_helper: GeoIPHelper) -> "LoginSessionResp": + session = cls.model_validate(obj.model_dump()) + session.user_agent_info = extract_user_agent(session.user_agent) + if obj.ip_address: + loc = get_geoip_helper.lookup(obj.ip_address) + session.location = Location( + country=loc.get("country_name", ""), + city=loc.get("city_name", ""), + country_code=loc.get("country_code", ""), + ) + else: + session.location = None + return session + + +class TrustedDeviceBase(SQLModel): + id: int = Field(default=None, primary_key=True) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True)) + ip_address: str = Field(sa_column=Column(VARCHAR(45), nullable=False), default="127.0.0.1", exclude=True) + user_agent: str = Field(sa_column=Column(Text, nullable=False)) + client_type: Literal["web", "client"] = Field(sa_column=Column(VARCHAR(10), nullable=False), default="web") + created_at: datetime = Field(default_factory=utcnow) + last_used_at: datetime = Field(default_factory=utcnow) + expires_at: datetime = Field(sa_column=Column(DateTime)) + + +class TrustedDevice(TrustedDeviceBase, table=True): + __tablename__: str = "trusted_devices" + web_uuid: str | None = Field(sa_column=Column(VARCHAR(36), nullable=True), default=None) + + sessions: list["LoginSession"] = Relationship(back_populates="device", passive_deletes=True) + + +class TrustedDeviceResp(UTCBaseModel, TrustedDeviceBase): + user_agent_info: UserAgentInfo | None = None + location: Location | None = None + + @classmethod + def from_db(cls, device: TrustedDevice, get_geoip_helper: GeoIPHelper) -> "TrustedDeviceResp": + device_ = cls.model_validate(device.model_dump()) + device_.user_agent_info = extract_user_agent(device_.user_agent) + if device_.ip_address: + loc = get_geoip_helper.lookup(device_.ip_address) + device_.location = Location( + country=loc.get("country_name", ""), + city=loc.get("city_name", ""), + country_code=loc.get("country_code", ""), + ) + else: + device_.location = None + return device_ diff --git a/app/dependencies/user_agent.py b/app/dependencies/user_agent.py new file mode 100644 index 0000000..6f776b1 --- /dev/null +++ b/app/dependencies/user_agent.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import Annotated + +from app.models.model import UserAgentInfo as UserAgentInfoModel +from app.utils import extract_user_agent + +from fastapi import Depends, Header + + +def get_user_agent_info(user_agent: str | None = Header(None, include_in_schema=False)) -> UserAgentInfoModel: + return extract_user_agent(user_agent) + + +UserAgentInfo = Annotated[UserAgentInfoModel, Depends(get_user_agent_info)] diff --git a/app/middleware/session_verification.py b/app/middleware/session_verification.py deleted file mode 100644 index eceb931..0000000 --- a/app/middleware/session_verification.py +++ /dev/null @@ -1,301 +0,0 @@ -""" -会话验证中间件和状态管理 - -基于osu-web的会话验证系统实现 -""" - -from __future__ import annotations - -from collections.abc import Awaitable, Callable -from typing import ClassVar, Literal, cast - -from app.database.lazer_user import User -from app.database.verification import LoginSession -from app.dependencies.database import get_redis, with_db -from app.log import logger -from app.service.verification_service import LoginSessionService -from app.utils import bg_tasks - -from fastapi import HTTPException, Request, Response, status -from fastapi.responses import JSONResponse -from redis.asyncio import Redis -from sqlmodel.ext.asyncio.session import AsyncSession - - -class SessionVerificationState: - """会话验证状态管理类 - - 参考osu-web的State类实现 - """ - - def __init__(self, session: LoginSession, user: User, redis: Redis): - self.session = session - self.user = user - self.redis = redis - - @classmethod - async def get_current( - cls, - request: Request, - db: AsyncSession, - redis: Redis, - user: User, - ) -> SessionVerificationState | None: - """获取当前会话验证状态""" - try: - # 从请求头或token中获取会话信息 - session_token = cls._extract_session_token(request) - if not session_token: - return None - - # 查找会话 - session = await LoginSessionService.find_for_verification(db, session_token) - if not session or session.user_id != user.id: - return None - - return cls(session, user, redis) - except Exception as e: - logger.error(f"[Session Verification] Error getting current state: {e}") - return None - - @staticmethod - def _extract_session_token(request: Request) -> str | None: - """从请求中提取会话token""" - # 尝试从Authorization header提取 - auth_header = request.headers.get("Authorization", "") - if auth_header.startswith("Bearer "): - return auth_header[7:] # 移除"Bearer "前缀 - - # 可以扩展其他提取方式 - return None - - def get_method(self) -> str: - """获取验证方法 - - 参考osu-web的逻辑,智能选择验证方法 - """ - current_method = self.session.verification_method - - if current_method is None: - # 智能选择验证方法 - # 参考osu-web: API版本 < 20250913 或用户没有TOTP时使用邮件验证 - # 这里简化为检查用户是否有TOTP - totp_key = getattr(self.user, "totp_key", None) - current_method = "totp" if totp_key else "mail" - - # 设置验证方法 - bg_tasks.add_task(self._set_verification_method, current_method) - - return current_method - - async def _set_verification_method(self, method: str) -> None: - """内部方法:设置验证方法""" - try: - token_id = self.session.token_id - if token_id is not None and method in ["totp", "mail"]: - # 类型检查确保method是正确的字面量类型 - verification_method = method if method in ["totp", "mail"] else "totp" - await LoginSessionService.set_login_method( - self.user.id, - token_id, - cast(Literal["totp", "mail"], verification_method), - self.redis, - ) - except Exception as e: - logger.error(f"[Session Verification] Error setting verification method: {e}") - - def is_verified(self) -> bool: - """检查会话是否已验证""" - return self.session.is_verified - - async def mark_verified(self) -> None: - """标记会话为已验证""" - try: - async with with_db() as db: - # 创建专用数据库会话 - token_id = self.session.token_id - if token_id is not None: - await LoginSessionService.mark_session_verified(db, self.redis, self.user.id, token_id) - except Exception as e: - logger.error(f"[Session Verification] Error marking session verified: {e}") - - def get_key(self) -> str: - """获取会话密钥""" - return str(self.session.id) if self.session.id else "" - - def get_key_for_event(self) -> str: - """获取用于事件广播的会话密钥""" - return LoginSessionService.get_key_for_event(self.get_key()) - - def user_id(self) -> int: - """获取用户ID""" - return self.user.id - - async def issue_mail_if_needed(self) -> None: - """如果需要,发送验证邮件""" - try: - if self.get_method() == "mail": - from app.service.verification_service import EmailVerificationService - - # 创建专用数据库会话发送邮件 - async with with_db() as db: - await EmailVerificationService.send_verification_email( - db, self.redis, self.user.id, self.user.username, self.user.email, None, None - ) - except Exception as e: - logger.error(f"[Session Verification] Error issuing mail: {e}") - - -class SessionVerificationController: - """会话验证控制器 - - 参考osu-web的Controller类实现 - """ - - # 需要跳过验证的路由(参考osu-web的SKIP_VERIFICATION_ROUTES) - SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = { - "/api/v2/session/verify", - "/api/v2/session/verify/reissue", - "/api/v2/me", - "/api/v2/logout", - "/oauth/token", - } - - @staticmethod - def should_skip_verification(request: Request) -> bool: - """检查是否应该跳过验证""" - path = request.url.path - return path in SessionVerificationController.SKIP_VERIFICATION_ROUTES - - @staticmethod - async def initiate_verification( - state: SessionVerificationState, - request: Request, - ) -> Response: - """启动会话验证流程 - - 参考osu-web的initiate方法 - """ - try: - method = state.get_method() - - # 如果是邮件验证,发送验证邮件 - if method == "mail": - await state.issue_mail_if_needed() - - # API请求返回JSON响应 - if request.url.path.startswith("/api/"): - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": method}) - - # 其他情况可以扩展支持HTML响应 - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"authentication": "verify", "method": method, "message": "Session verification required"}, - ) - - except Exception as e: - logger.error(f"[Session Verification] Error initiating verification: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Verification initiation failed" - ) - - -class SessionVerificationMiddleware: - """会话验证中间件 - - 参考osu-web的VerifyUser中间件实现 - """ - - def __init__(self, app: Callable[[Request], Awaitable[Response]]): - self.app = app - - async def __call__(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: - """中间件主要逻辑""" - try: - # 检查是否需要跳过验证 - if SessionVerificationController.should_skip_verification(request): - return await call_next(request) - - # 获取依赖项 - user = await self._get_user(request) - if not user: - # 未认证用户跳过验证 - return await call_next(request) - - # 获取数据库和Redis连接 - async with with_db() as db: - redis = await self._get_redis() - - # 获取会话验证状态 - state = await SessionVerificationState.get_current(request, db, redis, user) - if not state: - # 无法获取会话状态,继续请求 - return await call_next(request) - - # 检查是否已验证 - if state.is_verified(): - # 已验证,继续请求 - return await call_next(request) - - # 检查是否需要验证 - if not self._requires_verification(request): - return await call_next(request) - - # 启动验证流程 - return await SessionVerificationController.initiate_verification(state, request) - - except Exception as e: - logger.error(f"[Session Verification Middleware] Unexpected error: {e}") - # 出错时允许请求继续,避免阻塞正常流程 - return await call_next(request) - - async def _get_user(self, request: Request) -> User | None: - """获取当前用户""" - try: - # 这里需要手动获取用户,因为在中间件中无法直接使用依赖注入 - # 简化实现,实际应该从token中解析用户 - return None # 暂时返回None,需要实际实现 - except Exception: - return None - - async def _get_redis(self) -> Redis: - """获取Redis连接""" - return get_redis() - - def _requires_verification(self, request: Request) -> bool: - """检查是否需要验证 - - 参考osu-web的requiresVerification方法 - """ - method = request.method - - # GET/HEAD/OPTIONS请求一般不需要验证 - safe_methods = {"GET", "HEAD", "OPTIONS"} - if method in safe_methods: - return False - - # POST/PUT/DELETE等修改操作需要验证 - return True - - -# FastAPI中间件包装器 -class FastAPISessionVerificationMiddleware: - """FastAPI会话验证中间件包装器""" - - def __init__(self, app): - self.app = app - - async def __call__(self, scope, receive, send): - if scope["type"] != "http": - return await self.app(scope, receive, send) - - request = Request(scope, receive) - - async def call_next(req: Request) -> Response: - # 这里需要调用FastAPI应用 - return Response("OK") # 占位符实现 - - middleware = SessionVerificationMiddleware(call_next) - response = await middleware(request, call_next) - - await response(scope, receive, send) diff --git a/app/middleware/verify_session.py b/app/middleware/verify_session.py index fbc7329..2124ab5 100644 --- a/app/middleware/verify_session.py +++ b/app/middleware/verify_session.py @@ -10,11 +10,13 @@ from collections.abc import Callable from typing import ClassVar from app.auth import get_token_by_access_token +from app.const import SUPPORT_TOTP_VERIFICATION_VER from app.database.lazer_user import User from app.database.verification import LoginSession from app.dependencies.database import get_redis, with_db from app.log import logger from app.service.verification_service import LoginSessionService +from app.utils import extract_user_agent from fastapi import Request, Response, status from fastapi.responses import JSONResponse @@ -34,7 +36,9 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = { "/api/v2/session/verify", "/api/v2/session/verify/reissue", + "/api/v2/session/verify/mail-fallback", "/api/v2/me", + "/api/v2/me/", "/api/v2/logout", "/oauth/token", "/health", @@ -44,10 +48,8 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): "/redoc", } - # 需要强制验证的路由模式(敏感操作) + # 总是需要验证的路由前缀 ALWAYS_VERIFY_PATTERNS: ClassVar[set[str]] = { - "/api/v2/account/", - "/api/v2/settings/", "/api/private/admin/", } @@ -110,9 +112,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): if path.startswith(pattern): return True - # 特权用户或非活跃用户需要验证 - # if hasattr(user, 'is_privileged') and user.is_privileged(): - # return True if not user.is_active: return True @@ -154,6 +153,14 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): try: # 提取会话token(这里简化为使用相同的auth token) auth_header = request.headers.get("Authorization", "") + api_version = 0 + raw_api_version = request.headers.get("x-api-version") + if raw_api_version is not None: + try: + api_version = int(raw_api_version) + except ValueError: + api_version = 0 + if not auth_header.startswith("Bearer "): return None @@ -168,7 +175,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): if not session or session.user_id != user.id: return None - return SessionState(session, user, redis, db) + return SessionState(session, user, redis, db, api_version) except Exception as e: logger.error(f"[Verify Session Middleware] Error getting session state: {e}") @@ -178,8 +185,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): """启动验证流程""" try: method = await state.get_method() - - # 如果是邮件验证,可以在这里触发发送邮件 if method == "mail": await state.issue_mail_if_needed() @@ -202,11 +207,12 @@ class SessionState: 简化版本的会话状态管理 """ - def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession): + def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession, api_version: int = 0) -> None: self.session = session self.user = user self.redis = redis self.db = db + self.api_version = api_version self._verification_method: str | None = None def is_verified(self) -> bool: @@ -223,14 +229,15 @@ class SessionState: self.user.id, token_id, self.redis ) - # 如果没有设置,智能选择 if self._verification_method is None: - # 检查用户是否有TOTP密钥 - await self.user.awaitable_attrs.totp_key # 预加载 - totp_key = getattr(self.user, "totp_key", None) + if self.api_version < SUPPORT_TOTP_VERIFICATION_VER: + self._verification_method = "mail" + return self._verification_method + + await self.user.awaitable_attrs.totp_key + totp_key = self.user.totp_key self._verification_method = "totp" if totp_key else "mail" - # 保存选择的方法 token_id = self.session.token_id if token_id is not None: await LoginSessionService.set_login_method( @@ -244,8 +251,15 @@ class SessionState: try: token_id = self.session.token_id if token_id is not None: - await LoginSessionService.mark_session_verified(self.db, self.redis, self.user.id, token_id) - self.session.is_verified = True # 更新本地状态 + await LoginSessionService.mark_session_verified( + self.db, + self.redis, + self.user.id, + token_id, + self.session.ip_address, + extract_user_agent(self.session.user_agent), + self.session.web_uuid, + ) except Exception as e: logger.error(f"[Session State] Error marking verified: {e}") @@ -266,10 +280,12 @@ class SessionState: """获取会话密钥""" return str(self.session.id) if self.session.id else "" - def get_key_for_event(self) -> str: + @property + def key_for_event(self) -> str: """获取用于事件广播的会话密钥""" return LoginSessionService.get_key_for_event(self.get_key()) + @property def user_id(self) -> int: """获取用户ID""" return self.user.id diff --git a/app/models/model.py b/app/models/model.py index fbcd52e..3224c99 100644 --- a/app/models/model.py +++ b/app/models/model.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from datetime import UTC, datetime from app.models.score import GameMode @@ -53,3 +54,33 @@ class CurrentUserAttributes(BaseModel): can_new_comment: bool | None = None can_new_comment_reason: str | None = None pin: PinAttributes | None = None + + +@dataclass +class UserAgentInfo: + raw_ua: str = "" + browser: str | None = None + version: str | None = None + os: str | None = None + platform: str | None = None + is_mobile: bool = False + is_tablet: bool = False + is_pc: bool = False + is_client: bool = False + + @property + def displayed_name(self) -> str: + parts = [] + if self.browser: + parts.append(self.browser) + if self.version: + parts.append(self.version) + if self.os: + if parts: + parts.append(f"on {self.os}") + else: + parts.append(self.os) + return " ".join(parts) if parts else "Unknown" + + def __str__(self) -> str: + return self.displayed_name diff --git a/app/router/auth.py b/app/router/auth.py index 7254b2c..544ce50 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -21,6 +21,7 @@ from app.database.auth import TotpKeys from app.database.statistics import UserStatistics from app.dependencies.database import Database, get_redis from app.dependencies.geoip import get_client_ip, get_geoip_helper +from app.dependencies.user_agent import UserAgentInfo from app.helpers.geoip_helper import GeoIPHelper from app.log import logger from app.models.extended_auth import ExtendedTokenResponse @@ -39,7 +40,7 @@ from app.service.verification_service import ( ) from app.utils import utcnow -from fastapi import APIRouter, Depends, Form, Request +from fastapi import APIRouter, Depends, Form, Header, Request from fastapi.responses import JSONResponse from redis.asyncio import Redis from sqlalchemy import text @@ -199,6 +200,7 @@ async def register_user( async def oauth_token( db: Database, request: Request, + user_agent: UserAgentInfo, grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form( ..., description="授权类型:密码/刷新令牌/授权码/客户端凭证" ), @@ -211,12 +213,10 @@ async def oauth_token( refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"), redis: Redis = Depends(get_redis), geoip: GeoIPHelper = Depends(get_geoip_helper), + web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"), ): scopes = scope.split(" ") - # 打印请求头 - # logger.info(f"Request headers: {request.headers}") - client = ( await db.exec( select(OAuthClient).where( @@ -306,19 +306,19 @@ async def oauth_token( access_token, refresh_token_str, settings.access_token_expire_minutes * 60, + settings.refresh_token_expire_minutes * 60, allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持 ) token_id = token.id ip_address = get_client_ip(request) - user_agent = request.headers.get("User-Agent", "") # 获取国家代码 geo_info = geoip.lookup(ip_address) country_code = geo_info.get("country_iso", "XX") # 检查是否为新位置登录 - is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code) + trusted_device = await LoginSessionService.check_trusted_device(db, user_id, ip_address, user_agent, web_uuid) session_verification_method = None if settings.enable_totp_verification and totp_key is not None: @@ -331,18 +331,12 @@ async def oauth_token( login_method="password_pending_verification", notes="需要 TOTP 验证", ) - elif is_new_location and settings.enable_email_verification: - # 如果是新位置登录,需要邮件验证 + elif not trusted_device and settings.enable_email_verification: + # 如果是新设备登录,需要邮件验证 # 刷新用户对象以确保属性已加载 await db.refresh(user) session_verification_method = "mail" - - # 使用智能验证发送邮件 - ( - verification_sent, - verification_message, - client_info, - ) = await EmailVerificationService.send_smart_verification_email( + await EmailVerificationService.send_verification_email( db, redis, user_id, @@ -350,36 +344,30 @@ async def oauth_token( user.email, ip_address, user_agent, - client_id, - country_code, - is_new_location, ) # 记录需要二次验证的登录尝试 - client_display_name = client_info.client_type if client_info else "unknown" await LoginLogService.record_login( db=db, user_id=user_id, request=request, login_success=True, login_method="password_pending_verification", - notes=f"智能验证: {verification_message} - 客户端: {client_display_name}, " - f"IP: {ip_address}, 国家: {country_code}", + notes=( + f"邮箱验证: User-Agent: {user_agent.raw_ua}, 客户端: {user_agent.displayed_name} " + f"IP: {ip_address}, 国家: {country_code}" + ), + ) + elif not trusted_device: + # 新设备登录但邮件验证功能被禁用,直接标记会话为已验证 + await LoginSessionService.mark_session_verified( + db, redis, user_id, token_id, ip_address, user_agent, web_uuid ) - - if not verification_sent: - # 邮件发送失败,记录错误 - logger.error(f"[Auth] Smart verification failed for user {user_id}: {verification_message}") - else: - logger.info(f"[Auth] Smart verification result for user {user_id}: {verification_message}") - elif is_new_location: - # 新位置登录但邮件验证功能被禁用,直接标记会话为已验证 - await LoginSessionService.mark_session_verified(db, redis, user_id, token_id) logger.debug( f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}" ) else: - # 不是新位置登录,正常登录 + # 不是新设备登录,正常登录 await LoginLogService.record_login( db=db, user_id=user_id, @@ -391,12 +379,12 @@ async def oauth_token( if session_verification_method: await LoginSessionService.create_session( - db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, False + db, user_id, token_id, ip_address, user_agent.raw_ua, trusted_device, web_uuid, False ) await LoginSessionService.set_login_method(user_id, token_id, session_verification_method, redis) else: await LoginSessionService.create_session( - db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, True + db, user_id, token_id, ip_address, user_agent.raw_ua, trusted_device, web_uuid, True ) return TokenResponse( @@ -449,6 +437,7 @@ async def oauth_token( access_token, new_refresh_token, settings.access_token_expire_minutes * 60, + settings.refresh_token_expire_minutes * 60, allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持 ) return TokenResponse( @@ -514,6 +503,7 @@ async def oauth_token( access_token, refresh_token_str, settings.access_token_expire_minutes * 60, + settings.refresh_token_expire_minutes * 60, allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持 ) @@ -561,6 +551,7 @@ async def oauth_token( access_token, refresh_token_str, settings.access_token_expire_minutes * 60, + settings.refresh_token_expire_minutes * 60, allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持 ) diff --git a/app/router/private/__init__.py b/app/router/private/__init__.py index f3a9590..f1715a1 100644 --- a/app/router/private/__init__.py +++ b/app/router/private/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations from app.config import settings -from . import audio_proxy, avatar, beatmapset, cover, oauth, relationship, score, team, username # noqa: F401 +from . import admin, audio_proxy, avatar, beatmapset, cover, oauth, relationship, score, team, username # noqa: F401 from .router import router as private_router if settings.enable_totp_verification: diff --git a/app/router/private/admin.py b/app/router/private/admin.py new file mode 100644 index 0000000..e29a264 --- /dev/null +++ b/app/router/private/admin.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from app.database.auth import OAuthToken +from app.database.verification import LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp +from app.dependencies.database import Database +from app.dependencies.geoip import get_geoip_helper +from app.dependencies.user import UserAndToken, get_client_user_and_token +from app.helpers.geoip_helper import GeoIPHelper + +from .router import router + +from fastapi import Depends, HTTPException, Security +from pydantic import BaseModel +from sqlmodel import col, select + + +class SessionsResp(BaseModel): + total: int + current: int = 0 + sessions: list[LoginSessionResp] + + +@router.get( + "/admin/sessions", + name="获取当前用户的登录会话列表", + tags=["用户会话", "g0v0 API", "管理"], + response_model=SessionsResp, +) +async def get_sessions( + session: Database, + user_and_token: UserAndToken = Security(get_client_user_and_token), + geoip: GeoIPHelper = Depends(get_geoip_helper), +): + current_user, token = user_and_token + sessions = ( + await session.exec( + select( + LoginSession, + ) + .where(LoginSession.user_id == current_user.id, col(LoginSession.is_verified).is_(True)) + .order_by(col(LoginSession.created_at).desc()) + ) + ).all() + return SessionsResp( + total=len(sessions), + current=token.id, + sessions=[LoginSessionResp.from_db(s, geoip) for s in sessions], + ) + + +@router.delete( + "/admin/sessions/{session_id}", + name="注销指定的登录会话", + tags=["用户会话", "g0v0 API", "管理"], + status_code=204, +) +async def delete_session( + session: Database, + session_id: int, + user_and_token: UserAndToken = Security(get_client_user_and_token), +): + current_user, token = user_and_token + if session_id == token.id: + raise HTTPException(status_code=400, detail="Cannot delete the current session") + + db_session = await session.get(LoginSession, session_id) + if not db_session or db_session.user_id != current_user.id: + raise HTTPException(status_code=404, detail="Session not found") + + await session.delete(db_session) + + token = await session.get(OAuthToken, db_session.token_id or 0) + if token: + await session.delete(token) + + await session.commit() + return + + +class TrustedDevicesResp(BaseModel): + total: int + current: int = 0 + devices: list[TrustedDeviceResp] + + +@router.get( + "/admin/trusted-devices", + name="获取当前用户的受信任设备列表", + tags=["用户会话", "g0v0 API", "管理"], + response_model=TrustedDevicesResp, +) +async def get_trusted_devices( + session: Database, + user_and_token: UserAndToken = Security(get_client_user_and_token), + geoip: GeoIPHelper = Depends(get_geoip_helper), +): + current_user, token = user_and_token + devices = ( + await session.exec( + select(TrustedDevice) + .where(TrustedDevice.user_id == current_user.id) + .order_by(col(TrustedDevice.last_used_at).desc()) + ) + ).all() + + current_device_id = ( + await session.exec( + select(TrustedDevice.id) + .join(LoginSession, col(LoginSession.device_id) == TrustedDevice.id) + .where( + LoginSession.token_id == token.id, + TrustedDevice.user_id == current_user.id, + ) + .limit(1) + ) + ).first() + + return TrustedDevicesResp( + total=len(devices), + current=current_device_id or 0, + devices=[TrustedDeviceResp.from_db(device, geoip) for device in devices], + ) + + +@router.delete( + "/admin/trusted-devices/{device_id}", + name="移除受信任设备", + tags=["用户会话", "g0v0 API", "管理"], + status_code=204, +) +async def delete_trusted_device( + session: Database, + device_id: int, + user_and_token: UserAndToken = Security(get_client_user_and_token), +): + current_user, token = user_and_token + device = await session.get(TrustedDevice, device_id) + current_device_id = ( + await session.exec( + select(TrustedDevice.id) + .join(LoginSession, col(LoginSession.device_id) == TrustedDevice.id) + .where( + LoginSession.token_id == token.id, + TrustedDevice.user_id == current_user.id, + ) + .limit(1) + ) + ).first() + if device_id == current_device_id: + raise HTTPException(status_code=400, detail="Cannot delete the current trusted device") + + if not device or device.user_id != current_user.id: + raise HTTPException(status_code=404, detail="Trusted device not found") + + await session.delete(device) + await session.commit() + return diff --git a/app/router/v2/me.py b/app/router/v2/me.py index cab441d..fe1e797 100644 --- a/app/router/v2/me.py +++ b/app/router/v2/me.py @@ -1,7 +1,6 @@ from __future__ import annotations from app.database import MeResp, User -from app.database.lazer_user import ALL_INCLUDED from app.dependencies import get_current_user from app.dependencies.database import Database from app.dependencies.user import UserAndToken, get_current_user_and_token @@ -33,7 +32,7 @@ async def get_user_info_with_ruleset( ruleset: GameMode = Path(description="指定 ruleset"), user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]), ): - user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, ruleset, token_id=user_and_token[1].id) + user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id) return user_resp @@ -48,7 +47,7 @@ async def get_user_info_default( session: Database, user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]), ): - user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, None, token_id=user_and_token[1].id) + user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id) return user_resp diff --git a/app/router/v2/session_verify.py b/app/router/v2/session_verify.py index c71cabc..81abb5d 100644 --- a/app/router/v2/session_verify.py +++ b/app/router/v2/session_verify.py @@ -8,12 +8,13 @@ from typing import Annotated, Literal from app.auth import check_totp_backup_code, verify_totp_key_with_replay_protection from app.config import settings -from app.const import BACKUP_CODE_LENGTH +from app.const import BACKUP_CODE_LENGTH, SUPPORT_TOTP_VERIFICATION_VER from app.database.auth import TotpKeys from app.dependencies.api_version import APIVersion from app.dependencies.database import Database, get_redis from app.dependencies.geoip import get_client_ip from app.dependencies.user import UserAndToken, get_client_user_and_token +from app.dependencies.user_agent import UserAgentInfo from app.log import logger from app.service.login_log_service import LoginLogService from app.service.verification_service import ( @@ -23,7 +24,7 @@ from app.service.verification_service import ( from .router import router -from fastapi import Depends, Form, HTTPException, Request, Security, status +from fastapi import Depends, Form, Header, HTTPException, Request, Security, status from fastapi.responses import JSONResponse, Response from pydantic import BaseModel from redis.asyncio import Redis @@ -62,9 +63,11 @@ async def verify_session( request: Request, db: Database, api_version: APIVersion, + user_agent: UserAgentInfo, redis: Annotated[Redis, Depends(get_redis)], verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 (g0v0 扩展支持)"), user_and_token: UserAndToken = Security(get_client_user_and_token), + web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"), ) -> Response: current_user = user_and_token[0] token_id = user_and_token[1].id @@ -74,11 +77,12 @@ async def verify_session( return Response(status_code=status.HTTP_204_NO_CONTENT) verify_method: str | None = ( - "mail" if api_version < 20250913 else await LoginSessionService.get_login_method(user_id, token_id, redis) + "mail" + if api_version < SUPPORT_TOTP_VERIFICATION_VER + else await LoginSessionService.get_login_method(user_id, token_id, redis) ) ip_address = get_client_ip(request) - user_agent = request.headers.get("User-Agent", "Unknown") login_method = "password" try: @@ -130,10 +134,11 @@ async def verify_session( user_id=user_id, request=request, login_method=login_method, + user_agent=user_agent.raw_ua, login_success=True, notes=f"{login_method} 验证成功", ) - await LoginSessionService.mark_session_verified(db, redis, user_id, token_id) + await LoginSessionService.mark_session_verified(db, redis, user_id, token_id, ip_address, user_agent, web_uuid) await db.commit() return Response(status_code=status.HTTP_204_NO_CONTENT) @@ -179,6 +184,7 @@ async def verify_session( async def reissue_verification_code( request: Request, db: Database, + user_agent: UserAgentInfo, api_version: APIVersion, redis: Annotated[Redis, Depends(get_redis)], user_and_token: UserAndToken = Security(get_client_user_and_token), @@ -198,7 +204,6 @@ async def reissue_verification_code( try: ip_address = get_client_ip(request) - user_agent = request.headers.get("User-Agent", "Unknown") user_id = current_user.id success, message = await EmailVerificationService.resend_verification_code( db, @@ -227,6 +232,7 @@ async def reissue_verification_code( ) async def fallback_email( db: Database, + user_agent: UserAgentInfo, request: Request, redis: Annotated[Redis, Depends(get_redis)], user_and_token: UserAndToken = Security(get_client_user_and_token), @@ -237,7 +243,6 @@ async def fallback_email( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退") ip_address = get_client_ip(request) - user_agent = request.headers.get("User-Agent", "Unknown") await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis) success, message = await EmailVerificationService.resend_verification_code( diff --git a/app/scheduler/database_cleanup_scheduler.py b/app/scheduler/database_cleanup_scheduler.py deleted file mode 100644 index afee738..0000000 --- a/app/scheduler/database_cleanup_scheduler.py +++ /dev/null @@ -1,122 +0,0 @@ -""" -数据库清理调度器 - 定时清理过期数据 -""" - -from __future__ import annotations - -import asyncio - -from app.dependencies.database import engine -from app.log import logger -from app.service.database_cleanup_service import DatabaseCleanupService - -from sqlmodel.ext.asyncio.session import AsyncSession - - -class DatabaseCleanupScheduler: - """数据库清理调度器""" - - def __init__(self): - self.running = False - self.task = None - - async def start(self): - """启动调度器""" - if self.running: - return - - self.running = True - self.task = asyncio.create_task(self._run_scheduler()) - logger.debug("Database cleanup scheduler started") - - async def stop(self): - """停止调度器""" - if not self.running: - return - - self.running = False - if self.task: - self.task.cancel() - try: - await self.task - except asyncio.CancelledError: - pass - logger.debug("Database cleanup scheduler stopped") - - async def _run_scheduler(self): - """运行调度器""" - while self.running: - try: - # 每小时运行一次清理 - await asyncio.sleep(3600) # 3600秒 = 1小时 - - if not self.running: - break - - await self._run_cleanup() - - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Database cleanup scheduler error: {e!s}") - # 发生错误后等待5分钟再继续 - await asyncio.sleep(300) - - async def _run_cleanup(self): - """执行清理任务""" - try: - async with AsyncSession(engine) as db: - logger.debug("Starting scheduled database cleanup...") - - # 清理过期的验证码 - expired_codes = await DatabaseCleanupService.cleanup_expired_verification_codes(db) - - # 清理过期的登录会话 - expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db) - - # 清理1小时前未验证的登录会话 - unverified_sessions = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1) - - # 只在有清理记录时输出总结 - total_cleaned = expired_codes + expired_sessions + unverified_sessions - if total_cleaned > 0: - logger.debug( - f"Scheduled cleanup completed - codes: {expired_codes}, " - f"sessions: {expired_sessions}, unverified: {unverified_sessions}" - ) - - except Exception as e: - logger.error(f"Error during scheduled database cleanup: {e!s}") - - async def run_manual_cleanup(self): - """手动运行完整清理""" - try: - async with AsyncSession(engine) as db: - logger.debug("Starting manual database cleanup...") - results = await DatabaseCleanupService.run_full_cleanup(db) - total = sum(results.values()) - if total > 0: - logger.debug(f"Manual cleanup completed, total records cleaned: {total}") - return results - except Exception as e: - logger.error(f"Error during manual database cleanup: {e!s}") - return {} - - -# 全局实例 -database_cleanup_scheduler = DatabaseCleanupScheduler() - - -async def start_database_cleanup_scheduler(): - """启动数据库清理调度器""" - await database_cleanup_scheduler.start() - - -async def stop_database_cleanup_scheduler(): - """停止数据库清理调度器""" - await database_cleanup_scheduler.stop() - - -async def run_manual_database_cleanup(): - """手动运行数据库清理""" - return await database_cleanup_scheduler.run_manual_cleanup() diff --git a/app/service/client_detection_service.py b/app/service/client_detection_service.py deleted file mode 100644 index 5fc98c4..0000000 --- a/app/service/client_detection_service.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -客户端检测服务 -用于识别不同类型的 osu! 客户端和设备 -""" - -from __future__ import annotations - -from dataclasses import dataclass -import hashlib -import re -from typing import ClassVar, Literal - -from app.log import logger - - -@dataclass -class ClientInfo: - """客户端信息""" - - client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"] - platform: str | None = None - version: str | None = None - device_fingerprint: str | None = None - is_trusted_client: bool = False - - -class ClientDetectionService: - """客户端检测服务""" - - # osu! 客户端的 User-Agent 模式 - OSU_CLIENT_PATTERNS: ClassVar[dict[str, list[str]]] = { - "osu_stable": [ - r"osu!/(\d+(?:\.\d+)*)", # osu!/20241001 - r"osu!", # 简单匹配 - ], - "osu_lazer": [ - r"osu-lazer/(\d+(?:\.\d+)*)", # osu-lazer/2024.1009.0 - r"osu!lazer/(\d+(?:\.\d+)*)", # osu!lazer/2024.1009.0 - ], - "osu_web": [ - r"Mozilla.*osu\.ppy\.sh", # 网页客户端 - ], - "mobile": [ - r"osu!.*mobile", - r"osu.*Mobile", - r"Mobile.*osu", - ], - } - - # 受信任的客户端类型(不需要频繁验证) - TRUSTED_CLIENT_TYPES: ClassVar[set[str]] = {"osu_stable", "osu_lazer"} - - @staticmethod - def detect_client(user_agent: str | None, client_id: int | None = None) -> ClientInfo: - """ - 检测客户端类型和信息 - - Args: - user_agent: 用户代理字符串 - client_id: OAuth 客户端 ID - - Returns: - ClientInfo: 客户端信息 - """ - from app.config import settings # 导入在函数内部避免循环导入 - - if not user_agent: - return ClientInfo(client_type="unknown") - - # 优先通过 client_id 判断客户端类型 - if client_id is not None: - if client_id == settings.osu_client_id: - # osu! stable 客户端 - return ClientInfo( - client_type="osu_stable", - platform=ClientDetectionService._extract_platform(user_agent), - device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), - is_trusted_client=True, - ) - elif client_id == settings.osu_web_client_id: - # 检查 User-Agent 是否表明这是 Lazer 客户端 - if user_agent and user_agent.strip() == "osu!": - # Lazer 客户端使用 web client_id 但发送简单的 "osu!" User-Agent - return ClientInfo( - client_type="osu_lazer", - platform=ClientDetectionService._extract_platform(user_agent), - device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), - is_trusted_client=True, - ) - else: - # 真正的 web 客户端 - return ClientInfo( - client_type="osu_web", - platform=ClientDetectionService._extract_platform(user_agent), - device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), - is_trusted_client=False, - ) - - # 回退到基于 User-Agent 的检测 - for client_type_str, patterns in ClientDetectionService.OSU_CLIENT_PATTERNS.items(): - for pattern in patterns: - match = re.search(pattern, user_agent, re.IGNORECASE) - if match: - version = match.group(1) if match.groups() else None - platform = ClientDetectionService._extract_platform(user_agent) - - # 确保 client_type 是正确的 Literal 类型 - client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"] = client_type_str # type: ignore - - return ClientInfo( - client_type=client_type, - platform=platform, - version=version, - device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), - is_trusted_client=client_type in ClientDetectionService.TRUSTED_CLIENT_TYPES, - ) - - # 检测常见浏览器 - if any(browser in user_agent.lower() for browser in ["chrome", "firefox", "safari", "edge"]): - return ClientInfo( - client_type="osu_web", - platform=ClientDetectionService._extract_platform(user_agent), - device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), - is_trusted_client=False, - ) - - return ClientInfo( - client_type="unknown", - device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent), - is_trusted_client=False, - ) - - @staticmethod - def _extract_platform(user_agent: str) -> str | None: - """从 User-Agent 中提取平台信息""" - platforms = { - "windows": ["windows", "win32", "win64"], - "macos": ["macintosh", "mac os", "darwin"], - "linux": ["linux", "ubuntu", "debian"], - "android": ["android"], - "ios": ["iphone", "ipad", "ios"], - } - - user_agent_lower = user_agent.lower() - for platform, keywords in platforms.items(): - if any(keyword in user_agent_lower for keyword in keywords): - return platform - - return None - - @staticmethod - def _generate_device_fingerprint(user_agent: str) -> str: - """生成设备指纹""" - # 使用 User-Agent 的哈希值作为简单的设备指纹 - # 在实际应用中可以结合更多信息(IP、屏幕分辨率等) - return hashlib.sha256(user_agent.encode()).hexdigest()[:16] - - @staticmethod - def should_skip_email_verification( - client_info: ClientInfo, - is_new_location: bool, - user_id: int, - ) -> bool: - """ - 判断是否应该跳过邮件验证 - - Args: - client_info: 客户端信息 - is_new_location: 是否为新位置登录 - user_id: 用户 ID - - Returns: - bool: 是否应该跳过邮件验证 - """ - # 受信任的客户端类型可以减少验证频率 - if client_info.is_trusted_client: - logger.info( - f"[Client Detection] Trusted client {client_info.client_type} for user {user_id}, " - f"reducing verification requirements" - ) - return True - - # 如果不是新位置,跳过验证 - if not is_new_location: - return True - - return False - - @staticmethod - def get_verification_cooldown(client_info: ClientInfo) -> int: - """ - 获取验证冷却时间(秒) - - Args: - client_info: 客户端信息 - - Returns: - int: 冷却时间(秒) - """ - # 受信任的客户端有更长的冷却时间 - if client_info.is_trusted_client: - return 3600 # 1小时 - - # 网页客户端较短的冷却时间 - if client_info.client_type == "osu_web": - return 1800 # 30分钟 - - # 未知客户端最短冷却时间 - return 900 # 15分钟 - - @staticmethod - def format_client_display_name(client_info: ClientInfo) -> str: - """格式化客户端显示名称""" - display_names = { - "osu_stable": "osu! (stable)", - "osu_lazer": "osu!(lazer)", - "osu_web": "osu! web", - "mobile": "osu! mobile", - "unknown": "Unknown client", - } - - base_name = display_names.get(client_info.client_type, "Unknown client") - - if client_info.version: - base_name += f" v{client_info.version}" - - if client_info.platform: - base_name += f" ({client_info.platform})" - - return base_name diff --git a/app/service/database_cleanup_service.py b/app/service/database_cleanup_service.py index 5d40db5..3acac4a 100644 --- a/app/service/database_cleanup_service.py +++ b/app/service/database_cleanup_service.py @@ -6,11 +6,14 @@ from __future__ import annotations from datetime import timedelta -from app.database.verification import EmailVerification, LoginSession +from app.database.auth import OAuthToken +from app.database.verification import EmailVerification, LoginSession, TrustedDevice +from app.dependencies.database import with_db +from app.dependencies.scheduler import get_scheduler from app.log import logger from app.utils import utcnow -from sqlmodel import col, select +from sqlmodel import col, func, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -69,7 +72,9 @@ class DatabaseCleanupService: # 查找过期的登录会话记录 current_time = utcnow() - stmt = select(LoginSession).where(LoginSession.expires_at < current_time) + stmt = select(LoginSession).where( + LoginSession.expires_at < current_time, col(LoginSession.is_verified).is_(False) + ) result = await db.exec(stmt) expired_sessions = result.all() @@ -179,50 +184,109 @@ class DatabaseCleanupService: return 0 @staticmethod - async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int: + async def cleanup_outdated_verified_sessions(db: AsyncSession) -> int: """ - 清理旧的已验证会话记录 + 清理过期会话记录 Args: db: 数据库会话 - days_old: 清理多少天前的已验证记录,默认30天 Returns: int: 清理的记录数 """ try: - # 查找指定天数前的已验证会话记录 - cutoff_time = utcnow() - timedelta(days=days_old) - - stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) + stmt = select(LoginSession).where( + col(LoginSession.is_verified).is_(True), col(LoginSession.token_id).is_(None) + ) result = await db.exec(stmt) - all_verified_sessions = result.all() - - # 筛选出过期的记录 - old_verified_sessions = [ - session - for session in all_verified_sessions - if session.verified_at and session.verified_at < cutoff_time - ] - # 删除旧的已验证记录 deleted_count = 0 - for session in old_verified_sessions: + for session in result.all(): await db.delete(session) deleted_count += 1 await db.commit() if deleted_count > 0: - logger.debug( - f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days" - ) + logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} outdated verified sessions") return deleted_count except Exception as e: await db.rollback() - logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}") + logger.error(f"[Cleanup Service] Error cleaning outdated verified sessions: {e!s}") + return 0 + + @staticmethod + async def cleanup_outdated_trusted_devices(db: AsyncSession) -> int: + """ + 清理过期的受信任设备记录 + + Args: + db: 数据库会话 + + Returns: + int: 清理的记录数 + """ + try: + # 查找过期的受信任设备记录 + current_time = utcnow() + + stmt = select(TrustedDevice).where(TrustedDevice.expires_at < current_time) + result = await db.exec(stmt) + expired_devices = result.all() + + # 删除过期的记录 + deleted_count = 0 + for device in expired_devices: + await db.delete(device) + deleted_count += 1 + + await db.commit() + + if deleted_count > 0: + logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired trusted devices") + + return deleted_count + + except Exception as e: + await db.rollback() + logger.error(f"[Cleanup Service] Error cleaning expired trusted devices: {e!s}") + return 0 + + @staticmethod + async def cleanup_outdated_tokens(db: AsyncSession) -> int: + """ + 清理过期的 OAuth 令牌 + + Args: + db: 数据库会话 + + Returns: + int: 清理的记录数 + """ + try: + current_time = utcnow() + + stmt = select(OAuthToken).where(OAuthToken.refresh_token_expires_at < current_time) + result = await db.exec(stmt) + expired_tokens = result.all() + + deleted_count = 0 + for token in expired_tokens: + await db.delete(token) + deleted_count += 1 + + await db.commit() + + if deleted_count > 0: + logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired OAuth tokens") + + return deleted_count + + except Exception as e: + await db.rollback() + logger.error(f"[Cleanup Service] Error cleaning expired OAuth tokens: {e!s}") return 0 @staticmethod @@ -250,8 +314,14 @@ class DatabaseCleanupService: # 清理7天前的已使用验证码 results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7) - # 清理30天前的已验证会话 - results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30) + # 清理过期的受信任设备 + results["outdated_trusted_devices"] = await DatabaseCleanupService.cleanup_outdated_trusted_devices(db) + + # 清理过期的 OAuth 令牌 + results["outdated_oauth_tokens"] = await DatabaseCleanupService.cleanup_outdated_tokens(db) + + # 清理过期(token 过期)的已验证会话 + results["outdated_verified_sessions"] = await DatabaseCleanupService.cleanup_outdated_verified_sessions(db) total_cleaned = sum(results.values()) if total_cleaned > 0: @@ -279,21 +349,27 @@ class DatabaseCleanupService: cutoff_30_days = current_time - timedelta(days=30) # 统计过期的验证码数量 - expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time) + expired_codes_stmt = ( + select(func.count()).select_from(EmailVerification).where(EmailVerification.expires_at < current_time) + ) expired_codes_result = await db.exec(expired_codes_stmt) - expired_codes_count = len(expired_codes_result.all()) + expired_codes_count = expired_codes_result.one() # 统计过期的登录会话数量 - expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time) + expired_sessions_stmt = ( + select(func.count()).select_from(LoginSession).where(LoginSession.expires_at < current_time) + ) expired_sessions_result = await db.exec(expired_sessions_stmt) - expired_sessions_count = len(expired_sessions_result.all()) + expired_sessions_count = expired_sessions_result.one() # 统计1小时前未验证的登录会话数量 - unverified_sessions_stmt = select(LoginSession).where( - col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour + unverified_sessions_stmt = ( + select(func.count()) + .select_from(LoginSession) + .where(col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour) ) unverified_sessions_result = await db.exec(unverified_sessions_stmt) - unverified_sessions_count = len(unverified_sessions_result.all()) + unverified_sessions_count = unverified_sessions_result.one() # 统计7天前的已使用验证码数量 old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True)) @@ -304,10 +380,10 @@ class DatabaseCleanupService: ) # 统计30天前的已验证会话数量 - old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) - old_verified_sessions_result = await db.exec(old_verified_sessions_stmt) - all_verified_sessions = old_verified_sessions_result.all() - old_verified_sessions_count = len( + outdated_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) + outdated_verified_sessions_result = await db.exec(outdated_verified_sessions_stmt) + all_verified_sessions = outdated_verified_sessions_result.all() + outdated_verified_sessions_count = len( [ session for session in all_verified_sessions @@ -315,17 +391,35 @@ class DatabaseCleanupService: ] ) + # 统计过期的 OAuth 令牌数量 + outdated_tokens_stmt = ( + select(func.count()).select_from(OAuthToken).where(OAuthToken.refresh_token_expires_at < current_time) + ) + outdated_tokens_result = await db.exec(outdated_tokens_stmt) + outdated_tokens_count = outdated_tokens_result.one() + + # 统计过期的受信任设备数量 + outdated_devices_stmt = ( + select(func.count()).select_from(TrustedDevice).where(TrustedDevice.expires_at < current_time) + ) + outdated_devices_result = await db.exec(outdated_devices_stmt) + outdated_devices_count = outdated_devices_result.one() + return { "expired_verification_codes": expired_codes_count, "expired_login_sessions": expired_sessions_count, "unverified_login_sessions": unverified_sessions_count, "old_used_verification_codes": old_used_codes_count, - "old_verified_sessions": old_verified_sessions_count, + "outdated_verified_sessions": outdated_verified_sessions_count, + "outdated_oauth_tokens": outdated_tokens_count, + "outdated_trusted_devices": outdated_devices_count, "total_cleanable": expired_codes_count + expired_sessions_count + unverified_sessions_count + old_used_codes_count - + old_verified_sessions_count, + + outdated_verified_sessions_count + + outdated_tokens_count + + outdated_devices_count, } except Exception as e: @@ -335,6 +429,23 @@ class DatabaseCleanupService: "expired_login_sessions": 0, "unverified_login_sessions": 0, "old_used_verification_codes": 0, - "old_verified_sessions": 0, + "outdated_verified_sessions": 0, + "outdated_oauth_tokens": 0, + "outdated_trusted_devices": 0, "total_cleanable": 0, } + + +@get_scheduler().scheduled_job( + "interval", + id="cleanup_database", + hours=1, +) +async def scheduled_cleanup_job(): + async with with_db() as session: + logger.debug("Starting database cleanup...") + results = await DatabaseCleanupService.run_full_cleanup(session) + total = sum(results.values()) + if total > 0: + logger.debug(f"Cleanup completed, total records cleaned: {total}") + return results diff --git a/app/service/device_trust_service.py b/app/service/device_trust_service.py deleted file mode 100644 index 1b4e623..0000000 --- a/app/service/device_trust_service.py +++ /dev/null @@ -1,283 +0,0 @@ -""" -设备信任服务 -管理用户的受信任设备,减少频繁验证 -""" - -from __future__ import annotations - -from datetime import timedelta - -from app.config import settings -from app.log import logger -from app.service.client_detection_service import ClientInfo -from app.utils import utcnow - -from redis.asyncio import Redis - - -class DeviceTrustService: - """设备信任服务""" - - @staticmethod - def _get_device_trust_key(user_id: int, device_fingerprint: str) -> str: - """获取设备信任的 Redis 键""" - return f"device_trust:{user_id}:{device_fingerprint}" - - @staticmethod - def _get_location_trust_key(user_id: int, country_code: str) -> str: - """获取位置信任的 Redis 键""" - return f"location_trust:{user_id}:{country_code}" - - @staticmethod - def _get_verification_cooldown_key(user_id: int) -> str: - """获取验证冷却的 Redis 键""" - return f"verification_cooldown:{user_id}" - - @staticmethod - async def is_device_trusted( - redis: Redis, - user_id: int, - device_fingerprint: str, - ) -> bool: - """ - 检查设备是否受信任 - - Args: - redis: Redis 连接 - user_id: 用户 ID - device_fingerprint: 设备指纹 - - Returns: - bool: 设备是否受信任 - """ - if not device_fingerprint: - return False - - trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint) - trust_data = await redis.get(trust_key) - - return trust_data is not None - - @staticmethod - async def is_location_trusted( - redis: Redis, - user_id: int, - country_code: str | None, - ) -> bool: - """ - 检查位置是否受信任 - - Args: - redis: Redis 连接 - user_id: 用户 ID - country_code: 国家代码 - - Returns: - bool: 位置是否受信任 - """ - if not country_code: - return False - - trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code) - trust_data = await redis.get(trust_key) - - return trust_data is not None - - @staticmethod - async def is_in_verification_cooldown( - redis: Redis, - user_id: int, - ) -> bool: - """ - 检查用户是否在验证冷却期内 - - Args: - redis: Redis 连接 - user_id: 用户 ID - - Returns: - bool: 是否在冷却期内 - """ - cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id) - cooldown_data = await redis.get(cooldown_key) - - return cooldown_data is not None - - @staticmethod - async def trust_device( - redis: Redis, - user_id: int, - device_fingerprint: str, - client_info: ClientInfo, - trust_duration_days: int | None = None, - ) -> None: - """ - 信任设备 - - Args: - redis: Redis 连接 - user_id: 用户 ID - device_fingerprint: 设备指纹 - client_info: 客户端信息 - trust_duration_days: 信任持续天数 - """ - if not device_fingerprint: - return - - # 使用配置中的默认值 - if trust_duration_days is None: - trust_duration_days = settings.device_trust_duration_days - - trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint) - trust_data = { - "client_type": client_info.client_type, - "platform": client_info.platform or "unknown", - "trusted_at": utcnow().isoformat(), - } - - # 设置信任期限 - trust_duration_seconds = trust_duration_days * 24 * 3600 - await redis.setex(trust_key, trust_duration_seconds, str(trust_data)) - - logger.info( - f"[Device Trust] Device trusted for user {user_id}: " - f"{client_info.client_type} on {client_info.platform} " - f"(fingerprint: {device_fingerprint[:8]}...)" - ) - - @staticmethod - async def trust_location( - redis: Redis, - user_id: int, - country_code: str, - trust_duration_days: int | None = None, - ) -> None: - """ - 信任位置 - - Args: - redis: Redis 连接 - user_id: 用户 ID - country_code: 国家代码 - trust_duration_days: 信任持续天数 - """ - if not country_code: - return - - # 使用配置中的默认值 - if trust_duration_days is None: - trust_duration_days = settings.location_trust_duration_days - - trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code) - trust_data = { - "country_code": country_code, - "trusted_at": utcnow().isoformat(), - } - - # 设置信任期限 - trust_duration_seconds = trust_duration_days * 24 * 3600 - await redis.setex(trust_key, trust_duration_seconds, str(trust_data)) - - logger.info(f"[Location Trust] Location trusted for user {user_id}: {country_code}") - - @staticmethod - async def set_verification_cooldown( - redis: Redis, - user_id: int, - cooldown_seconds: int, - ) -> None: - """ - 设置验证冷却期 - - Args: - redis: Redis 连接 - user_id: 用户 ID - cooldown_seconds: 冷却时间(秒) - """ - cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id) - cooldown_data = { - "set_at": utcnow().isoformat(), - "expires_at": (utcnow() + timedelta(seconds=cooldown_seconds)).isoformat(), - } - - await redis.setex(cooldown_key, cooldown_seconds, str(cooldown_data)) - - logger.info(f"[Verification Cooldown] Set cooldown for user {user_id}: {cooldown_seconds}s") - - @staticmethod - async def should_require_verification( - redis: Redis, - user_id: int, - device_fingerprint: str | None, - country_code: str | None, - client_info: ClientInfo, - is_new_location: bool, - ) -> tuple[bool, str]: - """ - 判断是否需要验证 - - Args: - redis: Redis 连接 - user_id: 用户 ID - device_fingerprint: 设备指纹 - country_code: 国家代码 - client_info: 客户端信息 - is_new_location: 是否为新位置 - - Returns: - tuple[bool, str]: (是否需要验证, 原因) - """ - # 检查验证冷却期 - if await DeviceTrustService.is_in_verification_cooldown(redis, user_id): - return False, "用户在验证冷却期内" - - # 检查设备信任 - if device_fingerprint and await DeviceTrustService.is_device_trusted(redis, user_id, device_fingerprint): - return False, "设备已受信任" - - # 检查位置信任 - if country_code and await DeviceTrustService.is_location_trusted(redis, user_id, country_code): - return False, "位置已受信任" - - # 受信任的客户端类型降低验证要求 - if client_info.is_trusted_client and not is_new_location: - return False, "受信任客户端且非新位置" - - # 如果是新位置登录,需要验证 - if is_new_location: - return True, "新位置登录需要验证" - - # 默认不需要验证 - return False, "常规登录无需验证" - - @staticmethod - async def mark_verification_successful( - redis: Redis, - user_id: int, - device_fingerprint: str | None, - country_code: str | None, - client_info: ClientInfo, - ) -> None: - """ - 标记验证成功,更新信任信息 - - Args: - redis: Redis 连接 - user_id: 用户 ID - device_fingerprint: 设备指纹 - country_code: 国家代码 - client_info: 客户端信息 - """ - # 信任设备 - if device_fingerprint: - await DeviceTrustService.trust_device(redis, user_id, device_fingerprint, client_info) - - # 信任位置 - if country_code: - await DeviceTrustService.trust_location(redis, user_id, country_code) - - # 设置验证冷却期 - cooldown_seconds = (client_info.is_trusted_client and 3600) or 1800 # 受信任客户端1小时,其他30分钟 - await DeviceTrustService.set_verification_cooldown(redis, user_id, cooldown_seconds) - - logger.info(f"[Device Trust] Verification successful for user {user_id}, trust updated") diff --git a/app/service/login_log_service.py b/app/service/login_log_service.py index d771019..6fa2f1a 100644 --- a/app/service/login_log_service.py +++ b/app/service/login_log_service.py @@ -9,7 +9,7 @@ import asyncio from app.database.user_login_log import UserLoginLog from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip from app.log import logger -from app.utils import simplify_user_agent, utcnow +from app.utils import utcnow from fastapi import Request from sqlmodel.ext.asyncio.session import AsyncSession @@ -23,6 +23,7 @@ class LoginLogService: db: AsyncSession, user_id: int, request: Request, + user_agent: str | None = None, login_success: bool = True, login_method: str = "password", notes: str | None = None, @@ -45,9 +46,6 @@ class LoginLogService: raw_ip = get_client_ip(request) ip_address = normalize_ip(raw_ip) - raw_user_agent = request.headers.get("User-Agent", "") - user_agent = simplify_user_agent(raw_user_agent, max_length=500) - # 创建基本的登录记录 login_log = UserLoginLog( user_id=user_id, @@ -107,6 +105,7 @@ class LoginLogService: attempted_username: str | None = None, login_method: str = "password", notes: str | None = None, + user_agent: str | None = None, ) -> UserLoginLog: """ 记录失败的登录尝试 @@ -128,6 +127,7 @@ class LoginLogService: request=request, login_success=False, login_method=login_method, + user_agent=user_agent, notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt", ) diff --git a/app/service/session_manager.py b/app/service/session_manager.py deleted file mode 100644 index 29e73f3..0000000 --- a/app/service/session_manager.py +++ /dev/null @@ -1,122 +0,0 @@ -""" -API 状态管理 - 模拟 osu! 的 APIState 和会话管理 -""" - -from __future__ import annotations - -from datetime import datetime -from enum import Enum - -from pydantic import BaseModel - - -class APIState(str, Enum): - """API 连接状态,对应 osu! 的 APIState""" - - OFFLINE = "offline" - CONNECTING = "connecting" - REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证 - ONLINE = "online" - FAILING = "failing" - - -class UserSession(BaseModel): - """用户会话信息""" - - user_id: int - username: str - email: str - session_token: str | None = None - state: APIState = APIState.OFFLINE - requires_verification: bool = False - verification_sent: bool = False - last_verification_attempt: datetime | None = None - failed_attempts: int = 0 - ip_address: str | None = None - country_code: str | None = None - is_new_location: bool = False - - -class SessionManager: - """会话管理器""" - - def __init__(self): - self._sessions: dict[str, UserSession] = {} - - def create_session( - self, - user_id: int, - username: str, - email: str, - ip_address: str, - country_code: str | None = None, - is_new_location: bool = False, - ) -> UserSession: - """创建新的用户会话""" - import secrets - - session_token = secrets.token_urlsafe(32) - - # 根据是否为新位置决定初始状态 - if is_new_location: - state = APIState.REQUIRES_SECOND_FACTOR_AUTH - else: - state = APIState.ONLINE - - session = UserSession( - user_id=user_id, - username=username, - email=email, - session_token=session_token, - state=state, - requires_verification=is_new_location, - ip_address=ip_address, - country_code=country_code, - is_new_location=is_new_location, - ) - - self._sessions[session_token] = session - return session - - def get_session(self, session_token: str) -> UserSession | None: - """获取会话""" - return self._sessions.get(session_token) - - def update_session_state(self, session_token: str, state: APIState): - """更新会话状态""" - if session_token in self._sessions: - self._sessions[session_token].state = state - - def mark_verification_sent(self, session_token: str): - """标记验证邮件已发送""" - if session_token in self._sessions: - session = self._sessions[session_token] - session.verification_sent = True - session.last_verification_attempt = datetime.now() - - def increment_failed_attempts(self, session_token: str): - """增加失败尝试次数""" - if session_token in self._sessions: - self._sessions[session_token].failed_attempts += 1 - - def verify_session(self, session_token: str) -> bool: - """验证会话成功""" - if session_token in self._sessions: - session = self._sessions[session_token] - session.state = APIState.ONLINE - session.requires_verification = False - return True - return False - - def remove_session(self, session_token: str): - """移除会话""" - self._sessions.pop(session_token, None) - - def cleanup_expired_sessions(self): - """清理过期会话""" - # 这里可以实现清理逻辑 - pass - - -# 全局会话管理器 -session_manager = SessionManager() diff --git a/app/service/verification_service.py b/app/service/verification_service.py index e4ce512..053486a 100644 --- a/app/service/verification_service.py +++ b/app/service/verification_service.py @@ -10,15 +10,15 @@ import string from typing import Literal from app.config import settings -from app.database.verification import EmailVerification, LoginSession +from app.database.auth import OAuthToken +from app.database.verification import EmailVerification, LoginSession, TrustedDevice from app.log import logger -from app.service.client_detection_service import ClientDetectionService, ClientInfo -from app.service.device_trust_service import DeviceTrustService -from app.service.email_queue import email_queue # 导入邮件队列 +from app.models.model import UserAgentInfo +from app.service.email_queue import email_queue from app.utils import utcnow from redis.asyncio import Redis -from sqlmodel import exists, select +from sqlmodel import col, exists, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -248,11 +248,9 @@ This email was sent automatically, please do not reply. username: str, email: str, ip_address: str | None = None, - user_agent: str | None = None, - client_id: int | None = None, - country_code: str | None = None, + user_agent: UserAgentInfo | None = None, ) -> bool: - """发送验证邮件(带智能检测)""" + """发送验证邮件""" try: # 检查是否启用邮件验证功能 if not settings.enable_email_verification: @@ -260,32 +258,14 @@ This email was sent automatically, please do not reply. return True # 返回成功,但不执行验证流程 # 检测客户端信息 - client_info = ClientDetectionService.detect_client(user_agent, client_id) - logger.info( - f"[Email Verification] Detected client for user {user_id}: " - f"{ClientDetectionService.format_client_display_name(client_info)}" - ) - - # 检查是否需要验证 - needs_verification, reason = await DeviceTrustService.should_require_verification( - redis=redis, - user_id=user_id, - device_fingerprint=client_info.device_fingerprint, - country_code=country_code, - client_info=client_info, - is_new_location=True, # 这里需要从调用方传入 - ) - - if not needs_verification: - logger.info(f"[Email Verification] Skipping verification for user {user_id}: {reason}") - return True + logger.info(f"[Email Verification] Detected client for user {user_id}: {user_agent}") # 创建验证记录 ( _, code, ) = await EmailVerificationService.create_verification_record( - db, redis, user_id, email, ip_address, user_agent + db, redis, user_id, email, ip_address, user_agent.raw_ua if user_agent else None ) # 使用邮件队列发送验证邮件 @@ -304,107 +284,6 @@ This email was sent automatically, please do not reply. logger.error(f"[Email Verification] Exception during sending verification email: {e}") return False - @staticmethod - async def send_smart_verification_email( - db: AsyncSession, - redis: Redis, - user_id: int, - username: str, - email: str, - ip_address: str | None = None, - user_agent: str | None = None, - client_id: int | None = None, - country_code: str | None = None, - is_new_location: bool = False, - ) -> tuple[bool, str, ClientInfo | None]: - """ - 智能邮件验证发送 - - Args: - db: 数据库会话 - redis: Redis 连接 - user_id: 用户 ID - username: 用户名 - email: 邮箱地址 - ip_address: IP 地址 - user_agent: 用户代理 - client_id: 客户端 ID - country_code: 国家代码 - is_new_location: 是否为新位置登录 - - Returns: - tuple[bool, str, ClientInfo | None]: (是否成功, 消息, 客户端信息) - """ - try: - # 检查是否启用邮件验证功能 - if not settings.enable_email_verification: - logger.debug(f"[Smart Verification] Email verification is disabled, skipping for user {user_id}") - return True, "邮件验证功能已禁用", None - - # 检查是否启用智能验证 - if not settings.enable_smart_verification: - logger.debug( - f"[Smart Verification] Smart verification is disabled, using legacy logic for user {user_id}" - ) - # 回退到传统验证逻辑 - verification, code = await EmailVerificationService.create_verification_record( - db, redis, user_id, email, ip_address, user_agent - ) - success = await EmailVerificationService.send_verification_email_via_queue( - email, code, username, user_id - ) - return success, "使用传统验证逻辑发送邮件" if success else "传统验证邮件发送失败", None - - # 检测客户端信息 - client_info = ClientDetectionService.detect_client(user_agent, client_id) - client_display_name = ClientDetectionService.format_client_display_name(client_info) - - logger.info(f"[Smart Verification] Detected client for user {user_id}: {client_display_name}") - - # 检查是否需要验证 - needs_verification, reason = await DeviceTrustService.should_require_verification( - redis=redis, - user_id=user_id, - device_fingerprint=client_info.device_fingerprint, - country_code=country_code, - client_info=client_info, - is_new_location=is_new_location, - ) - - if not needs_verification: - logger.info(f"[Smart Verification] Skipping verification for user {user_id}: {reason}") - - # 即使不需要验证,也要更新设备信任信息 - if client_info.device_fingerprint: - await DeviceTrustService.trust_device(redis, user_id, client_info.device_fingerprint, client_info) - if country_code: - await DeviceTrustService.trust_location(redis, user_id, country_code) - - return True, f"跳过验证: {reason}", client_info - - # 创建验证记录 - verification, code = await EmailVerificationService.create_verification_record( - db, redis, user_id, email, ip_address, user_agent - ) - _ = verification # 避免未使用变量警告 - - # 使用邮件队列发送验证邮件 - success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id) - - if success: - logger.info( - f"[Smart Verification] Successfully sent verification email to {email} " - f"for user {username} using {client_display_name}" - ) - return True, "验证邮件已发送", client_info - else: - logger.error(f"[Smart Verification] Failed to send verification email: {email} (user: {username})") - return False, "验证邮件发送失败", client_info - - except Exception as e: - logger.error(f"[Smart Verification] Exception during smart verification: {e}") - return False, f"验证过程中发生错误: {e!s}", None - @staticmethod async def verify_email_code( db: AsyncSession, @@ -416,7 +295,7 @@ This email was sent automatically, please do not reply. client_id: int | None = None, country_code: str | None = None, ) -> tuple[bool, str]: - """验证邮箱验证码(带智能信任更新)""" + """验证邮箱验证码""" try: # 检查是否启用邮件验证功能 if not settings.enable_email_verification: @@ -452,16 +331,6 @@ This email was sent automatically, please do not reply. # 删除 Redis 记录 await redis.delete(f"email_verification:{user_id}:{code}") - # 检测客户端信息并更新信任状态 - client_info = ClientDetectionService.detect_client(user_agent, client_id) - await DeviceTrustService.mark_verification_successful( - redis=redis, - user_id=user_id, - device_fingerprint=client_info.device_fingerprint, - country_code=country_code, - client_info=client_info, - ) - logger.info(f"[Email Verification] User {user_id} verification code verified successfully") return True, "验证成功" @@ -477,7 +346,7 @@ This email was sent automatically, please do not reply. username: str, email: str, ip_address: str | None = None, - user_agent: str | None = None, + user_agent: UserAgentInfo | None = None, ) -> tuple[bool, str]: """重新发送验证码""" try: @@ -516,12 +385,12 @@ class LoginSessionService: # Session verification interface methods @staticmethod - async def find_for_verification(db: AsyncSession, session_id: str) -> LoginSession | None: + async def find_for_verification(db: AsyncSession, token: str) -> LoginSession | None: """根据会话ID查找会话用于验证""" try: result = await db.exec( select(LoginSession).where( - LoginSession.session_token == session_id, + col(LoginSession.token).has(col(OAuthToken.access_token) == token), LoginSession.expires_at > utcnow(), ) ) @@ -537,42 +406,31 @@ class LoginSessionService: @staticmethod async def create_session( db: AsyncSession, - redis: Redis, user_id: int, token_id: int, ip_address: str, user_agent: str | None = None, - country_code: str | None = None, - is_new_location: bool = False, + is_new_device: bool = False, + web_uuid: str | None = None, is_verified: bool = False, ) -> LoginSession: """创建登录会话""" - - session_token = EmailVerificationService.generate_session_token() - session = LoginSession( user_id=user_id, token_id=token_id, ip_address=ip_address, - user_agent=None, - country_code=country_code, - is_new_location=is_new_location, + user_agent=user_agent, + is_new_device=is_new_device, expires_at=utcnow() + timedelta(hours=24), # 24小时过期 is_verified=is_verified, + web_uuid=web_uuid, ) db.add(session) await db.commit() await db.refresh(session) - # 存储到 Redis - await redis.setex( - f"login_session:{session_token}", - 86400, # 24小时 - user_id, - ) - - logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})") + logger.info(f"[Login Session] Created session for user {user_id} (new device: {is_new_device})") return session @classmethod @@ -592,35 +450,98 @@ class LoginSessionService: await redis.delete(cls._session_verify_redis_key(user_id, token_id)) @staticmethod - async def check_new_location( - db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None + async def check_trusted_device( + db: AsyncSession, user_id: int, ip_address: str, user_agent: UserAgentInfo, web_uuid: str | None = None ) -> bool: - """检查是否为新位置登录""" - try: - # 查看过去30天内是否有相同IP或相同国家的登录记录 - thirty_days_ago = utcnow() - timedelta(days=30) - - result = await db.exec( - select(LoginSession).where( - LoginSession.user_id == user_id, - LoginSession.created_at > thirty_days_ago, - (LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code), - ) + if user_agent.is_client: + query = select(exists()).where( + TrustedDevice.user_id == user_id, + TrustedDevice.client_type == "client", + TrustedDevice.ip_address == ip_address, + TrustedDevice.expires_at > utcnow(), ) - - existing_sessions = result.all() - - # 如果有历史记录,则不是新位置 - return len(existing_sessions) == 0 - - except Exception as e: - logger.error(f"[Login Session] Exception during new location check: {e}") - # 出错时默认为新位置(更安全) - return True + else: + if web_uuid is None: + return False + query = select(exists()).where( + TrustedDevice.user_id == user_id, + TrustedDevice.client_type == "web", + TrustedDevice.web_uuid == web_uuid, + TrustedDevice.expires_at > utcnow(), + ) + return (await db.exec(query)).first() or False @staticmethod - async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool: + async def create_trusted_device( + db: AsyncSession, + user_id: int, + ip_address: str, + user_agent: UserAgentInfo, + web_uuid: str | None = None, + ) -> TrustedDevice: + device = TrustedDevice( + user_id=user_id, + ip_address=ip_address, + user_agent=user_agent.raw_ua, + client_type="client" if user_agent.is_client else "web", + web_uuid=web_uuid if not user_agent.is_client else None, + expires_at=utcnow() + timedelta(days=settings.device_trust_duration_days), + ) + db.add(device) + await db.commit() + await db.refresh(device) + return device + + @staticmethod + async def get_or_create_trusted_device( + db: AsyncSession, + user_id: int, + ip_address: str, + user_agent: UserAgentInfo, + web_uuid: str | None = None, + ) -> TrustedDevice: + if user_agent.is_client: + query = select(TrustedDevice).where( + TrustedDevice.user_id == user_id, + TrustedDevice.client_type == "client", + TrustedDevice.ip_address == ip_address, + ) + else: + if web_uuid is None: + raise ValueError("web_uuid is required for web clients") + query = select(TrustedDevice).where( + TrustedDevice.user_id == user_id, + TrustedDevice.client_type == "web", + TrustedDevice.web_uuid == web_uuid, + ) + + device = (await db.exec(query)).first() + if device is None: + device = await LoginSessionService.create_trusted_device(db, user_id, ip_address, user_agent, web_uuid) + else: + device.last_used_at = utcnow() + device.expires_at = utcnow() + timedelta(days=settings.device_trust_duration_days) + await db.commit() + await db.refresh(device) + return device + + @staticmethod + async def mark_session_verified( + db: AsyncSession, + redis: Redis, + user_id: int, + token_id: int, + ip_address: str, + user_agent: UserAgentInfo, + web_uuid: str | None = None, + ) -> bool: """标记用户的未验证会话为已验证""" + device_info: TrustedDevice | None = None + if user_agent.is_client or web_uuid: + device_info = await LoginSessionService.get_or_create_trusted_device( + db, user_id, ip_address, user_agent, web_uuid + ) + try: # 查找用户所有未验证且未过期的会话 result = await db.exec( @@ -631,18 +552,20 @@ class LoginSessionService: LoginSession.token_id == token_id, ) ) - sessions = result.all() # 标记所有会话为已验证 for session in sessions: session.is_verified = True session.verified_at = utcnow() + if device_info: + session.device_id = device_info.id if sessions: logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}") await LoginSessionService.clear_login_method(user_id, token_id, redis) + await db.commit() return len(sessions) > 0 @@ -658,7 +581,7 @@ class LoginSessionService: await db.exec( select(exists()).where( LoginSession.user_id == user_id, - LoginSession.is_verified == False, # noqa: E712 + col(LoginSession.is_verified).is_(False), LoginSession.expires_at > utcnow(), LoginSession.token_id == token_id, ) diff --git a/app/utils.py b/app/utils.py index 563bd4a..9a610e0 100644 --- a/app/utils.py +++ b/app/utils.py @@ -6,11 +6,15 @@ from datetime import UTC, datetime import functools import inspect from io import BytesIO -from typing import Any, ParamSpec, TypeVar +import re +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar from fastapi import HTTPException from PIL import Image +if TYPE_CHECKING: + from app.models.model import UserAgentInfo + def unix_timestamp_to_windows(timestamp: int) -> int: """Convert a Unix timestamp to a Windows timestamp.""" @@ -154,81 +158,79 @@ def check_image(content: bytes, size: int, width: int, height: int) -> str: raise HTTPException(status_code=400, detail=f"Error processing image: {e}") -def simplify_user_agent(user_agent: str | None, max_length: int = 200) -> str | None: - """ - 简化 User-Agent 字符串,只保留 osu! 和关键设备系统信息浏览器 +def extract_user_agent(user_agent: str | None) -> "UserAgentInfo": + from app.models.model import UserAgentInfo - Args: - user_agent: 原始 User-Agent 字符串 - max_length: 最大长度限制 + raw_ua = user_agent or "" + ua = raw_ua.strip() + lower_ua = ua.lower() - Returns: - 简化后的 User-Agent 字符串,或 None - """ - import re + info = UserAgentInfo(raw_ua=raw_ua) - if not user_agent: - return None + if not ua: + return info - # 如果长度在限制内,直接返回 - if len(user_agent) <= max_length: - return user_agent + client_identifiers = ("osu!", "osu!lazer", "osu-framework") + if any(identifier in lower_ua for identifier in client_identifiers): + info.browser = "osu!" + info.is_client = True + return info - # 提取操作系统信息 - os_info = "" - os_patterns = [ - r"(Windows[^;)]*)", - r"(Mac OS[^;)]*)", - r"(Linux[^;)]*)", - r"(Android[^;)]*)", - r"(iOS[^;)]*)", - r"(iPhone[^;)]*)", - r"(iPad[^;)]*)", - ] + browser_patterns: tuple[tuple[re.Pattern[str], str], ...] = ( + (re.compile(r"OPR/(\d+(?:\.\d+)*)"), "Opera"), + (re.compile(r"Edg/(\d+(?:\.\d+)*)"), "Edge"), + (re.compile(r"Chrome/(\d+(?:\.\d+)*)"), "Chrome"), + (re.compile(r"Firefox/(\d+(?:\.\d+)*)"), "Firefox"), + (re.compile(r"Version/(\d+(?:\.\d+)*).*Safari"), "Safari"), + (re.compile(r"Safari/(\d+(?:\.\d+)*)"), "Safari"), + (re.compile(r"MSIE (\d+(?:\.\d+)*)"), "Internet Explorer"), + (re.compile(r"Trident/.*rv:(\d+(?:\.\d+)*)"), "Internet Explorer"), + ) - for pattern in os_patterns: - match = re.search(pattern, user_agent, re.IGNORECASE) + for pattern, name in browser_patterns: + match = pattern.search(ua) if match: - os_info = match.group(1).strip() + info.browser = name + info.version = match.group(1) break - # 提取浏览器信息 - browser_info = "" - browser_patterns = [ - r"(osu![^)]*)", # osu! 客户端 - r"(Chrome/[\d.]+)", - r"(Firefox/[\d.]+)", - r"(Safari/[\d.]+)", - r"(Edge/[\d.]+)", - r"(Opera/[\d.]+)", - ] + os_patterns: tuple[tuple[re.Pattern[str], str], ...] = ( + (re.compile(r"windows nt 10"), "Windows 10"), + (re.compile(r"windows nt 6\.3"), "Windows 8.1"), + (re.compile(r"windows nt 6\.2"), "Windows 8"), + (re.compile(r"windows nt 6\.1"), "Windows 7"), + (re.compile(r"windows nt 6\.0"), "Windows Vista"), + (re.compile(r"windows nt 5\.1"), "Windows XP"), + (re.compile(r"mac os x"), "macOS"), + (re.compile(r"iphone os"), "iOS"), + (re.compile(r"ipad;"), "iPadOS"), + (re.compile(r"android"), "Android"), + (re.compile(r"linux"), "Linux"), + ) - for pattern in browser_patterns: - match = re.search(pattern, user_agent, re.IGNORECASE) - if match: - browser_info = match.group(1).strip() - # 如果找到了 osu! 客户端,优先使用 - if "osu!" in browser_info.lower(): - break + for pattern, name in os_patterns: + if pattern.search(lower_ua): + info.os = name + break - # 构建简化的 User-Agent - parts = [] - if os_info: - parts.append(os_info) - if browser_info: - parts.append(browser_info) + info.is_mobile = any(keyword in lower_ua for keyword in ("mobile", "iphone", "android", "ipod")) + info.is_tablet = any(keyword in lower_ua for keyword in ("ipad", "tablet")) + # Only classify as PC if not mobile or tablet + if ( + not info.is_mobile + and not info.is_tablet + and any(keyword in lower_ua for keyword in ("windows", "macintosh", "linux", "x11")) + ): + info.is_pc = True - if parts: - simplified = "; ".join(parts) - else: - # 如果没有识别到关键信息,截断原始字符串 - simplified = user_agent[: max_length - 3] + "..." + if info.is_tablet: + info.platform = "tablet" + elif info.is_mobile: + info.platform = "mobile" + elif info.is_pc: + info.platform = "pc" - # 确保不超过最大长度 - if len(simplified) > max_length: - simplified = simplified[: max_length - 3] + "..." - - return simplified + return info # https://github.com/encode/starlette/blob/master/starlette/_utils.py diff --git a/main.py b/main.py index 19cbde7..d5c7f29 100644 --- a/main.py +++ b/main.py @@ -25,10 +25,6 @@ from app.router import ( from app.router.redirect import redirect_router from app.router.v1 import api_v1_public_router from app.scheduler.cache_scheduler import start_cache_scheduler, stop_cache_scheduler -from app.scheduler.database_cleanup_scheduler import ( - start_database_cleanup_scheduler, - stop_database_cleanup_scheduler, -) from app.service.beatmap_download_service import download_service from app.service.beatmapset_update_service import init_beatmapset_update_service from app.service.calculate_all_user_rank import calculate_user_rank @@ -68,7 +64,6 @@ async def lifespan(app: FastAPI): await start_email_processor() # 启动邮件队列处理器 await download_service.start_health_check() # 启动下载服务健康检查 await start_cache_scheduler() # 启动缓存调度器 - await start_database_cleanup_scheduler() # 启动数据库清理调度器 init_beatmapset_update_service(fetcher) # 初始化谱面集更新服务 redis_message_system.start() # 启动 Redis 消息系统 load_achievements() @@ -83,7 +78,6 @@ async def lifespan(app: FastAPI): stop_scheduler() redis_message_system.stop() # 停止 Redis 消息系统 await stop_cache_scheduler() # 停止缓存调度器 - await stop_database_cleanup_scheduler() # 停止数据库清理调度器 await download_service.stop_health_check() # 停止下载服务健康检查 await stop_email_processor() # 停止邮件队列处理器 await engine.dispose() diff --git a/migrations/versions/2025-10-02_72a9b8f3f863_session_support_multi_session.py b/migrations/versions/2025-10-02_72a9b8f3f863_session_support_multi_session.py new file mode 100644 index 0000000..6e47be4 --- /dev/null +++ b/migrations/versions/2025-10-02_72a9b8f3f863_session_support_multi_session.py @@ -0,0 +1,102 @@ +"""session: support multi-session + +Revision ID: 72a9b8f3f863 +Revises: b1ac2154bd0d +Create Date: 2025-10-02 07:17:19.297498 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "72a9b8f3f863" +down_revision: str | Sequence[str] | None = "b1ac2154bd0d" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "trusted_devices", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("ip_address", sa.VARCHAR(length=45), nullable=False), + sa.Column("user_agent", sa.Text(), nullable=False), + sa.Column("client_type", sa.VARCHAR(length=10), nullable=False), + sa.Column("web_uuid", sa.VARCHAR(length=36), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("last_used_at", sa.DateTime(), nullable=False), + sa.Column("expires_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.alter_column( + "login_sessions", + "is_new_location", + new_column_name="is_new_device", + existing_type=mysql.TINYINT(display_width=1), + ) + op.create_index(op.f("ix_trusted_devices_user_id"), "trusted_devices", ["user_id"], unique=False) + op.add_column("login_sessions", sa.Column("web_uuid", sa.VARCHAR(length=36), nullable=True)) + op.alter_column( + "login_sessions", + "ip_address", + existing_type=mysql.VARCHAR(length=255), + type_=sa.VARCHAR(length=45), + existing_nullable=False, + ) + op.alter_column( + "login_sessions", "user_agent", existing_type=mysql.VARCHAR(length=250), type_=sa.Text(), existing_nullable=True + ) + op.drop_index(op.f("ix_login_sessions_session_token"), table_name="login_sessions") + op.create_foreign_key(None, "login_sessions", "lazer_users", ["user_id"], ["id"]) + op.drop_column("login_sessions", "country_code") + op.drop_column("login_sessions", "session_token") + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("login_sessions", sa.Column("session_token", sa.VARCHAR(length=64), nullable=True)) + op.add_column("login_sessions", sa.Column("country_code", sa.VARCHAR(length=255), nullable=True)) + op.create_index(op.f("ix_login_sessions_session_token"), "login_sessions", ["session_token"], unique=False) + + op.alter_column( + "login_sessions", + "user_agent", + existing_type=sa.Text(), + type_=mysql.VARCHAR(length=250), + existing_nullable=True, + ) + op.alter_column( + "login_sessions", + "ip_address", + existing_type=sa.String(length=45), + type_=mysql.VARCHAR(length=255), + existing_nullable=False, + ) + + op.drop_column("login_sessions", "web_uuid") + op.alter_column( + "login_sessions", + "is_new_device", + new_column_name="is_new_location", + existing_type=mysql.TINYINT(display_width=1), + ) + op.drop_constraint(op.f("fk_login_sessions_user_id_lazer_users"), "login_sessions", type_="foreignkey") + + op.drop_index(op.f("ix_trusted_devices_user_id"), table_name="trusted_devices") + op.drop_table("trusted_devices") + # ### end Alembic commands ### diff --git a/migrations/versions/2025-10-02_7fe1319250c5_auth_add_refresh_token_expires_at.py b/migrations/versions/2025-10-02_7fe1319250c5_auth_add_refresh_token_expires_at.py new file mode 100644 index 0000000..07f8bc9 --- /dev/null +++ b/migrations/versions/2025-10-02_7fe1319250c5_auth_add_refresh_token_expires_at.py @@ -0,0 +1,40 @@ +"""auth: add refresh_token_expires_at + +Revision ID: 7fe1319250c5 +Revises: 72a9b8f3f863 +Create Date: 2025-10-02 10:50:21.169065 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "7fe1319250c5" +down_revision: str | Sequence[str] | None = "72a9b8f3f863" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("oauth_tokens", sa.Column("refresh_token_expires_at", sa.DateTime(), nullable=True)) + op.create_index(op.f("ix_oauth_tokens_expires_at"), "oauth_tokens", ["expires_at"], unique=False) + op.create_index( + op.f("ix_oauth_tokens_refresh_token_expires_at"), "oauth_tokens", ["refresh_token_expires_at"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_oauth_tokens_refresh_token_expires_at"), table_name="oauth_tokens") + op.drop_index(op.f("ix_oauth_tokens_expires_at"), table_name="oauth_tokens") + op.drop_column("oauth_tokens", "refresh_token_expires_at") + # ### end Alembic commands ### diff --git a/migrations/versions/2025-10-02_9556cd2ec11f_session_add_device_id_to_loginsession.py b/migrations/versions/2025-10-02_9556cd2ec11f_session_add_device_id_to_loginsession.py new file mode 100644 index 0000000..bb258a7 --- /dev/null +++ b/migrations/versions/2025-10-02_9556cd2ec11f_session_add_device_id_to_loginsession.py @@ -0,0 +1,35 @@ +"""session: add device_id to LoginSession + +Revision ID: 9556cd2ec11f +Revises: 7fe1319250c5 +Create Date: 2025-10-02 11:03:09.803140 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "9556cd2ec11f" +down_revision: str | Sequence[str] | None = "7fe1319250c5" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("login_sessions", sa.Column("device_id", sa.BigInteger(), nullable=True)) + op.create_index(op.f("ix_login_sessions_device_id"), "login_sessions", ["device_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("login_sessions", "device_id") + # ### end Alembic commands ###