feat(auth): support trusted device (#52)

New API to maintain sessions and devices:

- GET /api/private/admin/sessions
- DELETE /api/private/admin/sessions/{session_id}
- GET /api/private/admin/trusted-devices
- DELETE /api/private/admin/trusted-devices/{device_id}

Auth:

web clients request `/oauth/token` and `/api/v2/session/verify` with `X-UUID` header to save the client as trusted device.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
MingxuanGame
2025-10-03 11:26:43 +08:00
committed by GitHub
parent f34ed53a55
commit 40670c094b
28 changed files with 897 additions and 1456 deletions

View File

@@ -217,10 +217,12 @@ async def store_token(
access_token: str,
refresh_token: str,
expires_in: int,
refresh_token_expires_in: int,
allow_multiple_devices: bool = True,
) -> OAuthToken:
"""存储令牌到数据库(支持多设备)"""
expires_at = utcnow() + timedelta(seconds=expires_in)
refresh_token_expires_at = utcnow() + timedelta(seconds=refresh_token_expires_in)
if not allow_multiple_devices:
# 旧的行为:删除用户的旧令牌(单设备模式)
@@ -266,6 +268,7 @@ async def store_token(
scope=",".join(scopes),
refresh_token=refresh_token,
expires_at=expires_at,
refresh_token_expires_at=refresh_token_expires_at,
)
db.add(token_record)
await db.commit()
@@ -290,7 +293,7 @@ async def get_token_by_refresh_token(db: AsyncSession, refresh_token: str) -> OA
"""根据刷新令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.refresh_token == refresh_token,
OAuthToken.expires_at > utcnow(),
OAuthToken.refresh_token_expires_at > utcnow(),
)
return (await db.exec(statement)).first()

View File

@@ -170,6 +170,11 @@ STORAGE_SETTINGS='{
Field(default=1440, description="访问令牌过期时间(分钟)"),
"JWT 设置",
]
refresh_token_expire_minutes: Annotated[
int,
Field(default=21600, description="刷新令牌过期时间(分钟)"),
"JWT 设置",
] # 15 days
jwt_audience: Annotated[
str,
Field(default="5", description="JWT 受众"),
@@ -349,11 +354,6 @@ STORAGE_SETTINGS='{
Field(default=30, description="设备信任持续天数"),
"验证服务设置",
]
location_trust_duration_days: Annotated[
int,
Field(default=90, description="位置信任持续天数"),
"验证服务设置",
]
smtp_server: Annotated[
str,
Field(default="localhost", description="SMTP 服务器地址"),

View File

@@ -3,3 +3,5 @@ from __future__ import annotations
BANCHOBOT_ID = 2
BACKUP_CODE_LENGTH = 10
SUPPORT_TOTP_VERIFICATION_VER = 20250913

View File

@@ -68,7 +68,7 @@ from .user_account_history import (
UserAccountHistoryType,
)
from .user_login_log import UserLoginLog
from .verification import EmailVerification, LoginSession
from .verification import EmailVerification, LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp
__all__ = [
"APIUploadedRoom",
@@ -96,6 +96,7 @@ __all__ = [
"ItemAttemptsCount",
"ItemAttemptsResp",
"LoginSession",
"LoginSessionResp",
"MeResp",
"MonthlyPlaycounts",
"MultiplayerEvent",
@@ -131,6 +132,8 @@ __all__ = [
"TeamMember",
"TeamRequest",
"TotpKeys",
"TrustedDevice",
"TrustedDeviceResp",
"User",
"UserAccountHistory",
"UserAccountHistoryResp",

View File

@@ -32,7 +32,8 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
refresh_token: str = Field(max_length=500, unique=True)
token_type: str = Field(default="Bearer", max_length=20)
scope: str = Field(default="*", max_length=100)
expires_at: datetime = Field(sa_column=Column(DateTime))
expires_at: datetime = Field(sa_column=Column(DateTime, index=True))
refresh_token_expires_at: datetime = Field(sa_column=Column(DateTime, index=True))
created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime))
user: "User" = Relationship()

View File

@@ -243,7 +243,6 @@ class UserResp(UserBase):
user_achievements: list[UserAchievementResp] = Field(default_factory=list)
cover_url: str = "" # deprecated
team: Team | None = None
session_verified: bool = True
daily_challenge_user_stats: DailyChallengeStatsResp | None = None
default_group: str = ""
is_deleted: bool = False # TODO
@@ -425,27 +424,18 @@ class UserResp(UserBase):
)
).one()
if "session_verified" in include:
from app.service.verification_service import LoginSessionService
u.session_verified = (
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
if token_id
else True
)
return u
class MeResp(UserResp):
session_verification_method: Literal["totp", "mail"] | None = None
session_verified: bool = True
@classmethod
async def from_db(
cls,
obj: User,
session: AsyncSession,
include: list[str] = [],
ruleset: GameMode | None = None,
*,
token_id: int | None = None,
@@ -453,7 +443,12 @@ class MeResp(UserResp):
from app.dependencies.database import get_redis
from app.service.verification_service import LoginSessionService
u = await super().from_db(obj, session, ["session_verified", *include], ruleset, token_id=token_id)
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id)
u.session_verified = (
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
if token_id
else True
)
u = cls.model_validate(u.model_dump())
if (settings.enable_totp_verification or settings.enable_email_verification) and token_id:
redis = get_redis()

View File

@@ -3,17 +3,26 @@
"""
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional
from app.utils import utcnow
from app.helpers.geoip_helper import GeoIPHelper
from app.models.model import UserAgentInfo, UTCBaseModel
from app.utils import extract_user_agent, utcnow
from pydantic import BaseModel
from sqlalchemy import BigInteger, Column, ForeignKey
from sqlmodel import Field, Integer, Relationship, SQLModel
from sqlmodel import VARCHAR, DateTime, Field, Integer, Relationship, SQLModel, Text
if TYPE_CHECKING:
from .auth import OAuthToken
class Location(BaseModel):
country: str = ""
city: str = ""
country_code: str = ""
class EmailVerification(SQLModel, table=True):
"""邮件验证记录"""
@@ -31,25 +40,90 @@ class EmailVerification(SQLModel, table=True):
user_agent: str | None = Field(default=None) # 用户代理
class LoginSession(SQLModel, table=True):
class LoginSessionBase(SQLModel):
"""登录会话记录"""
__tablename__: str = "login_sessions"
id: int | None = Field(default=None, primary_key=True)
id: int = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
token_id: int | None = Field(
sa_column=Column(Integer, ForeignKey("oauth_tokens.id", ondelete="SET NULL"), nullable=True, index=True)
)
ip_address: str = Field() # 登录IP
user_agent: str | None = Field(default=None, max_length=250)
country_code: str | None = Field(default=None)
ip_address: str = Field(sa_column=Column(VARCHAR(45), nullable=False), default="127.0.0.1", exclude=True)
user_agent: str | None = Field(default=None, sa_column=Column(Text))
is_verified: bool = Field(default=False) # 是否已验证
created_at: datetime = Field(default_factory=lambda: utcnow())
verified_at: datetime | None = Field(default=None)
expires_at: datetime = Field() # 会话过期时间
is_new_location: bool = Field(default=False) # 是否新位置登录
session_token: str | None = Field(default=None, max_length=64, index=True) # 会话令牌
verification_method: str | None = Field(default=None, max_length=20) # 验证方法 (totp/mail)
device_id: int | None = Field(
sa_column=Column(BigInteger, ForeignKey("trusted_devices.id", ondelete="SET NULL"), nullable=True, index=True),
default=None,
)
class LoginSession(LoginSessionBase, table=True):
__tablename__: str = "login_sessions"
token_id: int | None = Field(
sa_column=Column(Integer, ForeignKey("oauth_tokens.id", ondelete="SET NULL"), nullable=True, index=True),
exclude=True,
)
is_new_device: bool = Field(default=False, exclude=True) # 是否新位置登录
web_uuid: str | None = Field(sa_column=Column(VARCHAR(36), nullable=True), default=None, exclude=True)
verification_method: str | None = Field(default=None, max_length=20, exclude=True) # 验证方法 (totp/mail)
device: Optional["TrustedDevice"] = Relationship(back_populates="sessions")
token: Optional["OAuthToken"] = Relationship(back_populates="login_session")
class LoginSessionResp(UTCBaseModel, LoginSessionBase):
user_agent_info: UserAgentInfo | None = None
location: Location | None = None
@classmethod
def from_db(cls, obj: LoginSession, get_geoip_helper: GeoIPHelper) -> "LoginSessionResp":
session = cls.model_validate(obj.model_dump())
session.user_agent_info = extract_user_agent(session.user_agent)
if obj.ip_address:
loc = get_geoip_helper.lookup(obj.ip_address)
session.location = Location(
country=loc.get("country_name", ""),
city=loc.get("city_name", ""),
country_code=loc.get("country_code", ""),
)
else:
session.location = None
return session
class TrustedDeviceBase(SQLModel):
id: int = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
ip_address: str = Field(sa_column=Column(VARCHAR(45), nullable=False), default="127.0.0.1", exclude=True)
user_agent: str = Field(sa_column=Column(Text, nullable=False))
client_type: Literal["web", "client"] = Field(sa_column=Column(VARCHAR(10), nullable=False), default="web")
created_at: datetime = Field(default_factory=utcnow)
last_used_at: datetime = Field(default_factory=utcnow)
expires_at: datetime = Field(sa_column=Column(DateTime))
class TrustedDevice(TrustedDeviceBase, table=True):
__tablename__: str = "trusted_devices"
web_uuid: str | None = Field(sa_column=Column(VARCHAR(36), nullable=True), default=None)
sessions: list["LoginSession"] = Relationship(back_populates="device", passive_deletes=True)
class TrustedDeviceResp(UTCBaseModel, TrustedDeviceBase):
user_agent_info: UserAgentInfo | None = None
location: Location | None = None
@classmethod
def from_db(cls, device: TrustedDevice, get_geoip_helper: GeoIPHelper) -> "TrustedDeviceResp":
device_ = cls.model_validate(device.model_dump())
device_.user_agent_info = extract_user_agent(device_.user_agent)
if device_.ip_address:
loc = get_geoip_helper.lookup(device_.ip_address)
device_.location = Location(
country=loc.get("country_name", ""),
city=loc.get("city_name", ""),
country_code=loc.get("country_code", ""),
)
else:
device_.location = None
return device_

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from typing import Annotated
from app.models.model import UserAgentInfo as UserAgentInfoModel
from app.utils import extract_user_agent
from fastapi import Depends, Header
def get_user_agent_info(user_agent: str | None = Header(None, include_in_schema=False)) -> UserAgentInfoModel:
return extract_user_agent(user_agent)
UserAgentInfo = Annotated[UserAgentInfoModel, Depends(get_user_agent_info)]

View File

@@ -1,301 +0,0 @@
"""
会话验证中间件和状态管理
基于osu-web的会话验证系统实现
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import ClassVar, Literal, cast
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import get_redis, with_db
from app.log import logger
from app.service.verification_service import LoginSessionService
from app.utils import bg_tasks
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlmodel.ext.asyncio.session import AsyncSession
class SessionVerificationState:
"""会话验证状态管理类
参考osu-web的State类实现
"""
def __init__(self, session: LoginSession, user: User, redis: Redis):
self.session = session
self.user = user
self.redis = redis
@classmethod
async def get_current(
cls,
request: Request,
db: AsyncSession,
redis: Redis,
user: User,
) -> SessionVerificationState | None:
"""获取当前会话验证状态"""
try:
# 从请求头或token中获取会话信息
session_token = cls._extract_session_token(request)
if not session_token:
return None
# 查找会话
session = await LoginSessionService.find_for_verification(db, session_token)
if not session or session.user_id != user.id:
return None
return cls(session, user, redis)
except Exception as e:
logger.error(f"[Session Verification] Error getting current state: {e}")
return None
@staticmethod
def _extract_session_token(request: Request) -> str | None:
"""从请求中提取会话token"""
# 尝试从Authorization header提取
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return auth_header[7:] # 移除"Bearer "前缀
# 可以扩展其他提取方式
return None
def get_method(self) -> str:
"""获取验证方法
参考osu-web的逻辑智能选择验证方法
"""
current_method = self.session.verification_method
if current_method is None:
# 智能选择验证方法
# 参考osu-web: API版本 < 20250913 或用户没有TOTP时使用邮件验证
# 这里简化为检查用户是否有TOTP
totp_key = getattr(self.user, "totp_key", None)
current_method = "totp" if totp_key else "mail"
# 设置验证方法
bg_tasks.add_task(self._set_verification_method, current_method)
return current_method
async def _set_verification_method(self, method: str) -> None:
"""内部方法:设置验证方法"""
try:
token_id = self.session.token_id
if token_id is not None and method in ["totp", "mail"]:
# 类型检查确保method是正确的字面量类型
verification_method = method if method in ["totp", "mail"] else "totp"
await LoginSessionService.set_login_method(
self.user.id,
token_id,
cast(Literal["totp", "mail"], verification_method),
self.redis,
)
except Exception as e:
logger.error(f"[Session Verification] Error setting verification method: {e}")
def is_verified(self) -> bool:
"""检查会话是否已验证"""
return self.session.is_verified
async def mark_verified(self) -> None:
"""标记会话为已验证"""
try:
async with with_db() as db:
# 创建专用数据库会话
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.mark_session_verified(db, self.redis, self.user.id, token_id)
except Exception as e:
logger.error(f"[Session Verification] Error marking session verified: {e}")
def get_key(self) -> str:
"""获取会话密钥"""
return str(self.session.id) if self.session.id else ""
def get_key_for_event(self) -> str:
"""获取用于事件广播的会话密钥"""
return LoginSessionService.get_key_for_event(self.get_key())
def user_id(self) -> int:
"""获取用户ID"""
return self.user.id
async def issue_mail_if_needed(self) -> None:
"""如果需要,发送验证邮件"""
try:
if self.get_method() == "mail":
from app.service.verification_service import EmailVerificationService
# 创建专用数据库会话发送邮件
async with with_db() as db:
await EmailVerificationService.send_verification_email(
db, self.redis, self.user.id, self.user.username, self.user.email, None, None
)
except Exception as e:
logger.error(f"[Session Verification] Error issuing mail: {e}")
class SessionVerificationController:
"""会话验证控制器
参考osu-web的Controller类实现
"""
# 需要跳过验证的路由参考osu-web的SKIP_VERIFICATION_ROUTES
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/me",
"/api/v2/logout",
"/oauth/token",
}
@staticmethod
def should_skip_verification(request: Request) -> bool:
"""检查是否应该跳过验证"""
path = request.url.path
return path in SessionVerificationController.SKIP_VERIFICATION_ROUTES
@staticmethod
async def initiate_verification(
state: SessionVerificationState,
request: Request,
) -> Response:
"""启动会话验证流程
参考osu-web的initiate方法
"""
try:
method = state.get_method()
# 如果是邮件验证,发送验证邮件
if method == "mail":
await state.issue_mail_if_needed()
# API请求返回JSON响应
if request.url.path.startswith("/api/"):
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": method})
# 其他情况可以扩展支持HTML响应
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"authentication": "verify", "method": method, "message": "Session verification required"},
)
except Exception as e:
logger.error(f"[Session Verification] Error initiating verification: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Verification initiation failed"
)
class SessionVerificationMiddleware:
"""会话验证中间件
参考osu-web的VerifyUser中间件实现
"""
def __init__(self, app: Callable[[Request], Awaitable[Response]]):
self.app = app
async def __call__(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
"""中间件主要逻辑"""
try:
# 检查是否需要跳过验证
if SessionVerificationController.should_skip_verification(request):
return await call_next(request)
# 获取依赖项
user = await self._get_user(request)
if not user:
# 未认证用户跳过验证
return await call_next(request)
# 获取数据库和Redis连接
async with with_db() as db:
redis = await self._get_redis()
# 获取会话验证状态
state = await SessionVerificationState.get_current(request, db, redis, user)
if not state:
# 无法获取会话状态,继续请求
return await call_next(request)
# 检查是否已验证
if state.is_verified():
# 已验证,继续请求
return await call_next(request)
# 检查是否需要验证
if not self._requires_verification(request):
return await call_next(request)
# 启动验证流程
return await SessionVerificationController.initiate_verification(state, request)
except Exception as e:
logger.error(f"[Session Verification Middleware] Unexpected error: {e}")
# 出错时允许请求继续,避免阻塞正常流程
return await call_next(request)
async def _get_user(self, request: Request) -> User | None:
"""获取当前用户"""
try:
# 这里需要手动获取用户,因为在中间件中无法直接使用依赖注入
# 简化实现实际应该从token中解析用户
return None # 暂时返回None需要实际实现
except Exception:
return None
async def _get_redis(self) -> Redis:
"""获取Redis连接"""
return get_redis()
def _requires_verification(self, request: Request) -> bool:
"""检查是否需要验证
参考osu-web的requiresVerification方法
"""
method = request.method
# GET/HEAD/OPTIONS请求一般不需要验证
safe_methods = {"GET", "HEAD", "OPTIONS"}
if method in safe_methods:
return False
# POST/PUT/DELETE等修改操作需要验证
return True
# FastAPI中间件包装器
class FastAPISessionVerificationMiddleware:
"""FastAPI会话验证中间件包装器"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope, receive)
async def call_next(req: Request) -> Response:
# 这里需要调用FastAPI应用
return Response("OK") # 占位符实现
middleware = SessionVerificationMiddleware(call_next)
response = await middleware(request, call_next)
await response(scope, receive, send)

View File

@@ -10,11 +10,13 @@ from collections.abc import Callable
from typing import ClassVar
from app.auth import get_token_by_access_token
from app.const import SUPPORT_TOTP_VERIFICATION_VER
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import get_redis, with_db
from app.log import logger
from app.service.verification_service import LoginSessionService
from app.utils import extract_user_agent
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
@@ -34,7 +36,9 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/session/verify/mail-fallback",
"/api/v2/me",
"/api/v2/me/",
"/api/v2/logout",
"/oauth/token",
"/health",
@@ -44,10 +48,8 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
"/redoc",
}
# 需要强制验证的路由模式(敏感操作)
# 总是需要验证的路由前缀
ALWAYS_VERIFY_PATTERNS: ClassVar[set[str]] = {
"/api/v2/account/",
"/api/v2/settings/",
"/api/private/admin/",
}
@@ -110,9 +112,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
if path.startswith(pattern):
return True
# 特权用户或非活跃用户需要验证
# if hasattr(user, 'is_privileged') and user.is_privileged():
# return True
if not user.is_active:
return True
@@ -154,6 +153,14 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
try:
# 提取会话token这里简化为使用相同的auth token
auth_header = request.headers.get("Authorization", "")
api_version = 0
raw_api_version = request.headers.get("x-api-version")
if raw_api_version is not None:
try:
api_version = int(raw_api_version)
except ValueError:
api_version = 0
if not auth_header.startswith("Bearer "):
return None
@@ -168,7 +175,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
if not session or session.user_id != user.id:
return None
return SessionState(session, user, redis, db)
return SessionState(session, user, redis, db, api_version)
except Exception as e:
logger.error(f"[Verify Session Middleware] Error getting session state: {e}")
@@ -178,8 +185,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
"""启动验证流程"""
try:
method = await state.get_method()
# 如果是邮件验证,可以在这里触发发送邮件
if method == "mail":
await state.issue_mail_if_needed()
@@ -202,11 +207,12 @@ class SessionState:
简化版本的会话状态管理
"""
def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession):
def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession, api_version: int = 0) -> None:
self.session = session
self.user = user
self.redis = redis
self.db = db
self.api_version = api_version
self._verification_method: str | None = None
def is_verified(self) -> bool:
@@ -223,14 +229,15 @@ class SessionState:
self.user.id, token_id, self.redis
)
# 如果没有设置,智能选择
if self._verification_method is None:
# 检查用户是否有TOTP密钥
await self.user.awaitable_attrs.totp_key # 预加载
totp_key = getattr(self.user, "totp_key", None)
if self.api_version < SUPPORT_TOTP_VERIFICATION_VER:
self._verification_method = "mail"
return self._verification_method
await self.user.awaitable_attrs.totp_key
totp_key = self.user.totp_key
self._verification_method = "totp" if totp_key else "mail"
# 保存选择的方法
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.set_login_method(
@@ -244,8 +251,15 @@ class SessionState:
try:
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.mark_session_verified(self.db, self.redis, self.user.id, token_id)
self.session.is_verified = True # 更新本地状态
await LoginSessionService.mark_session_verified(
self.db,
self.redis,
self.user.id,
token_id,
self.session.ip_address,
extract_user_agent(self.session.user_agent),
self.session.web_uuid,
)
except Exception as e:
logger.error(f"[Session State] Error marking verified: {e}")
@@ -266,10 +280,12 @@ class SessionState:
"""获取会话密钥"""
return str(self.session.id) if self.session.id else ""
def get_key_for_event(self) -> str:
@property
def key_for_event(self) -> str:
"""获取用于事件广播的会话密钥"""
return LoginSessionService.get_key_for_event(self.get_key())
@property
def user_id(self) -> int:
"""获取用户ID"""
return self.user.id

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
from app.models.score import GameMode
@@ -53,3 +54,33 @@ class CurrentUserAttributes(BaseModel):
can_new_comment: bool | None = None
can_new_comment_reason: str | None = None
pin: PinAttributes | None = None
@dataclass
class UserAgentInfo:
raw_ua: str = ""
browser: str | None = None
version: str | None = None
os: str | None = None
platform: str | None = None
is_mobile: bool = False
is_tablet: bool = False
is_pc: bool = False
is_client: bool = False
@property
def displayed_name(self) -> str:
parts = []
if self.browser:
parts.append(self.browser)
if self.version:
parts.append(self.version)
if self.os:
if parts:
parts.append(f"on {self.os}")
else:
parts.append(self.os)
return " ".join(parts) if parts else "Unknown"
def __str__(self) -> str:
return self.displayed_name

View File

@@ -21,6 +21,7 @@ from app.database.auth import TotpKeys
from app.database.statistics import UserStatistics
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user_agent import UserAgentInfo
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
from app.models.extended_auth import ExtendedTokenResponse
@@ -39,7 +40,7 @@ from app.service.verification_service import (
)
from app.utils import utcnow
from fastapi import APIRouter, Depends, Form, Request
from fastapi import APIRouter, Depends, Form, Header, Request
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text
@@ -199,6 +200,7 @@ async def register_user(
async def oauth_token(
db: Database,
request: Request,
user_agent: UserAgentInfo,
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
),
@@ -211,12 +213,10 @@ async def oauth_token(
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
redis: Redis = Depends(get_redis),
geoip: GeoIPHelper = Depends(get_geoip_helper),
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
):
scopes = scope.split(" ")
# 打印请求头
# logger.info(f"Request headers: {request.headers}")
client = (
await db.exec(
select(OAuthClient).where(
@@ -306,19 +306,19 @@ async def oauth_token(
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
settings.refresh_token_expire_minutes * 60,
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
)
token_id = token.id
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")
# 获取国家代码
geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX")
# 检查是否为新位置登录
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
trusted_device = await LoginSessionService.check_trusted_device(db, user_id, ip_address, user_agent, web_uuid)
session_verification_method = None
if settings.enable_totp_verification and totp_key is not None:
@@ -331,18 +331,12 @@ async def oauth_token(
login_method="password_pending_verification",
notes="需要 TOTP 验证",
)
elif is_new_location and settings.enable_email_verification:
# 如果是新位置登录,需要邮件验证
elif not trusted_device and settings.enable_email_verification:
# 如果是新设备登录,需要邮件验证
# 刷新用户对象以确保属性已加载
await db.refresh(user)
session_verification_method = "mail"
# 使用智能验证发送邮件
(
verification_sent,
verification_message,
client_info,
) = await EmailVerificationService.send_smart_verification_email(
await EmailVerificationService.send_verification_email(
db,
redis,
user_id,
@@ -350,36 +344,30 @@ async def oauth_token(
user.email,
ip_address,
user_agent,
client_id,
country_code,
is_new_location,
)
# 记录需要二次验证的登录尝试
client_display_name = client_info.client_type if client_info else "unknown"
await LoginLogService.record_login(
db=db,
user_id=user_id,
request=request,
login_success=True,
login_method="password_pending_verification",
notes=f"智能验证: {verification_message} - 客户端: {client_display_name}, "
f"IP: {ip_address}, 国家: {country_code}",
notes=(
f"邮箱验证: User-Agent: {user_agent.raw_ua}, 客户端: {user_agent.displayed_name} "
f"IP: {ip_address}, 国家: {country_code}"
),
)
elif not trusted_device:
# 新设备登录但邮件验证功能被禁用,直接标记会话为已验证
await LoginSessionService.mark_session_verified(
db, redis, user_id, token_id, ip_address, user_agent, web_uuid
)
if not verification_sent:
# 邮件发送失败,记录错误
logger.error(f"[Auth] Smart verification failed for user {user_id}: {verification_message}")
else:
logger.info(f"[Auth] Smart verification result for user {user_id}: {verification_message}")
elif is_new_location:
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
await LoginSessionService.mark_session_verified(db, redis, user_id, token_id)
logger.debug(
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
)
else:
# 不是新位置登录,正常登录
# 不是新设备登录,正常登录
await LoginLogService.record_login(
db=db,
user_id=user_id,
@@ -391,12 +379,12 @@ async def oauth_token(
if session_verification_method:
await LoginSessionService.create_session(
db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, False
db, user_id, token_id, ip_address, user_agent.raw_ua, trusted_device, web_uuid, False
)
await LoginSessionService.set_login_method(user_id, token_id, session_verification_method, redis)
else:
await LoginSessionService.create_session(
db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, True
db, user_id, token_id, ip_address, user_agent.raw_ua, trusted_device, web_uuid, True
)
return TokenResponse(
@@ -449,6 +437,7 @@ async def oauth_token(
access_token,
new_refresh_token,
settings.access_token_expire_minutes * 60,
settings.refresh_token_expire_minutes * 60,
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
)
return TokenResponse(
@@ -514,6 +503,7 @@ async def oauth_token(
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
settings.refresh_token_expire_minutes * 60,
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
)
@@ -561,6 +551,7 @@ async def oauth_token(
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
settings.refresh_token_expire_minutes * 60,
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from app.config import settings
from . import audio_proxy, avatar, beatmapset, cover, oauth, relationship, score, team, username # noqa: F401
from . import admin, audio_proxy, avatar, beatmapset, cover, oauth, relationship, score, team, username # noqa: F401
from .router import router as private_router
if settings.enable_totp_verification:

157
app/router/private/admin.py Normal file
View File

@@ -0,0 +1,157 @@
from __future__ import annotations
from app.database.auth import OAuthToken
from app.database.verification import LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp
from app.dependencies.database import Database
from app.dependencies.geoip import get_geoip_helper
from app.dependencies.user import UserAndToken, get_client_user_and_token
from app.helpers.geoip_helper import GeoIPHelper
from .router import router
from fastapi import Depends, HTTPException, Security
from pydantic import BaseModel
from sqlmodel import col, select
class SessionsResp(BaseModel):
total: int
current: int = 0
sessions: list[LoginSessionResp]
@router.get(
"/admin/sessions",
name="获取当前用户的登录会话列表",
tags=["用户会话", "g0v0 API", "管理"],
response_model=SessionsResp,
)
async def get_sessions(
session: Database,
user_and_token: UserAndToken = Security(get_client_user_and_token),
geoip: GeoIPHelper = Depends(get_geoip_helper),
):
current_user, token = user_and_token
sessions = (
await session.exec(
select(
LoginSession,
)
.where(LoginSession.user_id == current_user.id, col(LoginSession.is_verified).is_(True))
.order_by(col(LoginSession.created_at).desc())
)
).all()
return SessionsResp(
total=len(sessions),
current=token.id,
sessions=[LoginSessionResp.from_db(s, geoip) for s in sessions],
)
@router.delete(
"/admin/sessions/{session_id}",
name="注销指定的登录会话",
tags=["用户会话", "g0v0 API", "管理"],
status_code=204,
)
async def delete_session(
session: Database,
session_id: int,
user_and_token: UserAndToken = Security(get_client_user_and_token),
):
current_user, token = user_and_token
if session_id == token.id:
raise HTTPException(status_code=400, detail="Cannot delete the current session")
db_session = await session.get(LoginSession, session_id)
if not db_session or db_session.user_id != current_user.id:
raise HTTPException(status_code=404, detail="Session not found")
await session.delete(db_session)
token = await session.get(OAuthToken, db_session.token_id or 0)
if token:
await session.delete(token)
await session.commit()
return
class TrustedDevicesResp(BaseModel):
total: int
current: int = 0
devices: list[TrustedDeviceResp]
@router.get(
"/admin/trusted-devices",
name="获取当前用户的受信任设备列表",
tags=["用户会话", "g0v0 API", "管理"],
response_model=TrustedDevicesResp,
)
async def get_trusted_devices(
session: Database,
user_and_token: UserAndToken = Security(get_client_user_and_token),
geoip: GeoIPHelper = Depends(get_geoip_helper),
):
current_user, token = user_and_token
devices = (
await session.exec(
select(TrustedDevice)
.where(TrustedDevice.user_id == current_user.id)
.order_by(col(TrustedDevice.last_used_at).desc())
)
).all()
current_device_id = (
await session.exec(
select(TrustedDevice.id)
.join(LoginSession, col(LoginSession.device_id) == TrustedDevice.id)
.where(
LoginSession.token_id == token.id,
TrustedDevice.user_id == current_user.id,
)
.limit(1)
)
).first()
return TrustedDevicesResp(
total=len(devices),
current=current_device_id or 0,
devices=[TrustedDeviceResp.from_db(device, geoip) for device in devices],
)
@router.delete(
"/admin/trusted-devices/{device_id}",
name="移除受信任设备",
tags=["用户会话", "g0v0 API", "管理"],
status_code=204,
)
async def delete_trusted_device(
session: Database,
device_id: int,
user_and_token: UserAndToken = Security(get_client_user_and_token),
):
current_user, token = user_and_token
device = await session.get(TrustedDevice, device_id)
current_device_id = (
await session.exec(
select(TrustedDevice.id)
.join(LoginSession, col(LoginSession.device_id) == TrustedDevice.id)
.where(
LoginSession.token_id == token.id,
TrustedDevice.user_id == current_user.id,
)
.limit(1)
)
).first()
if device_id == current_device_id:
raise HTTPException(status_code=400, detail="Cannot delete the current trusted device")
if not device or device.user_id != current_user.id:
raise HTTPException(status_code=404, detail="Trusted device not found")
await session.delete(device)
await session.commit()
return

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from app.database import MeResp, User
from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user
from app.dependencies.database import Database
from app.dependencies.user import UserAndToken, get_current_user_and_token
@@ -33,7 +32,7 @@ async def get_user_info_with_ruleset(
ruleset: GameMode = Path(description="指定 ruleset"),
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
):
user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, ruleset, token_id=user_and_token[1].id)
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
return user_resp
@@ -48,7 +47,7 @@ async def get_user_info_default(
session: Database,
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
):
user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, None, token_id=user_and_token[1].id)
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
return user_resp

View File

@@ -8,12 +8,13 @@ from typing import Annotated, Literal
from app.auth import check_totp_backup_code, verify_totp_key_with_replay_protection
from app.config import settings
from app.const import BACKUP_CODE_LENGTH
from app.const import BACKUP_CODE_LENGTH, SUPPORT_TOTP_VERIFICATION_VER
from app.database.auth import TotpKeys
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip
from app.dependencies.user import UserAndToken, get_client_user_and_token
from app.dependencies.user_agent import UserAgentInfo
from app.log import logger
from app.service.login_log_service import LoginLogService
from app.service.verification_service import (
@@ -23,7 +24,7 @@ from app.service.verification_service import (
from .router import router
from fastapi import Depends, Form, HTTPException, Request, Security, status
from fastapi import Depends, Form, Header, HTTPException, Request, Security, status
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel
from redis.asyncio import Redis
@@ -62,9 +63,11 @@ async def verify_session(
request: Request,
db: Database,
api_version: APIVersion,
user_agent: UserAgentInfo,
redis: Annotated[Redis, Depends(get_redis)],
verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 g0v0 扩展支持)"),
user_and_token: UserAndToken = Security(get_client_user_and_token),
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
) -> Response:
current_user = user_and_token[0]
token_id = user_and_token[1].id
@@ -74,11 +77,12 @@ async def verify_session(
return Response(status_code=status.HTTP_204_NO_CONTENT)
verify_method: str | None = (
"mail" if api_version < 20250913 else await LoginSessionService.get_login_method(user_id, token_id, redis)
"mail"
if api_version < SUPPORT_TOTP_VERIFICATION_VER
else await LoginSessionService.get_login_method(user_id, token_id, redis)
)
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
login_method = "password"
try:
@@ -130,10 +134,11 @@ async def verify_session(
user_id=user_id,
request=request,
login_method=login_method,
user_agent=user_agent.raw_ua,
login_success=True,
notes=f"{login_method} 验证成功",
)
await LoginSessionService.mark_session_verified(db, redis, user_id, token_id)
await LoginSessionService.mark_session_verified(db, redis, user_id, token_id, ip_address, user_agent, web_uuid)
await db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
@@ -179,6 +184,7 @@ async def verify_session(
async def reissue_verification_code(
request: Request,
db: Database,
user_agent: UserAgentInfo,
api_version: APIVersion,
redis: Annotated[Redis, Depends(get_redis)],
user_and_token: UserAndToken = Security(get_client_user_and_token),
@@ -198,7 +204,6 @@ async def reissue_verification_code(
try:
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
user_id = current_user.id
success, message = await EmailVerificationService.resend_verification_code(
db,
@@ -227,6 +232,7 @@ async def reissue_verification_code(
)
async def fallback_email(
db: Database,
user_agent: UserAgentInfo,
request: Request,
redis: Annotated[Redis, Depends(get_redis)],
user_and_token: UserAndToken = Security(get_client_user_and_token),
@@ -237,7 +243,6 @@ async def fallback_email(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退")
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis)
success, message = await EmailVerificationService.resend_verification_code(

View File

@@ -1,122 +0,0 @@
"""
数据库清理调度器 - 定时清理过期数据
"""
from __future__ import annotations
import asyncio
from app.dependencies.database import engine
from app.log import logger
from app.service.database_cleanup_service import DatabaseCleanupService
from sqlmodel.ext.asyncio.session import AsyncSession
class DatabaseCleanupScheduler:
"""数据库清理调度器"""
def __init__(self):
self.running = False
self.task = None
async def start(self):
"""启动调度器"""
if self.running:
return
self.running = True
self.task = asyncio.create_task(self._run_scheduler())
logger.debug("Database cleanup scheduler started")
async def stop(self):
"""停止调度器"""
if not self.running:
return
self.running = False
if self.task:
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
logger.debug("Database cleanup scheduler stopped")
async def _run_scheduler(self):
"""运行调度器"""
while self.running:
try:
# 每小时运行一次清理
await asyncio.sleep(3600) # 3600秒 = 1小时
if not self.running:
break
await self._run_cleanup()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Database cleanup scheduler error: {e!s}")
# 发生错误后等待5分钟再继续
await asyncio.sleep(300)
async def _run_cleanup(self):
"""执行清理任务"""
try:
async with AsyncSession(engine) as db:
logger.debug("Starting scheduled database cleanup...")
# 清理过期的验证码
expired_codes = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
# 清理过期的登录会话
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 清理1小时前未验证的登录会话
unverified_sessions = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1)
# 只在有清理记录时输出总结
total_cleaned = expired_codes + expired_sessions + unverified_sessions
if total_cleaned > 0:
logger.debug(
f"Scheduled cleanup completed - codes: {expired_codes}, "
f"sessions: {expired_sessions}, unverified: {unverified_sessions}"
)
except Exception as e:
logger.error(f"Error during scheduled database cleanup: {e!s}")
async def run_manual_cleanup(self):
"""手动运行完整清理"""
try:
async with AsyncSession(engine) as db:
logger.debug("Starting manual database cleanup...")
results = await DatabaseCleanupService.run_full_cleanup(db)
total = sum(results.values())
if total > 0:
logger.debug(f"Manual cleanup completed, total records cleaned: {total}")
return results
except Exception as e:
logger.error(f"Error during manual database cleanup: {e!s}")
return {}
# 全局实例
database_cleanup_scheduler = DatabaseCleanupScheduler()
async def start_database_cleanup_scheduler():
"""启动数据库清理调度器"""
await database_cleanup_scheduler.start()
async def stop_database_cleanup_scheduler():
"""停止数据库清理调度器"""
await database_cleanup_scheduler.stop()
async def run_manual_database_cleanup():
"""手动运行数据库清理"""
return await database_cleanup_scheduler.run_manual_cleanup()

View File

@@ -1,230 +0,0 @@
"""
客户端检测服务
用于识别不同类型的 osu! 客户端和设备
"""
from __future__ import annotations
from dataclasses import dataclass
import hashlib
import re
from typing import ClassVar, Literal
from app.log import logger
@dataclass
class ClientInfo:
"""客户端信息"""
client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"]
platform: str | None = None
version: str | None = None
device_fingerprint: str | None = None
is_trusted_client: bool = False
class ClientDetectionService:
"""客户端检测服务"""
# osu! 客户端的 User-Agent 模式
OSU_CLIENT_PATTERNS: ClassVar[dict[str, list[str]]] = {
"osu_stable": [
r"osu!/(\d+(?:\.\d+)*)", # osu!/20241001
r"osu!", # 简单匹配
],
"osu_lazer": [
r"osu-lazer/(\d+(?:\.\d+)*)", # osu-lazer/2024.1009.0
r"osu!lazer/(\d+(?:\.\d+)*)", # osu!lazer/2024.1009.0
],
"osu_web": [
r"Mozilla.*osu\.ppy\.sh", # 网页客户端
],
"mobile": [
r"osu!.*mobile",
r"osu.*Mobile",
r"Mobile.*osu",
],
}
# 受信任的客户端类型(不需要频繁验证)
TRUSTED_CLIENT_TYPES: ClassVar[set[str]] = {"osu_stable", "osu_lazer"}
@staticmethod
def detect_client(user_agent: str | None, client_id: int | None = None) -> ClientInfo:
"""
检测客户端类型和信息
Args:
user_agent: 用户代理字符串
client_id: OAuth 客户端 ID
Returns:
ClientInfo: 客户端信息
"""
from app.config import settings # 导入在函数内部避免循环导入
if not user_agent:
return ClientInfo(client_type="unknown")
# 优先通过 client_id 判断客户端类型
if client_id is not None:
if client_id == settings.osu_client_id:
# osu! stable 客户端
return ClientInfo(
client_type="osu_stable",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=True,
)
elif client_id == settings.osu_web_client_id:
# 检查 User-Agent 是否表明这是 Lazer 客户端
if user_agent and user_agent.strip() == "osu!":
# Lazer 客户端使用 web client_id 但发送简单的 "osu!" User-Agent
return ClientInfo(
client_type="osu_lazer",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=True,
)
else:
# 真正的 web 客户端
return ClientInfo(
client_type="osu_web",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=False,
)
# 回退到基于 User-Agent 的检测
for client_type_str, patterns in ClientDetectionService.OSU_CLIENT_PATTERNS.items():
for pattern in patterns:
match = re.search(pattern, user_agent, re.IGNORECASE)
if match:
version = match.group(1) if match.groups() else None
platform = ClientDetectionService._extract_platform(user_agent)
# 确保 client_type 是正确的 Literal 类型
client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"] = client_type_str # type: ignore
return ClientInfo(
client_type=client_type,
platform=platform,
version=version,
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=client_type in ClientDetectionService.TRUSTED_CLIENT_TYPES,
)
# 检测常见浏览器
if any(browser in user_agent.lower() for browser in ["chrome", "firefox", "safari", "edge"]):
return ClientInfo(
client_type="osu_web",
platform=ClientDetectionService._extract_platform(user_agent),
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=False,
)
return ClientInfo(
client_type="unknown",
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
is_trusted_client=False,
)
@staticmethod
def _extract_platform(user_agent: str) -> str | None:
"""从 User-Agent 中提取平台信息"""
platforms = {
"windows": ["windows", "win32", "win64"],
"macos": ["macintosh", "mac os", "darwin"],
"linux": ["linux", "ubuntu", "debian"],
"android": ["android"],
"ios": ["iphone", "ipad", "ios"],
}
user_agent_lower = user_agent.lower()
for platform, keywords in platforms.items():
if any(keyword in user_agent_lower for keyword in keywords):
return platform
return None
@staticmethod
def _generate_device_fingerprint(user_agent: str) -> str:
"""生成设备指纹"""
# 使用 User-Agent 的哈希值作为简单的设备指纹
# 在实际应用中可以结合更多信息IP、屏幕分辨率等
return hashlib.sha256(user_agent.encode()).hexdigest()[:16]
@staticmethod
def should_skip_email_verification(
client_info: ClientInfo,
is_new_location: bool,
user_id: int,
) -> bool:
"""
判断是否应该跳过邮件验证
Args:
client_info: 客户端信息
is_new_location: 是否为新位置登录
user_id: 用户 ID
Returns:
bool: 是否应该跳过邮件验证
"""
# 受信任的客户端类型可以减少验证频率
if client_info.is_trusted_client:
logger.info(
f"[Client Detection] Trusted client {client_info.client_type} for user {user_id}, "
f"reducing verification requirements"
)
return True
# 如果不是新位置,跳过验证
if not is_new_location:
return True
return False
@staticmethod
def get_verification_cooldown(client_info: ClientInfo) -> int:
"""
获取验证冷却时间(秒)
Args:
client_info: 客户端信息
Returns:
int: 冷却时间(秒)
"""
# 受信任的客户端有更长的冷却时间
if client_info.is_trusted_client:
return 3600 # 1小时
# 网页客户端较短的冷却时间
if client_info.client_type == "osu_web":
return 1800 # 30分钟
# 未知客户端最短冷却时间
return 900 # 15分钟
@staticmethod
def format_client_display_name(client_info: ClientInfo) -> str:
"""格式化客户端显示名称"""
display_names = {
"osu_stable": "osu! (stable)",
"osu_lazer": "osu!(lazer)",
"osu_web": "osu! web",
"mobile": "osu! mobile",
"unknown": "Unknown client",
}
base_name = display_names.get(client_info.client_type, "Unknown client")
if client_info.version:
base_name += f" v{client_info.version}"
if client_info.platform:
base_name += f" ({client_info.platform})"
return base_name

View File

@@ -6,11 +6,14 @@ from __future__ import annotations
from datetime import timedelta
from app.database.verification import EmailVerification, LoginSession
from app.database.auth import OAuthToken
from app.database.verification import EmailVerification, LoginSession, TrustedDevice
from app.dependencies.database import with_db
from app.dependencies.scheduler import get_scheduler
from app.log import logger
from app.utils import utcnow
from sqlmodel import col, select
from sqlmodel import col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -69,7 +72,9 @@ class DatabaseCleanupService:
# 查找过期的登录会话记录
current_time = utcnow()
stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
stmt = select(LoginSession).where(
LoginSession.expires_at < current_time, col(LoginSession.is_verified).is_(False)
)
result = await db.exec(stmt)
expired_sessions = result.all()
@@ -179,50 +184,109 @@ class DatabaseCleanupService:
return 0
@staticmethod
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
async def cleanup_outdated_verified_sessions(db: AsyncSession) -> int:
"""
清理旧的已验证会话记录
清理过期会话记录
Args:
db: 数据库会话
days_old: 清理多少天前的已验证记录默认30天
Returns:
int: 清理的记录数
"""
try:
# 查找指定天数前的已验证会话记录
cutoff_time = utcnow() - timedelta(days=days_old)
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
stmt = select(LoginSession).where(
col(LoginSession.is_verified).is_(True), col(LoginSession.token_id).is_(None)
)
result = await db.exec(stmt)
all_verified_sessions = result.all()
# 筛选出过期的记录
old_verified_sessions = [
session
for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_time
]
# 删除旧的已验证记录
deleted_count = 0
for session in old_verified_sessions:
for session in result.all():
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days"
)
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} outdated verified sessions")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}")
logger.error(f"[Cleanup Service] Error cleaning outdated verified sessions: {e!s}")
return 0
@staticmethod
async def cleanup_outdated_trusted_devices(db: AsyncSession) -> int:
"""
清理过期的受信任设备记录
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
# 查找过期的受信任设备记录
current_time = utcnow()
stmt = select(TrustedDevice).where(TrustedDevice.expires_at < current_time)
result = await db.exec(stmt)
expired_devices = result.all()
# 删除过期的记录
deleted_count = 0
for device in expired_devices:
await db.delete(device)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired trusted devices")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired trusted devices: {e!s}")
return 0
@staticmethod
async def cleanup_outdated_tokens(db: AsyncSession) -> int:
"""
清理过期的 OAuth 令牌
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
current_time = utcnow()
stmt = select(OAuthToken).where(OAuthToken.refresh_token_expires_at < current_time)
result = await db.exec(stmt)
expired_tokens = result.all()
deleted_count = 0
for token in expired_tokens:
await db.delete(token)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired OAuth tokens")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired OAuth tokens: {e!s}")
return 0
@staticmethod
@@ -250,8 +314,14 @@ class DatabaseCleanupService:
# 清理7天前的已使用验证码
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
# 清理30天前的已验证会话
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
# 清理过期的受信任设备
results["outdated_trusted_devices"] = await DatabaseCleanupService.cleanup_outdated_trusted_devices(db)
# 清理过期的 OAuth 令牌
results["outdated_oauth_tokens"] = await DatabaseCleanupService.cleanup_outdated_tokens(db)
# 清理过期token 过期)的已验证会话
results["outdated_verified_sessions"] = await DatabaseCleanupService.cleanup_outdated_verified_sessions(db)
total_cleaned = sum(results.values())
if total_cleaned > 0:
@@ -279,21 +349,27 @@ class DatabaseCleanupService:
cutoff_30_days = current_time - timedelta(days=30)
# 统计过期的验证码数量
expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
expired_codes_stmt = (
select(func.count()).select_from(EmailVerification).where(EmailVerification.expires_at < current_time)
)
expired_codes_result = await db.exec(expired_codes_stmt)
expired_codes_count = len(expired_codes_result.all())
expired_codes_count = expired_codes_result.one()
# 统计过期的登录会话数量
expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
expired_sessions_stmt = (
select(func.count()).select_from(LoginSession).where(LoginSession.expires_at < current_time)
)
expired_sessions_result = await db.exec(expired_sessions_stmt)
expired_sessions_count = len(expired_sessions_result.all())
expired_sessions_count = expired_sessions_result.one()
# 统计1小时前未验证的登录会话数量
unverified_sessions_stmt = select(LoginSession).where(
col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour
unverified_sessions_stmt = (
select(func.count())
.select_from(LoginSession)
.where(col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour)
)
unverified_sessions_result = await db.exec(unverified_sessions_stmt)
unverified_sessions_count = len(unverified_sessions_result.all())
unverified_sessions_count = unverified_sessions_result.one()
# 统计7天前的已使用验证码数量
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
@@ -304,10 +380,10 @@ class DatabaseCleanupService:
)
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
all_verified_sessions = old_verified_sessions_result.all()
old_verified_sessions_count = len(
outdated_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
outdated_verified_sessions_result = await db.exec(outdated_verified_sessions_stmt)
all_verified_sessions = outdated_verified_sessions_result.all()
outdated_verified_sessions_count = len(
[
session
for session in all_verified_sessions
@@ -315,17 +391,35 @@ class DatabaseCleanupService:
]
)
# 统计过期的 OAuth 令牌数量
outdated_tokens_stmt = (
select(func.count()).select_from(OAuthToken).where(OAuthToken.refresh_token_expires_at < current_time)
)
outdated_tokens_result = await db.exec(outdated_tokens_stmt)
outdated_tokens_count = outdated_tokens_result.one()
# 统计过期的受信任设备数量
outdated_devices_stmt = (
select(func.count()).select_from(TrustedDevice).where(TrustedDevice.expires_at < current_time)
)
outdated_devices_result = await db.exec(outdated_devices_stmt)
outdated_devices_count = outdated_devices_result.one()
return {
"expired_verification_codes": expired_codes_count,
"expired_login_sessions": expired_sessions_count,
"unverified_login_sessions": unverified_sessions_count,
"old_used_verification_codes": old_used_codes_count,
"old_verified_sessions": old_verified_sessions_count,
"outdated_verified_sessions": outdated_verified_sessions_count,
"outdated_oauth_tokens": outdated_tokens_count,
"outdated_trusted_devices": outdated_devices_count,
"total_cleanable": expired_codes_count
+ expired_sessions_count
+ unverified_sessions_count
+ old_used_codes_count
+ old_verified_sessions_count,
+ outdated_verified_sessions_count
+ outdated_tokens_count
+ outdated_devices_count,
}
except Exception as e:
@@ -335,6 +429,23 @@ class DatabaseCleanupService:
"expired_login_sessions": 0,
"unverified_login_sessions": 0,
"old_used_verification_codes": 0,
"old_verified_sessions": 0,
"outdated_verified_sessions": 0,
"outdated_oauth_tokens": 0,
"outdated_trusted_devices": 0,
"total_cleanable": 0,
}
@get_scheduler().scheduled_job(
"interval",
id="cleanup_database",
hours=1,
)
async def scheduled_cleanup_job():
async with with_db() as session:
logger.debug("Starting database cleanup...")
results = await DatabaseCleanupService.run_full_cleanup(session)
total = sum(results.values())
if total > 0:
logger.debug(f"Cleanup completed, total records cleaned: {total}")
return results

View File

@@ -1,283 +0,0 @@
"""
设备信任服务
管理用户的受信任设备,减少频繁验证
"""
from __future__ import annotations
from datetime import timedelta
from app.config import settings
from app.log import logger
from app.service.client_detection_service import ClientInfo
from app.utils import utcnow
from redis.asyncio import Redis
class DeviceTrustService:
"""设备信任服务"""
@staticmethod
def _get_device_trust_key(user_id: int, device_fingerprint: str) -> str:
"""获取设备信任的 Redis 键"""
return f"device_trust:{user_id}:{device_fingerprint}"
@staticmethod
def _get_location_trust_key(user_id: int, country_code: str) -> str:
"""获取位置信任的 Redis 键"""
return f"location_trust:{user_id}:{country_code}"
@staticmethod
def _get_verification_cooldown_key(user_id: int) -> str:
"""获取验证冷却的 Redis 键"""
return f"verification_cooldown:{user_id}"
@staticmethod
async def is_device_trusted(
redis: Redis,
user_id: int,
device_fingerprint: str,
) -> bool:
"""
检查设备是否受信任
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
Returns:
bool: 设备是否受信任
"""
if not device_fingerprint:
return False
trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint)
trust_data = await redis.get(trust_key)
return trust_data is not None
@staticmethod
async def is_location_trusted(
redis: Redis,
user_id: int,
country_code: str | None,
) -> bool:
"""
检查位置是否受信任
Args:
redis: Redis 连接
user_id: 用户 ID
country_code: 国家代码
Returns:
bool: 位置是否受信任
"""
if not country_code:
return False
trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code)
trust_data = await redis.get(trust_key)
return trust_data is not None
@staticmethod
async def is_in_verification_cooldown(
redis: Redis,
user_id: int,
) -> bool:
"""
检查用户是否在验证冷却期内
Args:
redis: Redis 连接
user_id: 用户 ID
Returns:
bool: 是否在冷却期内
"""
cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id)
cooldown_data = await redis.get(cooldown_key)
return cooldown_data is not None
@staticmethod
async def trust_device(
redis: Redis,
user_id: int,
device_fingerprint: str,
client_info: ClientInfo,
trust_duration_days: int | None = None,
) -> None:
"""
信任设备
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
client_info: 客户端信息
trust_duration_days: 信任持续天数
"""
if not device_fingerprint:
return
# 使用配置中的默认值
if trust_duration_days is None:
trust_duration_days = settings.device_trust_duration_days
trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint)
trust_data = {
"client_type": client_info.client_type,
"platform": client_info.platform or "unknown",
"trusted_at": utcnow().isoformat(),
}
# 设置信任期限
trust_duration_seconds = trust_duration_days * 24 * 3600
await redis.setex(trust_key, trust_duration_seconds, str(trust_data))
logger.info(
f"[Device Trust] Device trusted for user {user_id}: "
f"{client_info.client_type} on {client_info.platform} "
f"(fingerprint: {device_fingerprint[:8]}...)"
)
@staticmethod
async def trust_location(
redis: Redis,
user_id: int,
country_code: str,
trust_duration_days: int | None = None,
) -> None:
"""
信任位置
Args:
redis: Redis 连接
user_id: 用户 ID
country_code: 国家代码
trust_duration_days: 信任持续天数
"""
if not country_code:
return
# 使用配置中的默认值
if trust_duration_days is None:
trust_duration_days = settings.location_trust_duration_days
trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code)
trust_data = {
"country_code": country_code,
"trusted_at": utcnow().isoformat(),
}
# 设置信任期限
trust_duration_seconds = trust_duration_days * 24 * 3600
await redis.setex(trust_key, trust_duration_seconds, str(trust_data))
logger.info(f"[Location Trust] Location trusted for user {user_id}: {country_code}")
@staticmethod
async def set_verification_cooldown(
redis: Redis,
user_id: int,
cooldown_seconds: int,
) -> None:
"""
设置验证冷却期
Args:
redis: Redis 连接
user_id: 用户 ID
cooldown_seconds: 冷却时间(秒)
"""
cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id)
cooldown_data = {
"set_at": utcnow().isoformat(),
"expires_at": (utcnow() + timedelta(seconds=cooldown_seconds)).isoformat(),
}
await redis.setex(cooldown_key, cooldown_seconds, str(cooldown_data))
logger.info(f"[Verification Cooldown] Set cooldown for user {user_id}: {cooldown_seconds}s")
@staticmethod
async def should_require_verification(
redis: Redis,
user_id: int,
device_fingerprint: str | None,
country_code: str | None,
client_info: ClientInfo,
is_new_location: bool,
) -> tuple[bool, str]:
"""
判断是否需要验证
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
country_code: 国家代码
client_info: 客户端信息
is_new_location: 是否为新位置
Returns:
tuple[bool, str]: (是否需要验证, 原因)
"""
# 检查验证冷却期
if await DeviceTrustService.is_in_verification_cooldown(redis, user_id):
return False, "用户在验证冷却期内"
# 检查设备信任
if device_fingerprint and await DeviceTrustService.is_device_trusted(redis, user_id, device_fingerprint):
return False, "设备已受信任"
# 检查位置信任
if country_code and await DeviceTrustService.is_location_trusted(redis, user_id, country_code):
return False, "位置已受信任"
# 受信任的客户端类型降低验证要求
if client_info.is_trusted_client and not is_new_location:
return False, "受信任客户端且非新位置"
# 如果是新位置登录,需要验证
if is_new_location:
return True, "新位置登录需要验证"
# 默认不需要验证
return False, "常规登录无需验证"
@staticmethod
async def mark_verification_successful(
redis: Redis,
user_id: int,
device_fingerprint: str | None,
country_code: str | None,
client_info: ClientInfo,
) -> None:
"""
标记验证成功,更新信任信息
Args:
redis: Redis 连接
user_id: 用户 ID
device_fingerprint: 设备指纹
country_code: 国家代码
client_info: 客户端信息
"""
# 信任设备
if device_fingerprint:
await DeviceTrustService.trust_device(redis, user_id, device_fingerprint, client_info)
# 信任位置
if country_code:
await DeviceTrustService.trust_location(redis, user_id, country_code)
# 设置验证冷却期
cooldown_seconds = (client_info.is_trusted_client and 3600) or 1800 # 受信任客户端1小时其他30分钟
await DeviceTrustService.set_verification_cooldown(redis, user_id, cooldown_seconds)
logger.info(f"[Device Trust] Verification successful for user {user_id}, trust updated")

View File

@@ -9,7 +9,7 @@ import asyncio
from app.database.user_login_log import UserLoginLog
from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip
from app.log import logger
from app.utils import simplify_user_agent, utcnow
from app.utils import utcnow
from fastapi import Request
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -23,6 +23,7 @@ class LoginLogService:
db: AsyncSession,
user_id: int,
request: Request,
user_agent: str | None = None,
login_success: bool = True,
login_method: str = "password",
notes: str | None = None,
@@ -45,9 +46,6 @@ class LoginLogService:
raw_ip = get_client_ip(request)
ip_address = normalize_ip(raw_ip)
raw_user_agent = request.headers.get("User-Agent", "")
user_agent = simplify_user_agent(raw_user_agent, max_length=500)
# 创建基本的登录记录
login_log = UserLoginLog(
user_id=user_id,
@@ -107,6 +105,7 @@ class LoginLogService:
attempted_username: str | None = None,
login_method: str = "password",
notes: str | None = None,
user_agent: str | None = None,
) -> UserLoginLog:
"""
记录失败的登录尝试
@@ -128,6 +127,7 @@ class LoginLogService:
request=request,
login_success=False,
login_method=login_method,
user_agent=user_agent,
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt",
)

View File

@@ -1,122 +0,0 @@
"""
API 状态管理 - 模拟 osu! 的 APIState 和会话管理
"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
class APIState(str, Enum):
"""API 连接状态,对应 osu! 的 APIState"""
OFFLINE = "offline"
CONNECTING = "connecting"
REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证
ONLINE = "online"
FAILING = "failing"
class UserSession(BaseModel):
"""用户会话信息"""
user_id: int
username: str
email: str
session_token: str | None = None
state: APIState = APIState.OFFLINE
requires_verification: bool = False
verification_sent: bool = False
last_verification_attempt: datetime | None = None
failed_attempts: int = 0
ip_address: str | None = None
country_code: str | None = None
is_new_location: bool = False
class SessionManager:
"""会话管理器"""
def __init__(self):
self._sessions: dict[str, UserSession] = {}
def create_session(
self,
user_id: int,
username: str,
email: str,
ip_address: str,
country_code: str | None = None,
is_new_location: bool = False,
) -> UserSession:
"""创建新的用户会话"""
import secrets
session_token = secrets.token_urlsafe(32)
# 根据是否为新位置决定初始状态
if is_new_location:
state = APIState.REQUIRES_SECOND_FACTOR_AUTH
else:
state = APIState.ONLINE
session = UserSession(
user_id=user_id,
username=username,
email=email,
session_token=session_token,
state=state,
requires_verification=is_new_location,
ip_address=ip_address,
country_code=country_code,
is_new_location=is_new_location,
)
self._sessions[session_token] = session
return session
def get_session(self, session_token: str) -> UserSession | None:
"""获取会话"""
return self._sessions.get(session_token)
def update_session_state(self, session_token: str, state: APIState):
"""更新会话状态"""
if session_token in self._sessions:
self._sessions[session_token].state = state
def mark_verification_sent(self, session_token: str):
"""标记验证邮件已发送"""
if session_token in self._sessions:
session = self._sessions[session_token]
session.verification_sent = True
session.last_verification_attempt = datetime.now()
def increment_failed_attempts(self, session_token: str):
"""增加失败尝试次数"""
if session_token in self._sessions:
self._sessions[session_token].failed_attempts += 1
def verify_session(self, session_token: str) -> bool:
"""验证会话成功"""
if session_token in self._sessions:
session = self._sessions[session_token]
session.state = APIState.ONLINE
session.requires_verification = False
return True
return False
def remove_session(self, session_token: str):
"""移除会话"""
self._sessions.pop(session_token, None)
def cleanup_expired_sessions(self):
"""清理过期会话"""
# 这里可以实现清理逻辑
pass
# 全局会话管理器
session_manager = SessionManager()

View File

@@ -10,15 +10,15 @@ import string
from typing import Literal
from app.config import settings
from app.database.verification import EmailVerification, LoginSession
from app.database.auth import OAuthToken
from app.database.verification import EmailVerification, LoginSession, TrustedDevice
from app.log import logger
from app.service.client_detection_service import ClientDetectionService, ClientInfo
from app.service.device_trust_service import DeviceTrustService
from app.service.email_queue import email_queue # 导入邮件队列
from app.models.model import UserAgentInfo
from app.service.email_queue import email_queue
from app.utils import utcnow
from redis.asyncio import Redis
from sqlmodel import exists, select
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -248,11 +248,9 @@ This email was sent automatically, please do not reply.
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
user_agent: UserAgentInfo | None = None,
) -> bool:
"""发送验证邮件(带智能检测)"""
"""发送验证邮件"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
@@ -260,32 +258,14 @@ This email was sent automatically, please do not reply.
return True # 返回成功,但不执行验证流程
# 检测客户端信息
client_info = ClientDetectionService.detect_client(user_agent, client_id)
logger.info(
f"[Email Verification] Detected client for user {user_id}: "
f"{ClientDetectionService.format_client_display_name(client_info)}"
)
# 检查是否需要验证
needs_verification, reason = await DeviceTrustService.should_require_verification(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
is_new_location=True, # 这里需要从调用方传入
)
if not needs_verification:
logger.info(f"[Email Verification] Skipping verification for user {user_id}: {reason}")
return True
logger.info(f"[Email Verification] Detected client for user {user_id}: {user_agent}")
# 创建验证记录
(
_,
code,
) = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
db, redis, user_id, email, ip_address, user_agent.raw_ua if user_agent else None
)
# 使用邮件队列发送验证邮件
@@ -304,107 +284,6 @@ This email was sent automatically, please do not reply.
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
return False
@staticmethod
async def send_smart_verification_email(
db: AsyncSession,
redis: Redis,
user_id: int,
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None,
client_id: int | None = None,
country_code: str | None = None,
is_new_location: bool = False,
) -> tuple[bool, str, ClientInfo | None]:
"""
智能邮件验证发送
Args:
db: 数据库会话
redis: Redis 连接
user_id: 用户 ID
username: 用户名
email: 邮箱地址
ip_address: IP 地址
user_agent: 用户代理
client_id: 客户端 ID
country_code: 国家代码
is_new_location: 是否为新位置登录
Returns:
tuple[bool, str, ClientInfo | None]: (是否成功, 消息, 客户端信息)
"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
logger.debug(f"[Smart Verification] Email verification is disabled, skipping for user {user_id}")
return True, "邮件验证功能已禁用", None
# 检查是否启用智能验证
if not settings.enable_smart_verification:
logger.debug(
f"[Smart Verification] Smart verification is disabled, using legacy logic for user {user_id}"
)
# 回退到传统验证逻辑
verification, code = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
success = await EmailVerificationService.send_verification_email_via_queue(
email, code, username, user_id
)
return success, "使用传统验证逻辑发送邮件" if success else "传统验证邮件发送失败", None
# 检测客户端信息
client_info = ClientDetectionService.detect_client(user_agent, client_id)
client_display_name = ClientDetectionService.format_client_display_name(client_info)
logger.info(f"[Smart Verification] Detected client for user {user_id}: {client_display_name}")
# 检查是否需要验证
needs_verification, reason = await DeviceTrustService.should_require_verification(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
is_new_location=is_new_location,
)
if not needs_verification:
logger.info(f"[Smart Verification] Skipping verification for user {user_id}: {reason}")
# 即使不需要验证,也要更新设备信任信息
if client_info.device_fingerprint:
await DeviceTrustService.trust_device(redis, user_id, client_info.device_fingerprint, client_info)
if country_code:
await DeviceTrustService.trust_location(redis, user_id, country_code)
return True, f"跳过验证: {reason}", client_info
# 创建验证记录
verification, code = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
_ = verification # 避免未使用变量警告
# 使用邮件队列发送验证邮件
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
if success:
logger.info(
f"[Smart Verification] Successfully sent verification email to {email} "
f"for user {username} using {client_display_name}"
)
return True, "验证邮件已发送", client_info
else:
logger.error(f"[Smart Verification] Failed to send verification email: {email} (user: {username})")
return False, "验证邮件发送失败", client_info
except Exception as e:
logger.error(f"[Smart Verification] Exception during smart verification: {e}")
return False, f"验证过程中发生错误: {e!s}", None
@staticmethod
async def verify_email_code(
db: AsyncSession,
@@ -416,7 +295,7 @@ This email was sent automatically, please do not reply.
client_id: int | None = None,
country_code: str | None = None,
) -> tuple[bool, str]:
"""验证邮箱验证码(带智能信任更新)"""
"""验证邮箱验证码"""
try:
# 检查是否启用邮件验证功能
if not settings.enable_email_verification:
@@ -452,16 +331,6 @@ This email was sent automatically, please do not reply.
# 删除 Redis 记录
await redis.delete(f"email_verification:{user_id}:{code}")
# 检测客户端信息并更新信任状态
client_info = ClientDetectionService.detect_client(user_agent, client_id)
await DeviceTrustService.mark_verification_successful(
redis=redis,
user_id=user_id,
device_fingerprint=client_info.device_fingerprint,
country_code=country_code,
client_info=client_info,
)
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
return True, "验证成功"
@@ -477,7 +346,7 @@ This email was sent automatically, please do not reply.
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None,
user_agent: UserAgentInfo | None = None,
) -> tuple[bool, str]:
"""重新发送验证码"""
try:
@@ -516,12 +385,12 @@ class LoginSessionService:
# Session verification interface methods
@staticmethod
async def find_for_verification(db: AsyncSession, session_id: str) -> LoginSession | None:
async def find_for_verification(db: AsyncSession, token: str) -> LoginSession | None:
"""根据会话ID查找会话用于验证"""
try:
result = await db.exec(
select(LoginSession).where(
LoginSession.session_token == session_id,
col(LoginSession.token).has(col(OAuthToken.access_token) == token),
LoginSession.expires_at > utcnow(),
)
)
@@ -537,42 +406,31 @@ class LoginSessionService:
@staticmethod
async def create_session(
db: AsyncSession,
redis: Redis,
user_id: int,
token_id: int,
ip_address: str,
user_agent: str | None = None,
country_code: str | None = None,
is_new_location: bool = False,
is_new_device: bool = False,
web_uuid: str | None = None,
is_verified: bool = False,
) -> LoginSession:
"""创建登录会话"""
session_token = EmailVerificationService.generate_session_token()
session = LoginSession(
user_id=user_id,
token_id=token_id,
ip_address=ip_address,
user_agent=None,
country_code=country_code,
is_new_location=is_new_location,
user_agent=user_agent,
is_new_device=is_new_device,
expires_at=utcnow() + timedelta(hours=24), # 24小时过期
is_verified=is_verified,
web_uuid=web_uuid,
)
db.add(session)
await db.commit()
await db.refresh(session)
# 存储到 Redis
await redis.setex(
f"login_session:{session_token}",
86400, # 24小时
user_id,
)
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
logger.info(f"[Login Session] Created session for user {user_id} (new device: {is_new_device})")
return session
@classmethod
@@ -592,35 +450,98 @@ class LoginSessionService:
await redis.delete(cls._session_verify_redis_key(user_id, token_id))
@staticmethod
async def check_new_location(
db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None
async def check_trusted_device(
db: AsyncSession, user_id: int, ip_address: str, user_agent: UserAgentInfo, web_uuid: str | None = None
) -> bool:
"""检查是否为新位置登录"""
try:
# 查看过去30天内是否有相同IP或相同国家的登录记录
thirty_days_ago = utcnow() - timedelta(days=30)
result = await db.exec(
select(LoginSession).where(
LoginSession.user_id == user_id,
LoginSession.created_at > thirty_days_ago,
(LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code),
)
if user_agent.is_client:
query = select(exists()).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "client",
TrustedDevice.ip_address == ip_address,
TrustedDevice.expires_at > utcnow(),
)
existing_sessions = result.all()
# 如果有历史记录,则不是新位置
return len(existing_sessions) == 0
except Exception as e:
logger.error(f"[Login Session] Exception during new location check: {e}")
# 出错时默认为新位置(更安全)
return True
else:
if web_uuid is None:
return False
query = select(exists()).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "web",
TrustedDevice.web_uuid == web_uuid,
TrustedDevice.expires_at > utcnow(),
)
return (await db.exec(query)).first() or False
@staticmethod
async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool:
async def create_trusted_device(
db: AsyncSession,
user_id: int,
ip_address: str,
user_agent: UserAgentInfo,
web_uuid: str | None = None,
) -> TrustedDevice:
device = TrustedDevice(
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent.raw_ua,
client_type="client" if user_agent.is_client else "web",
web_uuid=web_uuid if not user_agent.is_client else None,
expires_at=utcnow() + timedelta(days=settings.device_trust_duration_days),
)
db.add(device)
await db.commit()
await db.refresh(device)
return device
@staticmethod
async def get_or_create_trusted_device(
db: AsyncSession,
user_id: int,
ip_address: str,
user_agent: UserAgentInfo,
web_uuid: str | None = None,
) -> TrustedDevice:
if user_agent.is_client:
query = select(TrustedDevice).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "client",
TrustedDevice.ip_address == ip_address,
)
else:
if web_uuid is None:
raise ValueError("web_uuid is required for web clients")
query = select(TrustedDevice).where(
TrustedDevice.user_id == user_id,
TrustedDevice.client_type == "web",
TrustedDevice.web_uuid == web_uuid,
)
device = (await db.exec(query)).first()
if device is None:
device = await LoginSessionService.create_trusted_device(db, user_id, ip_address, user_agent, web_uuid)
else:
device.last_used_at = utcnow()
device.expires_at = utcnow() + timedelta(days=settings.device_trust_duration_days)
await db.commit()
await db.refresh(device)
return device
@staticmethod
async def mark_session_verified(
db: AsyncSession,
redis: Redis,
user_id: int,
token_id: int,
ip_address: str,
user_agent: UserAgentInfo,
web_uuid: str | None = None,
) -> bool:
"""标记用户的未验证会话为已验证"""
device_info: TrustedDevice | None = None
if user_agent.is_client or web_uuid:
device_info = await LoginSessionService.get_or_create_trusted_device(
db, user_id, ip_address, user_agent, web_uuid
)
try:
# 查找用户所有未验证且未过期的会话
result = await db.exec(
@@ -631,18 +552,20 @@ class LoginSessionService:
LoginSession.token_id == token_id,
)
)
sessions = result.all()
# 标记所有会话为已验证
for session in sessions:
session.is_verified = True
session.verified_at = utcnow()
if device_info:
session.device_id = device_info.id
if sessions:
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
await LoginSessionService.clear_login_method(user_id, token_id, redis)
await db.commit()
return len(sessions) > 0
@@ -658,7 +581,7 @@ class LoginSessionService:
await db.exec(
select(exists()).where(
LoginSession.user_id == user_id,
LoginSession.is_verified == False, # noqa: E712
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > utcnow(),
LoginSession.token_id == token_id,
)

View File

@@ -6,11 +6,15 @@ from datetime import UTC, datetime
import functools
import inspect
from io import BytesIO
from typing import Any, ParamSpec, TypeVar
import re
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from fastapi import HTTPException
from PIL import Image
if TYPE_CHECKING:
from app.models.model import UserAgentInfo
def unix_timestamp_to_windows(timestamp: int) -> int:
"""Convert a Unix timestamp to a Windows timestamp."""
@@ -154,81 +158,79 @@ def check_image(content: bytes, size: int, width: int, height: int) -> str:
raise HTTPException(status_code=400, detail=f"Error processing image: {e}")
def simplify_user_agent(user_agent: str | None, max_length: int = 200) -> str | None:
"""
简化 User-Agent 字符串,只保留 osu! 和关键设备系统信息浏览器
def extract_user_agent(user_agent: str | None) -> "UserAgentInfo":
from app.models.model import UserAgentInfo
Args:
user_agent: 原始 User-Agent 字符串
max_length: 最大长度限制
raw_ua = user_agent or ""
ua = raw_ua.strip()
lower_ua = ua.lower()
Returns:
简化后的 User-Agent 字符串,或 None
"""
import re
info = UserAgentInfo(raw_ua=raw_ua)
if not user_agent:
return None
if not ua:
return info
# 如果长度在限制内,直接返回
if len(user_agent) <= max_length:
return user_agent
client_identifiers = ("osu!", "osu!lazer", "osu-framework")
if any(identifier in lower_ua for identifier in client_identifiers):
info.browser = "osu!"
info.is_client = True
return info
# 提取操作系统信息
os_info = ""
os_patterns = [
r"(Windows[^;)]*)",
r"(Mac OS[^;)]*)",
r"(Linux[^;)]*)",
r"(Android[^;)]*)",
r"(iOS[^;)]*)",
r"(iPhone[^;)]*)",
r"(iPad[^;)]*)",
]
browser_patterns: tuple[tuple[re.Pattern[str], str], ...] = (
(re.compile(r"OPR/(\d+(?:\.\d+)*)"), "Opera"),
(re.compile(r"Edg/(\d+(?:\.\d+)*)"), "Edge"),
(re.compile(r"Chrome/(\d+(?:\.\d+)*)"), "Chrome"),
(re.compile(r"Firefox/(\d+(?:\.\d+)*)"), "Firefox"),
(re.compile(r"Version/(\d+(?:\.\d+)*).*Safari"), "Safari"),
(re.compile(r"Safari/(\d+(?:\.\d+)*)"), "Safari"),
(re.compile(r"MSIE (\d+(?:\.\d+)*)"), "Internet Explorer"),
(re.compile(r"Trident/.*rv:(\d+(?:\.\d+)*)"), "Internet Explorer"),
)
for pattern in os_patterns:
match = re.search(pattern, user_agent, re.IGNORECASE)
for pattern, name in browser_patterns:
match = pattern.search(ua)
if match:
os_info = match.group(1).strip()
info.browser = name
info.version = match.group(1)
break
# 提取浏览器信息
browser_info = ""
browser_patterns = [
r"(osu![^)]*)", # osu! 客户端
r"(Chrome/[\d.]+)",
r"(Firefox/[\d.]+)",
r"(Safari/[\d.]+)",
r"(Edge/[\d.]+)",
r"(Opera/[\d.]+)",
]
os_patterns: tuple[tuple[re.Pattern[str], str], ...] = (
(re.compile(r"windows nt 10"), "Windows 10"),
(re.compile(r"windows nt 6\.3"), "Windows 8.1"),
(re.compile(r"windows nt 6\.2"), "Windows 8"),
(re.compile(r"windows nt 6\.1"), "Windows 7"),
(re.compile(r"windows nt 6\.0"), "Windows Vista"),
(re.compile(r"windows nt 5\.1"), "Windows XP"),
(re.compile(r"mac os x"), "macOS"),
(re.compile(r"iphone os"), "iOS"),
(re.compile(r"ipad;"), "iPadOS"),
(re.compile(r"android"), "Android"),
(re.compile(r"linux"), "Linux"),
)
for pattern in browser_patterns:
match = re.search(pattern, user_agent, re.IGNORECASE)
if match:
browser_info = match.group(1).strip()
# 如果找到了 osu! 客户端,优先使用
if "osu!" in browser_info.lower():
break
for pattern, name in os_patterns:
if pattern.search(lower_ua):
info.os = name
break
# 构建简化的 User-Agent
parts = []
if os_info:
parts.append(os_info)
if browser_info:
parts.append(browser_info)
info.is_mobile = any(keyword in lower_ua for keyword in ("mobile", "iphone", "android", "ipod"))
info.is_tablet = any(keyword in lower_ua for keyword in ("ipad", "tablet"))
# Only classify as PC if not mobile or tablet
if (
not info.is_mobile
and not info.is_tablet
and any(keyword in lower_ua for keyword in ("windows", "macintosh", "linux", "x11"))
):
info.is_pc = True
if parts:
simplified = "; ".join(parts)
else:
# 如果没有识别到关键信息,截断原始字符串
simplified = user_agent[: max_length - 3] + "..."
if info.is_tablet:
info.platform = "tablet"
elif info.is_mobile:
info.platform = "mobile"
elif info.is_pc:
info.platform = "pc"
# 确保不超过最大长度
if len(simplified) > max_length:
simplified = simplified[: max_length - 3] + "..."
return simplified
return info
# https://github.com/encode/starlette/blob/master/starlette/_utils.py

View File

@@ -25,10 +25,6 @@ from app.router import (
from app.router.redirect import redirect_router
from app.router.v1 import api_v1_public_router
from app.scheduler.cache_scheduler import start_cache_scheduler, stop_cache_scheduler
from app.scheduler.database_cleanup_scheduler import (
start_database_cleanup_scheduler,
stop_database_cleanup_scheduler,
)
from app.service.beatmap_download_service import download_service
from app.service.beatmapset_update_service import init_beatmapset_update_service
from app.service.calculate_all_user_rank import calculate_user_rank
@@ -68,7 +64,6 @@ async def lifespan(app: FastAPI):
await start_email_processor() # 启动邮件队列处理器
await download_service.start_health_check() # 启动下载服务健康检查
await start_cache_scheduler() # 启动缓存调度器
await start_database_cleanup_scheduler() # 启动数据库清理调度器
init_beatmapset_update_service(fetcher) # 初始化谱面集更新服务
redis_message_system.start() # 启动 Redis 消息系统
load_achievements()
@@ -83,7 +78,6 @@ async def lifespan(app: FastAPI):
stop_scheduler()
redis_message_system.stop() # 停止 Redis 消息系统
await stop_cache_scheduler() # 停止缓存调度器
await stop_database_cleanup_scheduler() # 停止数据库清理调度器
await download_service.stop_health_check() # 停止下载服务健康检查
await stop_email_processor() # 停止邮件队列处理器
await engine.dispose()

View File

@@ -0,0 +1,102 @@
"""session: support multi-session
Revision ID: 72a9b8f3f863
Revises: b1ac2154bd0d
Create Date: 2025-10-02 07:17:19.297498
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "72a9b8f3f863"
down_revision: str | Sequence[str] | None = "b1ac2154bd0d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"trusted_devices",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("ip_address", sa.VARCHAR(length=45), nullable=False),
sa.Column("user_agent", sa.Text(), nullable=False),
sa.Column("client_type", sa.VARCHAR(length=10), nullable=False),
sa.Column("web_uuid", sa.VARCHAR(length=36), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("last_used_at", sa.DateTime(), nullable=False),
sa.Column("expires_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["lazer_users.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.alter_column(
"login_sessions",
"is_new_location",
new_column_name="is_new_device",
existing_type=mysql.TINYINT(display_width=1),
)
op.create_index(op.f("ix_trusted_devices_user_id"), "trusted_devices", ["user_id"], unique=False)
op.add_column("login_sessions", sa.Column("web_uuid", sa.VARCHAR(length=36), nullable=True))
op.alter_column(
"login_sessions",
"ip_address",
existing_type=mysql.VARCHAR(length=255),
type_=sa.VARCHAR(length=45),
existing_nullable=False,
)
op.alter_column(
"login_sessions", "user_agent", existing_type=mysql.VARCHAR(length=250), type_=sa.Text(), existing_nullable=True
)
op.drop_index(op.f("ix_login_sessions_session_token"), table_name="login_sessions")
op.create_foreign_key(None, "login_sessions", "lazer_users", ["user_id"], ["id"])
op.drop_column("login_sessions", "country_code")
op.drop_column("login_sessions", "session_token")
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("login_sessions", sa.Column("session_token", sa.VARCHAR(length=64), nullable=True))
op.add_column("login_sessions", sa.Column("country_code", sa.VARCHAR(length=255), nullable=True))
op.create_index(op.f("ix_login_sessions_session_token"), "login_sessions", ["session_token"], unique=False)
op.alter_column(
"login_sessions",
"user_agent",
existing_type=sa.Text(),
type_=mysql.VARCHAR(length=250),
existing_nullable=True,
)
op.alter_column(
"login_sessions",
"ip_address",
existing_type=sa.String(length=45),
type_=mysql.VARCHAR(length=255),
existing_nullable=False,
)
op.drop_column("login_sessions", "web_uuid")
op.alter_column(
"login_sessions",
"is_new_device",
new_column_name="is_new_location",
existing_type=mysql.TINYINT(display_width=1),
)
op.drop_constraint(op.f("fk_login_sessions_user_id_lazer_users"), "login_sessions", type_="foreignkey")
op.drop_index(op.f("ix_trusted_devices_user_id"), table_name="trusted_devices")
op.drop_table("trusted_devices")
# ### end Alembic commands ###

View File

@@ -0,0 +1,40 @@
"""auth: add refresh_token_expires_at
Revision ID: 7fe1319250c5
Revises: 72a9b8f3f863
Create Date: 2025-10-02 10:50:21.169065
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "7fe1319250c5"
down_revision: str | Sequence[str] | None = "72a9b8f3f863"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("oauth_tokens", sa.Column("refresh_token_expires_at", sa.DateTime(), nullable=True))
op.create_index(op.f("ix_oauth_tokens_expires_at"), "oauth_tokens", ["expires_at"], unique=False)
op.create_index(
op.f("ix_oauth_tokens_refresh_token_expires_at"), "oauth_tokens", ["refresh_token_expires_at"], unique=False
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_oauth_tokens_refresh_token_expires_at"), table_name="oauth_tokens")
op.drop_index(op.f("ix_oauth_tokens_expires_at"), table_name="oauth_tokens")
op.drop_column("oauth_tokens", "refresh_token_expires_at")
# ### end Alembic commands ###

View File

@@ -0,0 +1,35 @@
"""session: add device_id to LoginSession
Revision ID: 9556cd2ec11f
Revises: 7fe1319250c5
Create Date: 2025-10-02 11:03:09.803140
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "9556cd2ec11f"
down_revision: str | Sequence[str] | None = "7fe1319250c5"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("login_sessions", sa.Column("device_id", sa.BigInteger(), nullable=True))
op.create_index(op.f("ix_login_sessions_device_id"), "login_sessions", ["device_id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("login_sessions", "device_id")
# ### end Alembic commands ###