feat(auth): support trusted device (#52)
New API to maintain sessions and devices:
- GET /api/private/admin/sessions
- DELETE /api/private/admin/sessions/{session_id}
- GET /api/private/admin/trusted-devices
- DELETE /api/private/admin/trusted-devices/{device_id}
Auth:
web clients request `/oauth/token` and `/api/v2/session/verify` with `X-UUID` header to save the client as trusted device.
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -217,10 +217,12 @@ async def store_token(
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
expires_in: int,
|
||||
refresh_token_expires_in: int,
|
||||
allow_multiple_devices: bool = True,
|
||||
) -> OAuthToken:
|
||||
"""存储令牌到数据库(支持多设备)"""
|
||||
expires_at = utcnow() + timedelta(seconds=expires_in)
|
||||
refresh_token_expires_at = utcnow() + timedelta(seconds=refresh_token_expires_in)
|
||||
|
||||
if not allow_multiple_devices:
|
||||
# 旧的行为:删除用户的旧令牌(单设备模式)
|
||||
@@ -266,6 +268,7 @@ async def store_token(
|
||||
scope=",".join(scopes),
|
||||
refresh_token=refresh_token,
|
||||
expires_at=expires_at,
|
||||
refresh_token_expires_at=refresh_token_expires_at,
|
||||
)
|
||||
db.add(token_record)
|
||||
await db.commit()
|
||||
@@ -290,7 +293,7 @@ async def get_token_by_refresh_token(db: AsyncSession, refresh_token: str) -> OA
|
||||
"""根据刷新令牌获取令牌记录"""
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.refresh_token == refresh_token,
|
||||
OAuthToken.expires_at > utcnow(),
|
||||
OAuthToken.refresh_token_expires_at > utcnow(),
|
||||
)
|
||||
return (await db.exec(statement)).first()
|
||||
|
||||
|
||||
@@ -170,6 +170,11 @@ STORAGE_SETTINGS='{
|
||||
Field(default=1440, description="访问令牌过期时间(分钟)"),
|
||||
"JWT 设置",
|
||||
]
|
||||
refresh_token_expire_minutes: Annotated[
|
||||
int,
|
||||
Field(default=21600, description="刷新令牌过期时间(分钟)"),
|
||||
"JWT 设置",
|
||||
] # 15 days
|
||||
jwt_audience: Annotated[
|
||||
str,
|
||||
Field(default="5", description="JWT 受众"),
|
||||
@@ -349,11 +354,6 @@ STORAGE_SETTINGS='{
|
||||
Field(default=30, description="设备信任持续天数"),
|
||||
"验证服务设置",
|
||||
]
|
||||
location_trust_duration_days: Annotated[
|
||||
int,
|
||||
Field(default=90, description="位置信任持续天数"),
|
||||
"验证服务设置",
|
||||
]
|
||||
smtp_server: Annotated[
|
||||
str,
|
||||
Field(default="localhost", description="SMTP 服务器地址"),
|
||||
|
||||
@@ -3,3 +3,5 @@ from __future__ import annotations
|
||||
BANCHOBOT_ID = 2
|
||||
|
||||
BACKUP_CODE_LENGTH = 10
|
||||
|
||||
SUPPORT_TOTP_VERIFICATION_VER = 20250913
|
||||
|
||||
@@ -68,7 +68,7 @@ from .user_account_history import (
|
||||
UserAccountHistoryType,
|
||||
)
|
||||
from .user_login_log import UserLoginLog
|
||||
from .verification import EmailVerification, LoginSession
|
||||
from .verification import EmailVerification, LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp
|
||||
|
||||
__all__ = [
|
||||
"APIUploadedRoom",
|
||||
@@ -96,6 +96,7 @@ __all__ = [
|
||||
"ItemAttemptsCount",
|
||||
"ItemAttemptsResp",
|
||||
"LoginSession",
|
||||
"LoginSessionResp",
|
||||
"MeResp",
|
||||
"MonthlyPlaycounts",
|
||||
"MultiplayerEvent",
|
||||
@@ -131,6 +132,8 @@ __all__ = [
|
||||
"TeamMember",
|
||||
"TeamRequest",
|
||||
"TotpKeys",
|
||||
"TrustedDevice",
|
||||
"TrustedDeviceResp",
|
||||
"User",
|
||||
"UserAccountHistory",
|
||||
"UserAccountHistoryResp",
|
||||
|
||||
@@ -32,7 +32,8 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
refresh_token: str = Field(max_length=500, unique=True)
|
||||
token_type: str = Field(default="Bearer", max_length=20)
|
||||
scope: str = Field(default="*", max_length=100)
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime))
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
refresh_token_expires_at: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime))
|
||||
|
||||
user: "User" = Relationship()
|
||||
|
||||
@@ -243,7 +243,6 @@ class UserResp(UserBase):
|
||||
user_achievements: list[UserAchievementResp] = Field(default_factory=list)
|
||||
cover_url: str = "" # deprecated
|
||||
team: Team | None = None
|
||||
session_verified: bool = True
|
||||
daily_challenge_user_stats: DailyChallengeStatsResp | None = None
|
||||
default_group: str = ""
|
||||
is_deleted: bool = False # TODO
|
||||
@@ -425,27 +424,18 @@ class UserResp(UserBase):
|
||||
)
|
||||
).one()
|
||||
|
||||
if "session_verified" in include:
|
||||
from app.service.verification_service import LoginSessionService
|
||||
|
||||
u.session_verified = (
|
||||
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
|
||||
if token_id
|
||||
else True
|
||||
)
|
||||
|
||||
return u
|
||||
|
||||
|
||||
class MeResp(UserResp):
|
||||
session_verification_method: Literal["totp", "mail"] | None = None
|
||||
session_verified: bool = True
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
obj: User,
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
ruleset: GameMode | None = None,
|
||||
*,
|
||||
token_id: int | None = None,
|
||||
@@ -453,7 +443,12 @@ class MeResp(UserResp):
|
||||
from app.dependencies.database import get_redis
|
||||
from app.service.verification_service import LoginSessionService
|
||||
|
||||
u = await super().from_db(obj, session, ["session_verified", *include], ruleset, token_id=token_id)
|
||||
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id)
|
||||
u.session_verified = (
|
||||
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
|
||||
if token_id
|
||||
else True
|
||||
)
|
||||
u = cls.model_validate(u.model_dump())
|
||||
if (settings.enable_totp_verification or settings.enable_email_verification) and token_id:
|
||||
redis = get_redis()
|
||||
|
||||
@@ -3,17 +3,26 @@
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
from app.utils import utcnow
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
from app.models.model import UserAgentInfo, UTCBaseModel
|
||||
from app.utils import extract_user_agent, utcnow
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import BigInteger, Column, ForeignKey
|
||||
from sqlmodel import Field, Integer, Relationship, SQLModel
|
||||
from sqlmodel import VARCHAR, DateTime, Field, Integer, Relationship, SQLModel, Text
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .auth import OAuthToken
|
||||
|
||||
|
||||
class Location(BaseModel):
|
||||
country: str = ""
|
||||
city: str = ""
|
||||
country_code: str = ""
|
||||
|
||||
|
||||
class EmailVerification(SQLModel, table=True):
|
||||
"""邮件验证记录"""
|
||||
|
||||
@@ -31,25 +40,90 @@ class EmailVerification(SQLModel, table=True):
|
||||
user_agent: str | None = Field(default=None) # 用户代理
|
||||
|
||||
|
||||
class LoginSession(SQLModel, table=True):
|
||||
class LoginSessionBase(SQLModel):
|
||||
"""登录会话记录"""
|
||||
|
||||
__tablename__: str = "login_sessions"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||
token_id: int | None = Field(
|
||||
sa_column=Column(Integer, ForeignKey("oauth_tokens.id", ondelete="SET NULL"), nullable=True, index=True)
|
||||
)
|
||||
ip_address: str = Field() # 登录IP
|
||||
user_agent: str | None = Field(default=None, max_length=250)
|
||||
country_code: str | None = Field(default=None)
|
||||
ip_address: str = Field(sa_column=Column(VARCHAR(45), nullable=False), default="127.0.0.1", exclude=True)
|
||||
user_agent: str | None = Field(default=None, sa_column=Column(Text))
|
||||
is_verified: bool = Field(default=False) # 是否已验证
|
||||
created_at: datetime = Field(default_factory=lambda: utcnow())
|
||||
verified_at: datetime | None = Field(default=None)
|
||||
expires_at: datetime = Field() # 会话过期时间
|
||||
is_new_location: bool = Field(default=False) # 是否新位置登录
|
||||
session_token: str | None = Field(default=None, max_length=64, index=True) # 会话令牌
|
||||
verification_method: str | None = Field(default=None, max_length=20) # 验证方法 (totp/mail)
|
||||
device_id: int | None = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("trusted_devices.id", ondelete="SET NULL"), nullable=True, index=True),
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class LoginSession(LoginSessionBase, table=True):
|
||||
__tablename__: str = "login_sessions"
|
||||
token_id: int | None = Field(
|
||||
sa_column=Column(Integer, ForeignKey("oauth_tokens.id", ondelete="SET NULL"), nullable=True, index=True),
|
||||
exclude=True,
|
||||
)
|
||||
is_new_device: bool = Field(default=False, exclude=True) # 是否新位置登录
|
||||
web_uuid: str | None = Field(sa_column=Column(VARCHAR(36), nullable=True), default=None, exclude=True)
|
||||
verification_method: str | None = Field(default=None, max_length=20, exclude=True) # 验证方法 (totp/mail)
|
||||
|
||||
device: Optional["TrustedDevice"] = Relationship(back_populates="sessions")
|
||||
token: Optional["OAuthToken"] = Relationship(back_populates="login_session")
|
||||
|
||||
|
||||
class LoginSessionResp(UTCBaseModel, LoginSessionBase):
|
||||
user_agent_info: UserAgentInfo | None = None
|
||||
location: Location | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, obj: LoginSession, get_geoip_helper: GeoIPHelper) -> "LoginSessionResp":
|
||||
session = cls.model_validate(obj.model_dump())
|
||||
session.user_agent_info = extract_user_agent(session.user_agent)
|
||||
if obj.ip_address:
|
||||
loc = get_geoip_helper.lookup(obj.ip_address)
|
||||
session.location = Location(
|
||||
country=loc.get("country_name", ""),
|
||||
city=loc.get("city_name", ""),
|
||||
country_code=loc.get("country_code", ""),
|
||||
)
|
||||
else:
|
||||
session.location = None
|
||||
return session
|
||||
|
||||
|
||||
class TrustedDeviceBase(SQLModel):
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||
ip_address: str = Field(sa_column=Column(VARCHAR(45), nullable=False), default="127.0.0.1", exclude=True)
|
||||
user_agent: str = Field(sa_column=Column(Text, nullable=False))
|
||||
client_type: Literal["web", "client"] = Field(sa_column=Column(VARCHAR(10), nullable=False), default="web")
|
||||
created_at: datetime = Field(default_factory=utcnow)
|
||||
last_used_at: datetime = Field(default_factory=utcnow)
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime))
|
||||
|
||||
|
||||
class TrustedDevice(TrustedDeviceBase, table=True):
|
||||
__tablename__: str = "trusted_devices"
|
||||
web_uuid: str | None = Field(sa_column=Column(VARCHAR(36), nullable=True), default=None)
|
||||
|
||||
sessions: list["LoginSession"] = Relationship(back_populates="device", passive_deletes=True)
|
||||
|
||||
|
||||
class TrustedDeviceResp(UTCBaseModel, TrustedDeviceBase):
|
||||
user_agent_info: UserAgentInfo | None = None
|
||||
location: Location | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, device: TrustedDevice, get_geoip_helper: GeoIPHelper) -> "TrustedDeviceResp":
|
||||
device_ = cls.model_validate(device.model_dump())
|
||||
device_.user_agent_info = extract_user_agent(device_.user_agent)
|
||||
if device_.ip_address:
|
||||
loc = get_geoip_helper.lookup(device_.ip_address)
|
||||
device_.location = Location(
|
||||
country=loc.get("country_name", ""),
|
||||
city=loc.get("city_name", ""),
|
||||
country_code=loc.get("country_code", ""),
|
||||
)
|
||||
else:
|
||||
device_.location = None
|
||||
return device_
|
||||
|
||||
15
app/dependencies/user_agent.py
Normal file
15
app/dependencies/user_agent.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.models.model import UserAgentInfo as UserAgentInfoModel
|
||||
from app.utils import extract_user_agent
|
||||
|
||||
from fastapi import Depends, Header
|
||||
|
||||
|
||||
def get_user_agent_info(user_agent: str | None = Header(None, include_in_schema=False)) -> UserAgentInfoModel:
|
||||
return extract_user_agent(user_agent)
|
||||
|
||||
|
||||
UserAgentInfo = Annotated[UserAgentInfoModel, Depends(get_user_agent_info)]
|
||||
@@ -1,301 +0,0 @@
|
||||
"""
|
||||
会话验证中间件和状态管理
|
||||
|
||||
基于osu-web的会话验证系统实现
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import ClassVar, Literal, cast
|
||||
|
||||
from app.database.lazer_user import User
|
||||
from app.database.verification import LoginSession
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.log import logger
|
||||
from app.service.verification_service import LoginSessionService
|
||||
from app.utils import bg_tasks
|
||||
|
||||
from fastapi import HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class SessionVerificationState:
|
||||
"""会话验证状态管理类
|
||||
|
||||
参考osu-web的State类实现
|
||||
"""
|
||||
|
||||
def __init__(self, session: LoginSession, user: User, redis: Redis):
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.redis = redis
|
||||
|
||||
@classmethod
|
||||
async def get_current(
|
||||
cls,
|
||||
request: Request,
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user: User,
|
||||
) -> SessionVerificationState | None:
|
||||
"""获取当前会话验证状态"""
|
||||
try:
|
||||
# 从请求头或token中获取会话信息
|
||||
session_token = cls._extract_session_token(request)
|
||||
if not session_token:
|
||||
return None
|
||||
|
||||
# 查找会话
|
||||
session = await LoginSessionService.find_for_verification(db, session_token)
|
||||
if not session or session.user_id != user.id:
|
||||
return None
|
||||
|
||||
return cls(session, user, redis)
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification] Error getting current state: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_token(request: Request) -> str | None:
|
||||
"""从请求中提取会话token"""
|
||||
# 尝试从Authorization header提取
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header[7:] # 移除"Bearer "前缀
|
||||
|
||||
# 可以扩展其他提取方式
|
||||
return None
|
||||
|
||||
def get_method(self) -> str:
|
||||
"""获取验证方法
|
||||
|
||||
参考osu-web的逻辑,智能选择验证方法
|
||||
"""
|
||||
current_method = self.session.verification_method
|
||||
|
||||
if current_method is None:
|
||||
# 智能选择验证方法
|
||||
# 参考osu-web: API版本 < 20250913 或用户没有TOTP时使用邮件验证
|
||||
# 这里简化为检查用户是否有TOTP
|
||||
totp_key = getattr(self.user, "totp_key", None)
|
||||
current_method = "totp" if totp_key else "mail"
|
||||
|
||||
# 设置验证方法
|
||||
bg_tasks.add_task(self._set_verification_method, current_method)
|
||||
|
||||
return current_method
|
||||
|
||||
async def _set_verification_method(self, method: str) -> None:
|
||||
"""内部方法:设置验证方法"""
|
||||
try:
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None and method in ["totp", "mail"]:
|
||||
# 类型检查确保method是正确的字面量类型
|
||||
verification_method = method if method in ["totp", "mail"] else "totp"
|
||||
await LoginSessionService.set_login_method(
|
||||
self.user.id,
|
||||
token_id,
|
||||
cast(Literal["totp", "mail"], verification_method),
|
||||
self.redis,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification] Error setting verification method: {e}")
|
||||
|
||||
def is_verified(self) -> bool:
|
||||
"""检查会话是否已验证"""
|
||||
return self.session.is_verified
|
||||
|
||||
async def mark_verified(self) -> None:
|
||||
"""标记会话为已验证"""
|
||||
try:
|
||||
async with with_db() as db:
|
||||
# 创建专用数据库会话
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
await LoginSessionService.mark_session_verified(db, self.redis, self.user.id, token_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification] Error marking session verified: {e}")
|
||||
|
||||
def get_key(self) -> str:
|
||||
"""获取会话密钥"""
|
||||
return str(self.session.id) if self.session.id else ""
|
||||
|
||||
def get_key_for_event(self) -> str:
|
||||
"""获取用于事件广播的会话密钥"""
|
||||
return LoginSessionService.get_key_for_event(self.get_key())
|
||||
|
||||
def user_id(self) -> int:
|
||||
"""获取用户ID"""
|
||||
return self.user.id
|
||||
|
||||
async def issue_mail_if_needed(self) -> None:
|
||||
"""如果需要,发送验证邮件"""
|
||||
try:
|
||||
if self.get_method() == "mail":
|
||||
from app.service.verification_service import EmailVerificationService
|
||||
|
||||
# 创建专用数据库会话发送邮件
|
||||
async with with_db() as db:
|
||||
await EmailVerificationService.send_verification_email(
|
||||
db, self.redis, self.user.id, self.user.username, self.user.email, None, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification] Error issuing mail: {e}")
|
||||
|
||||
|
||||
class SessionVerificationController:
|
||||
"""会话验证控制器
|
||||
|
||||
参考osu-web的Controller类实现
|
||||
"""
|
||||
|
||||
# 需要跳过验证的路由(参考osu-web的SKIP_VERIFICATION_ROUTES)
|
||||
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
|
||||
"/api/v2/session/verify",
|
||||
"/api/v2/session/verify/reissue",
|
||||
"/api/v2/me",
|
||||
"/api/v2/logout",
|
||||
"/oauth/token",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def should_skip_verification(request: Request) -> bool:
|
||||
"""检查是否应该跳过验证"""
|
||||
path = request.url.path
|
||||
return path in SessionVerificationController.SKIP_VERIFICATION_ROUTES
|
||||
|
||||
@staticmethod
|
||||
async def initiate_verification(
|
||||
state: SessionVerificationState,
|
||||
request: Request,
|
||||
) -> Response:
|
||||
"""启动会话验证流程
|
||||
|
||||
参考osu-web的initiate方法
|
||||
"""
|
||||
try:
|
||||
method = state.get_method()
|
||||
|
||||
# 如果是邮件验证,发送验证邮件
|
||||
if method == "mail":
|
||||
await state.issue_mail_if_needed()
|
||||
|
||||
# API请求返回JSON响应
|
||||
if request.url.path.startswith("/api/"):
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": method})
|
||||
|
||||
# 其他情况可以扩展支持HTML响应
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"authentication": "verify", "method": method, "message": "Session verification required"},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification] Error initiating verification: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Verification initiation failed"
|
||||
)
|
||||
|
||||
|
||||
class SessionVerificationMiddleware:
|
||||
"""会话验证中间件
|
||||
|
||||
参考osu-web的VerifyUser中间件实现
|
||||
"""
|
||||
|
||||
def __init__(self, app: Callable[[Request], Awaitable[Response]]):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
"""中间件主要逻辑"""
|
||||
try:
|
||||
# 检查是否需要跳过验证
|
||||
if SessionVerificationController.should_skip_verification(request):
|
||||
return await call_next(request)
|
||||
|
||||
# 获取依赖项
|
||||
user = await self._get_user(request)
|
||||
if not user:
|
||||
# 未认证用户跳过验证
|
||||
return await call_next(request)
|
||||
|
||||
# 获取数据库和Redis连接
|
||||
async with with_db() as db:
|
||||
redis = await self._get_redis()
|
||||
|
||||
# 获取会话验证状态
|
||||
state = await SessionVerificationState.get_current(request, db, redis, user)
|
||||
if not state:
|
||||
# 无法获取会话状态,继续请求
|
||||
return await call_next(request)
|
||||
|
||||
# 检查是否已验证
|
||||
if state.is_verified():
|
||||
# 已验证,继续请求
|
||||
return await call_next(request)
|
||||
|
||||
# 检查是否需要验证
|
||||
if not self._requires_verification(request):
|
||||
return await call_next(request)
|
||||
|
||||
# 启动验证流程
|
||||
return await SessionVerificationController.initiate_verification(state, request)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Session Verification Middleware] Unexpected error: {e}")
|
||||
# 出错时允许请求继续,避免阻塞正常流程
|
||||
return await call_next(request)
|
||||
|
||||
async def _get_user(self, request: Request) -> User | None:
|
||||
"""获取当前用户"""
|
||||
try:
|
||||
# 这里需要手动获取用户,因为在中间件中无法直接使用依赖注入
|
||||
# 简化实现,实际应该从token中解析用户
|
||||
return None # 暂时返回None,需要实际实现
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _get_redis(self) -> Redis:
|
||||
"""获取Redis连接"""
|
||||
return get_redis()
|
||||
|
||||
def _requires_verification(self, request: Request) -> bool:
|
||||
"""检查是否需要验证
|
||||
|
||||
参考osu-web的requiresVerification方法
|
||||
"""
|
||||
method = request.method
|
||||
|
||||
# GET/HEAD/OPTIONS请求一般不需要验证
|
||||
safe_methods = {"GET", "HEAD", "OPTIONS"}
|
||||
if method in safe_methods:
|
||||
return False
|
||||
|
||||
# POST/PUT/DELETE等修改操作需要验证
|
||||
return True
|
||||
|
||||
|
||||
# FastAPI中间件包装器
|
||||
class FastAPISessionVerificationMiddleware:
|
||||
"""FastAPI会话验证中间件包装器"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] != "http":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
request = Request(scope, receive)
|
||||
|
||||
async def call_next(req: Request) -> Response:
|
||||
# 这里需要调用FastAPI应用
|
||||
return Response("OK") # 占位符实现
|
||||
|
||||
middleware = SessionVerificationMiddleware(call_next)
|
||||
response = await middleware(request, call_next)
|
||||
|
||||
await response(scope, receive, send)
|
||||
@@ -10,11 +10,13 @@ from collections.abc import Callable
|
||||
from typing import ClassVar
|
||||
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.const import SUPPORT_TOTP_VERIFICATION_VER
|
||||
from app.database.lazer_user import User
|
||||
from app.database.verification import LoginSession
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.log import logger
|
||||
from app.service.verification_service import LoginSessionService
|
||||
from app.utils import extract_user_agent
|
||||
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -34,7 +36,9 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
|
||||
"/api/v2/session/verify",
|
||||
"/api/v2/session/verify/reissue",
|
||||
"/api/v2/session/verify/mail-fallback",
|
||||
"/api/v2/me",
|
||||
"/api/v2/me/",
|
||||
"/api/v2/logout",
|
||||
"/oauth/token",
|
||||
"/health",
|
||||
@@ -44,10 +48,8 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
"/redoc",
|
||||
}
|
||||
|
||||
# 需要强制验证的路由模式(敏感操作)
|
||||
# 总是需要验证的路由前缀
|
||||
ALWAYS_VERIFY_PATTERNS: ClassVar[set[str]] = {
|
||||
"/api/v2/account/",
|
||||
"/api/v2/settings/",
|
||||
"/api/private/admin/",
|
||||
}
|
||||
|
||||
@@ -110,9 +112,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
# 特权用户或非活跃用户需要验证
|
||||
# if hasattr(user, 'is_privileged') and user.is_privileged():
|
||||
# return True
|
||||
if not user.is_active:
|
||||
return True
|
||||
|
||||
@@ -154,6 +153,14 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
try:
|
||||
# 提取会话token(这里简化为使用相同的auth token)
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
api_version = 0
|
||||
raw_api_version = request.headers.get("x-api-version")
|
||||
if raw_api_version is not None:
|
||||
try:
|
||||
api_version = int(raw_api_version)
|
||||
except ValueError:
|
||||
api_version = 0
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
@@ -168,7 +175,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
if not session or session.user_id != user.id:
|
||||
return None
|
||||
|
||||
return SessionState(session, user, redis, db)
|
||||
return SessionState(session, user, redis, db, api_version)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Verify Session Middleware] Error getting session state: {e}")
|
||||
@@ -178,8 +185,6 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
"""启动验证流程"""
|
||||
try:
|
||||
method = await state.get_method()
|
||||
|
||||
# 如果是邮件验证,可以在这里触发发送邮件
|
||||
if method == "mail":
|
||||
await state.issue_mail_if_needed()
|
||||
|
||||
@@ -202,11 +207,12 @@ class SessionState:
|
||||
简化版本的会话状态管理
|
||||
"""
|
||||
|
||||
def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession):
|
||||
def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession, api_version: int = 0) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.redis = redis
|
||||
self.db = db
|
||||
self.api_version = api_version
|
||||
self._verification_method: str | None = None
|
||||
|
||||
def is_verified(self) -> bool:
|
||||
@@ -223,14 +229,15 @@ class SessionState:
|
||||
self.user.id, token_id, self.redis
|
||||
)
|
||||
|
||||
# 如果没有设置,智能选择
|
||||
if self._verification_method is None:
|
||||
# 检查用户是否有TOTP密钥
|
||||
await self.user.awaitable_attrs.totp_key # 预加载
|
||||
totp_key = getattr(self.user, "totp_key", None)
|
||||
if self.api_version < SUPPORT_TOTP_VERIFICATION_VER:
|
||||
self._verification_method = "mail"
|
||||
return self._verification_method
|
||||
|
||||
await self.user.awaitable_attrs.totp_key
|
||||
totp_key = self.user.totp_key
|
||||
self._verification_method = "totp" if totp_key else "mail"
|
||||
|
||||
# 保存选择的方法
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
await LoginSessionService.set_login_method(
|
||||
@@ -244,8 +251,15 @@ class SessionState:
|
||||
try:
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
await LoginSessionService.mark_session_verified(self.db, self.redis, self.user.id, token_id)
|
||||
self.session.is_verified = True # 更新本地状态
|
||||
await LoginSessionService.mark_session_verified(
|
||||
self.db,
|
||||
self.redis,
|
||||
self.user.id,
|
||||
token_id,
|
||||
self.session.ip_address,
|
||||
extract_user_agent(self.session.user_agent),
|
||||
self.session.web_uuid,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Session State] Error marking verified: {e}")
|
||||
|
||||
@@ -266,10 +280,12 @@ class SessionState:
|
||||
"""获取会话密钥"""
|
||||
return str(self.session.id) if self.session.id else ""
|
||||
|
||||
def get_key_for_event(self) -> str:
|
||||
@property
|
||||
def key_for_event(self) -> str:
|
||||
"""获取用于事件广播的会话密钥"""
|
||||
return LoginSessionService.get_key_for_event(self.get_key())
|
||||
|
||||
@property
|
||||
def user_id(self) -> int:
|
||||
"""获取用户ID"""
|
||||
return self.user.id
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.models.score import GameMode
|
||||
@@ -53,3 +54,33 @@ class CurrentUserAttributes(BaseModel):
|
||||
can_new_comment: bool | None = None
|
||||
can_new_comment_reason: str | None = None
|
||||
pin: PinAttributes | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserAgentInfo:
|
||||
raw_ua: str = ""
|
||||
browser: str | None = None
|
||||
version: str | None = None
|
||||
os: str | None = None
|
||||
platform: str | None = None
|
||||
is_mobile: bool = False
|
||||
is_tablet: bool = False
|
||||
is_pc: bool = False
|
||||
is_client: bool = False
|
||||
|
||||
@property
|
||||
def displayed_name(self) -> str:
|
||||
parts = []
|
||||
if self.browser:
|
||||
parts.append(self.browser)
|
||||
if self.version:
|
||||
parts.append(self.version)
|
||||
if self.os:
|
||||
if parts:
|
||||
parts.append(f"on {self.os}")
|
||||
else:
|
||||
parts.append(self.os)
|
||||
return " ".join(parts) if parts else "Unknown"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.displayed_name
|
||||
|
||||
@@ -21,6 +21,7 @@ from app.database.auth import TotpKeys
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.dependencies.user_agent import UserAgentInfo
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
from app.log import logger
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
@@ -39,7 +40,7 @@ from app.service.verification_service import (
|
||||
)
|
||||
from app.utils import utcnow
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
from fastapi import APIRouter, Depends, Form, Header, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import text
|
||||
@@ -199,6 +200,7 @@ async def register_user(
|
||||
async def oauth_token(
|
||||
db: Database,
|
||||
request: Request,
|
||||
user_agent: UserAgentInfo,
|
||||
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
|
||||
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
|
||||
),
|
||||
@@ -211,12 +213,10 @@ async def oauth_token(
|
||||
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
|
||||
):
|
||||
scopes = scope.split(" ")
|
||||
|
||||
# 打印请求头
|
||||
# logger.info(f"Request headers: {request.headers}")
|
||||
|
||||
client = (
|
||||
await db.exec(
|
||||
select(OAuthClient).where(
|
||||
@@ -306,19 +306,19 @@ async def oauth_token(
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
settings.refresh_token_expire_minutes * 60,
|
||||
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
|
||||
)
|
||||
token_id = token.id
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
# 获取国家代码
|
||||
geo_info = geoip.lookup(ip_address)
|
||||
country_code = geo_info.get("country_iso", "XX")
|
||||
|
||||
# 检查是否为新位置登录
|
||||
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
|
||||
trusted_device = await LoginSessionService.check_trusted_device(db, user_id, ip_address, user_agent, web_uuid)
|
||||
|
||||
session_verification_method = None
|
||||
if settings.enable_totp_verification and totp_key is not None:
|
||||
@@ -331,18 +331,12 @@ async def oauth_token(
|
||||
login_method="password_pending_verification",
|
||||
notes="需要 TOTP 验证",
|
||||
)
|
||||
elif is_new_location and settings.enable_email_verification:
|
||||
# 如果是新位置登录,需要邮件验证
|
||||
elif not trusted_device and settings.enable_email_verification:
|
||||
# 如果是新设备登录,需要邮件验证
|
||||
# 刷新用户对象以确保属性已加载
|
||||
await db.refresh(user)
|
||||
session_verification_method = "mail"
|
||||
|
||||
# 使用智能验证发送邮件
|
||||
(
|
||||
verification_sent,
|
||||
verification_message,
|
||||
client_info,
|
||||
) = await EmailVerificationService.send_smart_verification_email(
|
||||
await EmailVerificationService.send_verification_email(
|
||||
db,
|
||||
redis,
|
||||
user_id,
|
||||
@@ -350,36 +344,30 @@ async def oauth_token(
|
||||
user.email,
|
||||
ip_address,
|
||||
user_agent,
|
||||
client_id,
|
||||
country_code,
|
||||
is_new_location,
|
||||
)
|
||||
|
||||
# 记录需要二次验证的登录尝试
|
||||
client_display_name = client_info.client_type if client_info else "unknown"
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
request=request,
|
||||
login_success=True,
|
||||
login_method="password_pending_verification",
|
||||
notes=f"智能验证: {verification_message} - 客户端: {client_display_name}, "
|
||||
f"IP: {ip_address}, 国家: {country_code}",
|
||||
notes=(
|
||||
f"邮箱验证: User-Agent: {user_agent.raw_ua}, 客户端: {user_agent.displayed_name} "
|
||||
f"IP: {ip_address}, 国家: {country_code}"
|
||||
),
|
||||
)
|
||||
elif not trusted_device:
|
||||
# 新设备登录但邮件验证功能被禁用,直接标记会话为已验证
|
||||
await LoginSessionService.mark_session_verified(
|
||||
db, redis, user_id, token_id, ip_address, user_agent, web_uuid
|
||||
)
|
||||
|
||||
if not verification_sent:
|
||||
# 邮件发送失败,记录错误
|
||||
logger.error(f"[Auth] Smart verification failed for user {user_id}: {verification_message}")
|
||||
else:
|
||||
logger.info(f"[Auth] Smart verification result for user {user_id}: {verification_message}")
|
||||
elif is_new_location:
|
||||
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, redis, user_id, token_id)
|
||||
logger.debug(
|
||||
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
|
||||
)
|
||||
else:
|
||||
# 不是新位置登录,正常登录
|
||||
# 不是新设备登录,正常登录
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
@@ -391,12 +379,12 @@ async def oauth_token(
|
||||
|
||||
if session_verification_method:
|
||||
await LoginSessionService.create_session(
|
||||
db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, False
|
||||
db, user_id, token_id, ip_address, user_agent.raw_ua, trusted_device, web_uuid, False
|
||||
)
|
||||
await LoginSessionService.set_login_method(user_id, token_id, session_verification_method, redis)
|
||||
else:
|
||||
await LoginSessionService.create_session(
|
||||
db, redis, user_id, token_id, ip_address, user_agent, country_code, is_new_location, True
|
||||
db, user_id, token_id, ip_address, user_agent.raw_ua, trusted_device, web_uuid, True
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
@@ -449,6 +437,7 @@ async def oauth_token(
|
||||
access_token,
|
||||
new_refresh_token,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
settings.refresh_token_expire_minutes * 60,
|
||||
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
|
||||
)
|
||||
return TokenResponse(
|
||||
@@ -514,6 +503,7 @@ async def oauth_token(
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
settings.refresh_token_expire_minutes * 60,
|
||||
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
|
||||
)
|
||||
|
||||
@@ -561,6 +551,7 @@ async def oauth_token(
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
settings.refresh_token_expire_minutes * 60,
|
||||
allow_multiple_devices=settings.enable_multi_device_login, # 使用配置决定是否启用多设备支持
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from . import audio_proxy, avatar, beatmapset, cover, oauth, relationship, score, team, username # noqa: F401
|
||||
from . import admin, audio_proxy, avatar, beatmapset, cover, oauth, relationship, score, team, username # noqa: F401
|
||||
from .router import router as private_router
|
||||
|
||||
if settings.enable_totp_verification:
|
||||
|
||||
157
app/router/private/admin.py
Normal file
157
app/router/private/admin.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.auth import OAuthToken
|
||||
from app.database.verification import LoginSession, LoginSessionResp, TrustedDevice, TrustedDeviceResp
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.geoip import get_geoip_helper
|
||||
from app.dependencies.user import UserAndToken, get_client_user_and_token
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, select
|
||||
|
||||
|
||||
class SessionsResp(BaseModel):
|
||||
total: int
|
||||
current: int = 0
|
||||
sessions: list[LoginSessionResp]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/admin/sessions",
|
||||
name="获取当前用户的登录会话列表",
|
||||
tags=["用户会话", "g0v0 API", "管理"],
|
||||
response_model=SessionsResp,
|
||||
)
|
||||
async def get_sessions(
|
||||
session: Database,
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
):
|
||||
current_user, token = user_and_token
|
||||
sessions = (
|
||||
await session.exec(
|
||||
select(
|
||||
LoginSession,
|
||||
)
|
||||
.where(LoginSession.user_id == current_user.id, col(LoginSession.is_verified).is_(True))
|
||||
.order_by(col(LoginSession.created_at).desc())
|
||||
)
|
||||
).all()
|
||||
return SessionsResp(
|
||||
total=len(sessions),
|
||||
current=token.id,
|
||||
sessions=[LoginSessionResp.from_db(s, geoip) for s in sessions],
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/admin/sessions/{session_id}",
|
||||
name="注销指定的登录会话",
|
||||
tags=["用户会话", "g0v0 API", "管理"],
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_session(
|
||||
session: Database,
|
||||
session_id: int,
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
):
|
||||
current_user, token = user_and_token
|
||||
if session_id == token.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete the current session")
|
||||
|
||||
db_session = await session.get(LoginSession, session_id)
|
||||
if not db_session or db_session.user_id != current_user.id:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
await session.delete(db_session)
|
||||
|
||||
token = await session.get(OAuthToken, db_session.token_id or 0)
|
||||
if token:
|
||||
await session.delete(token)
|
||||
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
|
||||
class TrustedDevicesResp(BaseModel):
|
||||
total: int
|
||||
current: int = 0
|
||||
devices: list[TrustedDeviceResp]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/admin/trusted-devices",
|
||||
name="获取当前用户的受信任设备列表",
|
||||
tags=["用户会话", "g0v0 API", "管理"],
|
||||
response_model=TrustedDevicesResp,
|
||||
)
|
||||
async def get_trusted_devices(
|
||||
session: Database,
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
):
|
||||
current_user, token = user_and_token
|
||||
devices = (
|
||||
await session.exec(
|
||||
select(TrustedDevice)
|
||||
.where(TrustedDevice.user_id == current_user.id)
|
||||
.order_by(col(TrustedDevice.last_used_at).desc())
|
||||
)
|
||||
).all()
|
||||
|
||||
current_device_id = (
|
||||
await session.exec(
|
||||
select(TrustedDevice.id)
|
||||
.join(LoginSession, col(LoginSession.device_id) == TrustedDevice.id)
|
||||
.where(
|
||||
LoginSession.token_id == token.id,
|
||||
TrustedDevice.user_id == current_user.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).first()
|
||||
|
||||
return TrustedDevicesResp(
|
||||
total=len(devices),
|
||||
current=current_device_id or 0,
|
||||
devices=[TrustedDeviceResp.from_db(device, geoip) for device in devices],
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/admin/trusted-devices/{device_id}",
|
||||
name="移除受信任设备",
|
||||
tags=["用户会话", "g0v0 API", "管理"],
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_trusted_device(
|
||||
session: Database,
|
||||
device_id: int,
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
):
|
||||
current_user, token = user_and_token
|
||||
device = await session.get(TrustedDevice, device_id)
|
||||
current_device_id = (
|
||||
await session.exec(
|
||||
select(TrustedDevice.id)
|
||||
.join(LoginSession, col(LoginSession.device_id) == TrustedDevice.id)
|
||||
.where(
|
||||
LoginSession.token_id == token.id,
|
||||
TrustedDevice.user_id == current_user.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).first()
|
||||
if device_id == current_device_id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete the current trusted device")
|
||||
|
||||
if not device or device.user_id != current_user.id:
|
||||
raise HTTPException(status_code=404, detail="Trusted device not found")
|
||||
|
||||
await session.delete(device)
|
||||
await session.commit()
|
||||
return
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import MeResp, User
|
||||
from app.database.lazer_user import ALL_INCLUDED
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database
|
||||
from app.dependencies.user import UserAndToken, get_current_user_and_token
|
||||
@@ -33,7 +32,7 @@ async def get_user_info_with_ruleset(
|
||||
ruleset: GameMode = Path(description="指定 ruleset"),
|
||||
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
|
||||
):
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, ruleset, token_id=user_and_token[1].id)
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
|
||||
return user_resp
|
||||
|
||||
|
||||
@@ -48,7 +47,7 @@ async def get_user_info_default(
|
||||
session: Database,
|
||||
user_and_token: UserAndToken = Security(get_current_user_and_token, scopes=["identify"]),
|
||||
):
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, ALL_INCLUDED, None, token_id=user_and_token[1].id)
|
||||
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
|
||||
return user_resp
|
||||
|
||||
|
||||
|
||||
@@ -8,12 +8,13 @@ from typing import Annotated, Literal
|
||||
|
||||
from app.auth import check_totp_backup_code, verify_totp_key_with_replay_protection
|
||||
from app.config import settings
|
||||
from app.const import BACKUP_CODE_LENGTH
|
||||
from app.const import BACKUP_CODE_LENGTH, SUPPORT_TOTP_VERIFICATION_VER
|
||||
from app.database.auth import TotpKeys
|
||||
from app.dependencies.api_version import APIVersion
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
from app.dependencies.user import UserAndToken, get_client_user_and_token
|
||||
from app.dependencies.user_agent import UserAgentInfo
|
||||
from app.log import logger
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.service.verification_service import (
|
||||
@@ -23,7 +24,7 @@ from app.service.verification_service import (
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Request, Security, status
|
||||
from fastapi import Depends, Form, Header, HTTPException, Request, Security, status
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
@@ -62,9 +63,11 @@ async def verify_session(
|
||||
request: Request,
|
||||
db: Database,
|
||||
api_version: APIVersion,
|
||||
user_agent: UserAgentInfo,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
verification_key: str = Form(..., description="8 位邮件验证码或者 6 位 TOTP 代码或 10 位备份码 (g0v0 扩展支持)"),
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
web_uuid: str | None = Header(None, include_in_schema=False, alias="X-UUID"),
|
||||
) -> Response:
|
||||
current_user = user_and_token[0]
|
||||
token_id = user_and_token[1].id
|
||||
@@ -74,11 +77,12 @@ async def verify_session(
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
verify_method: str | None = (
|
||||
"mail" if api_version < 20250913 else await LoginSessionService.get_login_method(user_id, token_id, redis)
|
||||
"mail"
|
||||
if api_version < SUPPORT_TOTP_VERIFICATION_VER
|
||||
else await LoginSessionService.get_login_method(user_id, token_id, redis)
|
||||
)
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
login_method = "password"
|
||||
|
||||
try:
|
||||
@@ -130,10 +134,11 @@ async def verify_session(
|
||||
user_id=user_id,
|
||||
request=request,
|
||||
login_method=login_method,
|
||||
user_agent=user_agent.raw_ua,
|
||||
login_success=True,
|
||||
notes=f"{login_method} 验证成功",
|
||||
)
|
||||
await LoginSessionService.mark_session_verified(db, redis, user_id, token_id)
|
||||
await LoginSessionService.mark_session_verified(db, redis, user_id, token_id, ip_address, user_agent, web_uuid)
|
||||
await db.commit()
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
@@ -179,6 +184,7 @@ async def verify_session(
|
||||
async def reissue_verification_code(
|
||||
request: Request,
|
||||
db: Database,
|
||||
user_agent: UserAgentInfo,
|
||||
api_version: APIVersion,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
@@ -198,7 +204,6 @@ async def reissue_verification_code(
|
||||
|
||||
try:
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
user_id = current_user.id
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
db,
|
||||
@@ -227,6 +232,7 @@ async def reissue_verification_code(
|
||||
)
|
||||
async def fallback_email(
|
||||
db: Database,
|
||||
user_agent: UserAgentInfo,
|
||||
request: Request,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
user_and_token: UserAndToken = Security(get_client_user_and_token),
|
||||
@@ -237,7 +243,6 @@ async def fallback_email(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="当前会话不需要回退")
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
|
||||
await LoginSessionService.set_login_method(current_user.id, token_id, "mail", redis)
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
"""
|
||||
数据库清理调度器 - 定时清理过期数据
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.dependencies.database import engine
|
||||
from app.log import logger
|
||||
from app.service.database_cleanup_service import DatabaseCleanupService
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class DatabaseCleanupScheduler:
|
||||
"""数据库清理调度器"""
|
||||
|
||||
def __init__(self):
|
||||
self.running = False
|
||||
self.task = None
|
||||
|
||||
async def start(self):
|
||||
"""启动调度器"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._run_scheduler())
|
||||
logger.debug("Database cleanup scheduler started")
|
||||
|
||||
async def stop(self):
|
||||
"""停止调度器"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
try:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.debug("Database cleanup scheduler stopped")
|
||||
|
||||
async def _run_scheduler(self):
|
||||
"""运行调度器"""
|
||||
while self.running:
|
||||
try:
|
||||
# 每小时运行一次清理
|
||||
await asyncio.sleep(3600) # 3600秒 = 1小时
|
||||
|
||||
if not self.running:
|
||||
break
|
||||
|
||||
await self._run_cleanup()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Database cleanup scheduler error: {e!s}")
|
||||
# 发生错误后等待5分钟再继续
|
||||
await asyncio.sleep(300)
|
||||
|
||||
async def _run_cleanup(self):
|
||||
"""执行清理任务"""
|
||||
try:
|
||||
async with AsyncSession(engine) as db:
|
||||
logger.debug("Starting scheduled database cleanup...")
|
||||
|
||||
# 清理过期的验证码
|
||||
expired_codes = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
|
||||
|
||||
# 清理过期的登录会话
|
||||
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||
|
||||
# 清理1小时前未验证的登录会话
|
||||
unverified_sessions = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1)
|
||||
|
||||
# 只在有清理记录时输出总结
|
||||
total_cleaned = expired_codes + expired_sessions + unverified_sessions
|
||||
if total_cleaned > 0:
|
||||
logger.debug(
|
||||
f"Scheduled cleanup completed - codes: {expired_codes}, "
|
||||
f"sessions: {expired_sessions}, unverified: {unverified_sessions}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during scheduled database cleanup: {e!s}")
|
||||
|
||||
async def run_manual_cleanup(self):
|
||||
"""手动运行完整清理"""
|
||||
try:
|
||||
async with AsyncSession(engine) as db:
|
||||
logger.debug("Starting manual database cleanup...")
|
||||
results = await DatabaseCleanupService.run_full_cleanup(db)
|
||||
total = sum(results.values())
|
||||
if total > 0:
|
||||
logger.debug(f"Manual cleanup completed, total records cleaned: {total}")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Error during manual database cleanup: {e!s}")
|
||||
return {}
|
||||
|
||||
|
||||
# 全局实例
|
||||
database_cleanup_scheduler = DatabaseCleanupScheduler()
|
||||
|
||||
|
||||
async def start_database_cleanup_scheduler():
|
||||
"""启动数据库清理调度器"""
|
||||
await database_cleanup_scheduler.start()
|
||||
|
||||
|
||||
async def stop_database_cleanup_scheduler():
|
||||
"""停止数据库清理调度器"""
|
||||
await database_cleanup_scheduler.stop()
|
||||
|
||||
|
||||
async def run_manual_database_cleanup():
|
||||
"""手动运行数据库清理"""
|
||||
return await database_cleanup_scheduler.run_manual_cleanup()
|
||||
@@ -1,230 +0,0 @@
|
||||
"""
|
||||
客户端检测服务
|
||||
用于识别不同类型的 osu! 客户端和设备
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
import re
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientInfo:
|
||||
"""客户端信息"""
|
||||
|
||||
client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"]
|
||||
platform: str | None = None
|
||||
version: str | None = None
|
||||
device_fingerprint: str | None = None
|
||||
is_trusted_client: bool = False
|
||||
|
||||
|
||||
class ClientDetectionService:
|
||||
"""客户端检测服务"""
|
||||
|
||||
# osu! 客户端的 User-Agent 模式
|
||||
OSU_CLIENT_PATTERNS: ClassVar[dict[str, list[str]]] = {
|
||||
"osu_stable": [
|
||||
r"osu!/(\d+(?:\.\d+)*)", # osu!/20241001
|
||||
r"osu!", # 简单匹配
|
||||
],
|
||||
"osu_lazer": [
|
||||
r"osu-lazer/(\d+(?:\.\d+)*)", # osu-lazer/2024.1009.0
|
||||
r"osu!lazer/(\d+(?:\.\d+)*)", # osu!lazer/2024.1009.0
|
||||
],
|
||||
"osu_web": [
|
||||
r"Mozilla.*osu\.ppy\.sh", # 网页客户端
|
||||
],
|
||||
"mobile": [
|
||||
r"osu!.*mobile",
|
||||
r"osu.*Mobile",
|
||||
r"Mobile.*osu",
|
||||
],
|
||||
}
|
||||
|
||||
# 受信任的客户端类型(不需要频繁验证)
|
||||
TRUSTED_CLIENT_TYPES: ClassVar[set[str]] = {"osu_stable", "osu_lazer"}
|
||||
|
||||
@staticmethod
|
||||
def detect_client(user_agent: str | None, client_id: int | None = None) -> ClientInfo:
|
||||
"""
|
||||
检测客户端类型和信息
|
||||
|
||||
Args:
|
||||
user_agent: 用户代理字符串
|
||||
client_id: OAuth 客户端 ID
|
||||
|
||||
Returns:
|
||||
ClientInfo: 客户端信息
|
||||
"""
|
||||
from app.config import settings # 导入在函数内部避免循环导入
|
||||
|
||||
if not user_agent:
|
||||
return ClientInfo(client_type="unknown")
|
||||
|
||||
# 优先通过 client_id 判断客户端类型
|
||||
if client_id is not None:
|
||||
if client_id == settings.osu_client_id:
|
||||
# osu! stable 客户端
|
||||
return ClientInfo(
|
||||
client_type="osu_stable",
|
||||
platform=ClientDetectionService._extract_platform(user_agent),
|
||||
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
|
||||
is_trusted_client=True,
|
||||
)
|
||||
elif client_id == settings.osu_web_client_id:
|
||||
# 检查 User-Agent 是否表明这是 Lazer 客户端
|
||||
if user_agent and user_agent.strip() == "osu!":
|
||||
# Lazer 客户端使用 web client_id 但发送简单的 "osu!" User-Agent
|
||||
return ClientInfo(
|
||||
client_type="osu_lazer",
|
||||
platform=ClientDetectionService._extract_platform(user_agent),
|
||||
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
|
||||
is_trusted_client=True,
|
||||
)
|
||||
else:
|
||||
# 真正的 web 客户端
|
||||
return ClientInfo(
|
||||
client_type="osu_web",
|
||||
platform=ClientDetectionService._extract_platform(user_agent),
|
||||
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
|
||||
is_trusted_client=False,
|
||||
)
|
||||
|
||||
# 回退到基于 User-Agent 的检测
|
||||
for client_type_str, patterns in ClientDetectionService.OSU_CLIENT_PATTERNS.items():
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, user_agent, re.IGNORECASE)
|
||||
if match:
|
||||
version = match.group(1) if match.groups() else None
|
||||
platform = ClientDetectionService._extract_platform(user_agent)
|
||||
|
||||
# 确保 client_type 是正确的 Literal 类型
|
||||
client_type: Literal["osu_stable", "osu_lazer", "osu_web", "mobile", "unknown"] = client_type_str # type: ignore
|
||||
|
||||
return ClientInfo(
|
||||
client_type=client_type,
|
||||
platform=platform,
|
||||
version=version,
|
||||
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
|
||||
is_trusted_client=client_type in ClientDetectionService.TRUSTED_CLIENT_TYPES,
|
||||
)
|
||||
|
||||
# 检测常见浏览器
|
||||
if any(browser in user_agent.lower() for browser in ["chrome", "firefox", "safari", "edge"]):
|
||||
return ClientInfo(
|
||||
client_type="osu_web",
|
||||
platform=ClientDetectionService._extract_platform(user_agent),
|
||||
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
|
||||
is_trusted_client=False,
|
||||
)
|
||||
|
||||
return ClientInfo(
|
||||
client_type="unknown",
|
||||
device_fingerprint=ClientDetectionService._generate_device_fingerprint(user_agent),
|
||||
is_trusted_client=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_platform(user_agent: str) -> str | None:
|
||||
"""从 User-Agent 中提取平台信息"""
|
||||
platforms = {
|
||||
"windows": ["windows", "win32", "win64"],
|
||||
"macos": ["macintosh", "mac os", "darwin"],
|
||||
"linux": ["linux", "ubuntu", "debian"],
|
||||
"android": ["android"],
|
||||
"ios": ["iphone", "ipad", "ios"],
|
||||
}
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
for platform, keywords in platforms.items():
|
||||
if any(keyword in user_agent_lower for keyword in keywords):
|
||||
return platform
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _generate_device_fingerprint(user_agent: str) -> str:
|
||||
"""生成设备指纹"""
|
||||
# 使用 User-Agent 的哈希值作为简单的设备指纹
|
||||
# 在实际应用中可以结合更多信息(IP、屏幕分辨率等)
|
||||
return hashlib.sha256(user_agent.encode()).hexdigest()[:16]
|
||||
|
||||
@staticmethod
|
||||
def should_skip_email_verification(
|
||||
client_info: ClientInfo,
|
||||
is_new_location: bool,
|
||||
user_id: int,
|
||||
) -> bool:
|
||||
"""
|
||||
判断是否应该跳过邮件验证
|
||||
|
||||
Args:
|
||||
client_info: 客户端信息
|
||||
is_new_location: 是否为新位置登录
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
bool: 是否应该跳过邮件验证
|
||||
"""
|
||||
# 受信任的客户端类型可以减少验证频率
|
||||
if client_info.is_trusted_client:
|
||||
logger.info(
|
||||
f"[Client Detection] Trusted client {client_info.client_type} for user {user_id}, "
|
||||
f"reducing verification requirements"
|
||||
)
|
||||
return True
|
||||
|
||||
# 如果不是新位置,跳过验证
|
||||
if not is_new_location:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_verification_cooldown(client_info: ClientInfo) -> int:
|
||||
"""
|
||||
获取验证冷却时间(秒)
|
||||
|
||||
Args:
|
||||
client_info: 客户端信息
|
||||
|
||||
Returns:
|
||||
int: 冷却时间(秒)
|
||||
"""
|
||||
# 受信任的客户端有更长的冷却时间
|
||||
if client_info.is_trusted_client:
|
||||
return 3600 # 1小时
|
||||
|
||||
# 网页客户端较短的冷却时间
|
||||
if client_info.client_type == "osu_web":
|
||||
return 1800 # 30分钟
|
||||
|
||||
# 未知客户端最短冷却时间
|
||||
return 900 # 15分钟
|
||||
|
||||
@staticmethod
|
||||
def format_client_display_name(client_info: ClientInfo) -> str:
|
||||
"""格式化客户端显示名称"""
|
||||
display_names = {
|
||||
"osu_stable": "osu! (stable)",
|
||||
"osu_lazer": "osu!(lazer)",
|
||||
"osu_web": "osu! web",
|
||||
"mobile": "osu! mobile",
|
||||
"unknown": "Unknown client",
|
||||
}
|
||||
|
||||
base_name = display_names.get(client_info.client_type, "Unknown client")
|
||||
|
||||
if client_info.version:
|
||||
base_name += f" v{client_info.version}"
|
||||
|
||||
if client_info.platform:
|
||||
base_name += f" ({client_info.platform})"
|
||||
|
||||
return base_name
|
||||
@@ -6,11 +6,14 @@ from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from app.database.verification import EmailVerification, LoginSession
|
||||
from app.database.auth import OAuthToken
|
||||
from app.database.verification import EmailVerification, LoginSession, TrustedDevice
|
||||
from app.dependencies.database import with_db
|
||||
from app.dependencies.scheduler import get_scheduler
|
||||
from app.log import logger
|
||||
from app.utils import utcnow
|
||||
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel import col, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -69,7 +72,9 @@ class DatabaseCleanupService:
|
||||
# 查找过期的登录会话记录
|
||||
current_time = utcnow()
|
||||
|
||||
stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
|
||||
stmt = select(LoginSession).where(
|
||||
LoginSession.expires_at < current_time, col(LoginSession.is_verified).is_(False)
|
||||
)
|
||||
result = await db.exec(stmt)
|
||||
expired_sessions = result.all()
|
||||
|
||||
@@ -179,50 +184,109 @@ class DatabaseCleanupService:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
|
||||
async def cleanup_outdated_verified_sessions(db: AsyncSession) -> int:
|
||||
"""
|
||||
清理旧的已验证会话记录
|
||||
清理过期会话记录
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
days_old: 清理多少天前的已验证记录,默认30天
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找指定天数前的已验证会话记录
|
||||
cutoff_time = utcnow() - timedelta(days=days_old)
|
||||
|
||||
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
||||
stmt = select(LoginSession).where(
|
||||
col(LoginSession.is_verified).is_(True), col(LoginSession.token_id).is_(None)
|
||||
)
|
||||
result = await db.exec(stmt)
|
||||
all_verified_sessions = result.all()
|
||||
|
||||
# 筛选出过期的记录
|
||||
old_verified_sessions = [
|
||||
session
|
||||
for session in all_verified_sessions
|
||||
if session.verified_at and session.verified_at < cutoff_time
|
||||
]
|
||||
|
||||
# 删除旧的已验证记录
|
||||
deleted_count = 0
|
||||
for session in old_verified_sessions:
|
||||
for session in result.all():
|
||||
await db.delete(session)
|
||||
deleted_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(
|
||||
f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days"
|
||||
)
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} outdated verified sessions")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}")
|
||||
logger.error(f"[Cleanup Service] Error cleaning outdated verified sessions: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_outdated_trusted_devices(db: AsyncSession) -> int:
|
||||
"""
|
||||
清理过期的受信任设备记录
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找过期的受信任设备记录
|
||||
current_time = utcnow()
|
||||
|
||||
stmt = select(TrustedDevice).where(TrustedDevice.expires_at < current_time)
|
||||
result = await db.exec(stmt)
|
||||
expired_devices = result.all()
|
||||
|
||||
# 删除过期的记录
|
||||
deleted_count = 0
|
||||
for device in expired_devices:
|
||||
await db.delete(device)
|
||||
deleted_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired trusted devices")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired trusted devices: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_outdated_tokens(db: AsyncSession) -> int:
|
||||
"""
|
||||
清理过期的 OAuth 令牌
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
current_time = utcnow()
|
||||
|
||||
stmt = select(OAuthToken).where(OAuthToken.refresh_token_expires_at < current_time)
|
||||
result = await db.exec(stmt)
|
||||
expired_tokens = result.all()
|
||||
|
||||
deleted_count = 0
|
||||
for token in expired_tokens:
|
||||
await db.delete(token)
|
||||
deleted_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired OAuth tokens")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired OAuth tokens: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
@@ -250,8 +314,14 @@ class DatabaseCleanupService:
|
||||
# 清理7天前的已使用验证码
|
||||
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
|
||||
|
||||
# 清理30天前的已验证会话
|
||||
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
|
||||
# 清理过期的受信任设备
|
||||
results["outdated_trusted_devices"] = await DatabaseCleanupService.cleanup_outdated_trusted_devices(db)
|
||||
|
||||
# 清理过期的 OAuth 令牌
|
||||
results["outdated_oauth_tokens"] = await DatabaseCleanupService.cleanup_outdated_tokens(db)
|
||||
|
||||
# 清理过期(token 过期)的已验证会话
|
||||
results["outdated_verified_sessions"] = await DatabaseCleanupService.cleanup_outdated_verified_sessions(db)
|
||||
|
||||
total_cleaned = sum(results.values())
|
||||
if total_cleaned > 0:
|
||||
@@ -279,21 +349,27 @@ class DatabaseCleanupService:
|
||||
cutoff_30_days = current_time - timedelta(days=30)
|
||||
|
||||
# 统计过期的验证码数量
|
||||
expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
|
||||
expired_codes_stmt = (
|
||||
select(func.count()).select_from(EmailVerification).where(EmailVerification.expires_at < current_time)
|
||||
)
|
||||
expired_codes_result = await db.exec(expired_codes_stmt)
|
||||
expired_codes_count = len(expired_codes_result.all())
|
||||
expired_codes_count = expired_codes_result.one()
|
||||
|
||||
# 统计过期的登录会话数量
|
||||
expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
|
||||
expired_sessions_stmt = (
|
||||
select(func.count()).select_from(LoginSession).where(LoginSession.expires_at < current_time)
|
||||
)
|
||||
expired_sessions_result = await db.exec(expired_sessions_stmt)
|
||||
expired_sessions_count = len(expired_sessions_result.all())
|
||||
expired_sessions_count = expired_sessions_result.one()
|
||||
|
||||
# 统计1小时前未验证的登录会话数量
|
||||
unverified_sessions_stmt = select(LoginSession).where(
|
||||
col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour
|
||||
unverified_sessions_stmt = (
|
||||
select(func.count())
|
||||
.select_from(LoginSession)
|
||||
.where(col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour)
|
||||
)
|
||||
unverified_sessions_result = await db.exec(unverified_sessions_stmt)
|
||||
unverified_sessions_count = len(unverified_sessions_result.all())
|
||||
unverified_sessions_count = unverified_sessions_result.one()
|
||||
|
||||
# 统计7天前的已使用验证码数量
|
||||
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
|
||||
@@ -304,10 +380,10 @@ class DatabaseCleanupService:
|
||||
)
|
||||
|
||||
# 统计30天前的已验证会话数量
|
||||
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
||||
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
|
||||
all_verified_sessions = old_verified_sessions_result.all()
|
||||
old_verified_sessions_count = len(
|
||||
outdated_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
||||
outdated_verified_sessions_result = await db.exec(outdated_verified_sessions_stmt)
|
||||
all_verified_sessions = outdated_verified_sessions_result.all()
|
||||
outdated_verified_sessions_count = len(
|
||||
[
|
||||
session
|
||||
for session in all_verified_sessions
|
||||
@@ -315,17 +391,35 @@ class DatabaseCleanupService:
|
||||
]
|
||||
)
|
||||
|
||||
# 统计过期的 OAuth 令牌数量
|
||||
outdated_tokens_stmt = (
|
||||
select(func.count()).select_from(OAuthToken).where(OAuthToken.refresh_token_expires_at < current_time)
|
||||
)
|
||||
outdated_tokens_result = await db.exec(outdated_tokens_stmt)
|
||||
outdated_tokens_count = outdated_tokens_result.one()
|
||||
|
||||
# 统计过期的受信任设备数量
|
||||
outdated_devices_stmt = (
|
||||
select(func.count()).select_from(TrustedDevice).where(TrustedDevice.expires_at < current_time)
|
||||
)
|
||||
outdated_devices_result = await db.exec(outdated_devices_stmt)
|
||||
outdated_devices_count = outdated_devices_result.one()
|
||||
|
||||
return {
|
||||
"expired_verification_codes": expired_codes_count,
|
||||
"expired_login_sessions": expired_sessions_count,
|
||||
"unverified_login_sessions": unverified_sessions_count,
|
||||
"old_used_verification_codes": old_used_codes_count,
|
||||
"old_verified_sessions": old_verified_sessions_count,
|
||||
"outdated_verified_sessions": outdated_verified_sessions_count,
|
||||
"outdated_oauth_tokens": outdated_tokens_count,
|
||||
"outdated_trusted_devices": outdated_devices_count,
|
||||
"total_cleanable": expired_codes_count
|
||||
+ expired_sessions_count
|
||||
+ unverified_sessions_count
|
||||
+ old_used_codes_count
|
||||
+ old_verified_sessions_count,
|
||||
+ outdated_verified_sessions_count
|
||||
+ outdated_tokens_count
|
||||
+ outdated_devices_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -335,6 +429,23 @@ class DatabaseCleanupService:
|
||||
"expired_login_sessions": 0,
|
||||
"unverified_login_sessions": 0,
|
||||
"old_used_verification_codes": 0,
|
||||
"old_verified_sessions": 0,
|
||||
"outdated_verified_sessions": 0,
|
||||
"outdated_oauth_tokens": 0,
|
||||
"outdated_trusted_devices": 0,
|
||||
"total_cleanable": 0,
|
||||
}
|
||||
|
||||
|
||||
@get_scheduler().scheduled_job(
|
||||
"interval",
|
||||
id="cleanup_database",
|
||||
hours=1,
|
||||
)
|
||||
async def scheduled_cleanup_job():
|
||||
async with with_db() as session:
|
||||
logger.debug("Starting database cleanup...")
|
||||
results = await DatabaseCleanupService.run_full_cleanup(session)
|
||||
total = sum(results.values())
|
||||
if total > 0:
|
||||
logger.debug(f"Cleanup completed, total records cleaned: {total}")
|
||||
return results
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
"""
|
||||
设备信任服务
|
||||
管理用户的受信任设备,减少频繁验证
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from app.config import settings
|
||||
from app.log import logger
|
||||
from app.service.client_detection_service import ClientInfo
|
||||
from app.utils import utcnow
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class DeviceTrustService:
|
||||
"""设备信任服务"""
|
||||
|
||||
@staticmethod
|
||||
def _get_device_trust_key(user_id: int, device_fingerprint: str) -> str:
|
||||
"""获取设备信任的 Redis 键"""
|
||||
return f"device_trust:{user_id}:{device_fingerprint}"
|
||||
|
||||
@staticmethod
|
||||
def _get_location_trust_key(user_id: int, country_code: str) -> str:
|
||||
"""获取位置信任的 Redis 键"""
|
||||
return f"location_trust:{user_id}:{country_code}"
|
||||
|
||||
@staticmethod
|
||||
def _get_verification_cooldown_key(user_id: int) -> str:
|
||||
"""获取验证冷却的 Redis 键"""
|
||||
return f"verification_cooldown:{user_id}"
|
||||
|
||||
@staticmethod
|
||||
async def is_device_trusted(
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
device_fingerprint: str,
|
||||
) -> bool:
|
||||
"""
|
||||
检查设备是否受信任
|
||||
|
||||
Args:
|
||||
redis: Redis 连接
|
||||
user_id: 用户 ID
|
||||
device_fingerprint: 设备指纹
|
||||
|
||||
Returns:
|
||||
bool: 设备是否受信任
|
||||
"""
|
||||
if not device_fingerprint:
|
||||
return False
|
||||
|
||||
trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint)
|
||||
trust_data = await redis.get(trust_key)
|
||||
|
||||
return trust_data is not None
|
||||
|
||||
@staticmethod
|
||||
async def is_location_trusted(
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
country_code: str | None,
|
||||
) -> bool:
|
||||
"""
|
||||
检查位置是否受信任
|
||||
|
||||
Args:
|
||||
redis: Redis 连接
|
||||
user_id: 用户 ID
|
||||
country_code: 国家代码
|
||||
|
||||
Returns:
|
||||
bool: 位置是否受信任
|
||||
"""
|
||||
if not country_code:
|
||||
return False
|
||||
|
||||
trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code)
|
||||
trust_data = await redis.get(trust_key)
|
||||
|
||||
return trust_data is not None
|
||||
|
||||
@staticmethod
|
||||
async def is_in_verification_cooldown(
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
) -> bool:
|
||||
"""
|
||||
检查用户是否在验证冷却期内
|
||||
|
||||
Args:
|
||||
redis: Redis 连接
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
bool: 是否在冷却期内
|
||||
"""
|
||||
cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id)
|
||||
cooldown_data = await redis.get(cooldown_key)
|
||||
|
||||
return cooldown_data is not None
|
||||
|
||||
@staticmethod
|
||||
async def trust_device(
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
device_fingerprint: str,
|
||||
client_info: ClientInfo,
|
||||
trust_duration_days: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
信任设备
|
||||
|
||||
Args:
|
||||
redis: Redis 连接
|
||||
user_id: 用户 ID
|
||||
device_fingerprint: 设备指纹
|
||||
client_info: 客户端信息
|
||||
trust_duration_days: 信任持续天数
|
||||
"""
|
||||
if not device_fingerprint:
|
||||
return
|
||||
|
||||
# 使用配置中的默认值
|
||||
if trust_duration_days is None:
|
||||
trust_duration_days = settings.device_trust_duration_days
|
||||
|
||||
trust_key = DeviceTrustService._get_device_trust_key(user_id, device_fingerprint)
|
||||
trust_data = {
|
||||
"client_type": client_info.client_type,
|
||||
"platform": client_info.platform or "unknown",
|
||||
"trusted_at": utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# 设置信任期限
|
||||
trust_duration_seconds = trust_duration_days * 24 * 3600
|
||||
await redis.setex(trust_key, trust_duration_seconds, str(trust_data))
|
||||
|
||||
logger.info(
|
||||
f"[Device Trust] Device trusted for user {user_id}: "
|
||||
f"{client_info.client_type} on {client_info.platform} "
|
||||
f"(fingerprint: {device_fingerprint[:8]}...)"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def trust_location(
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
country_code: str,
|
||||
trust_duration_days: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
信任位置
|
||||
|
||||
Args:
|
||||
redis: Redis 连接
|
||||
user_id: 用户 ID
|
||||
country_code: 国家代码
|
||||
trust_duration_days: 信任持续天数
|
||||
"""
|
||||
if not country_code:
|
||||
return
|
||||
|
||||
# 使用配置中的默认值
|
||||
if trust_duration_days is None:
|
||||
trust_duration_days = settings.location_trust_duration_days
|
||||
|
||||
trust_key = DeviceTrustService._get_location_trust_key(user_id, country_code)
|
||||
trust_data = {
|
||||
"country_code": country_code,
|
||||
"trusted_at": utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# 设置信任期限
|
||||
trust_duration_seconds = trust_duration_days * 24 * 3600
|
||||
await redis.setex(trust_key, trust_duration_seconds, str(trust_data))
|
||||
|
||||
logger.info(f"[Location Trust] Location trusted for user {user_id}: {country_code}")
|
||||
|
||||
@staticmethod
|
||||
async def set_verification_cooldown(
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
cooldown_seconds: int,
|
||||
) -> None:
|
||||
"""
|
||||
设置验证冷却期
|
||||
|
||||
Args:
|
||||
redis: Redis 连接
|
||||
user_id: 用户 ID
|
||||
cooldown_seconds: 冷却时间(秒)
|
||||
"""
|
||||
cooldown_key = DeviceTrustService._get_verification_cooldown_key(user_id)
|
||||
cooldown_data = {
|
||||
"set_at": utcnow().isoformat(),
|
||||
"expires_at": (utcnow() + timedelta(seconds=cooldown_seconds)).isoformat(),
|
||||
}
|
||||
|
||||
await redis.setex(cooldown_key, cooldown_seconds, str(cooldown_data))
|
||||
|
||||
logger.info(f"[Verification Cooldown] Set cooldown for user {user_id}: {cooldown_seconds}s")
|
||||
|
||||
@staticmethod
|
||||
async def should_require_verification(
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
device_fingerprint: str | None,
|
||||
country_code: str | None,
|
||||
client_info: ClientInfo,
|
||||
is_new_location: bool,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
判断是否需要验证
|
||||
|
||||
Args:
|
||||
redis: Redis 连接
|
||||
user_id: 用户 ID
|
||||
device_fingerprint: 设备指纹
|
||||
country_code: 国家代码
|
||||
client_info: 客户端信息
|
||||
is_new_location: 是否为新位置
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否需要验证, 原因)
|
||||
"""
|
||||
# 检查验证冷却期
|
||||
if await DeviceTrustService.is_in_verification_cooldown(redis, user_id):
|
||||
return False, "用户在验证冷却期内"
|
||||
|
||||
# 检查设备信任
|
||||
if device_fingerprint and await DeviceTrustService.is_device_trusted(redis, user_id, device_fingerprint):
|
||||
return False, "设备已受信任"
|
||||
|
||||
# 检查位置信任
|
||||
if country_code and await DeviceTrustService.is_location_trusted(redis, user_id, country_code):
|
||||
return False, "位置已受信任"
|
||||
|
||||
# 受信任的客户端类型降低验证要求
|
||||
if client_info.is_trusted_client and not is_new_location:
|
||||
return False, "受信任客户端且非新位置"
|
||||
|
||||
# 如果是新位置登录,需要验证
|
||||
if is_new_location:
|
||||
return True, "新位置登录需要验证"
|
||||
|
||||
# 默认不需要验证
|
||||
return False, "常规登录无需验证"
|
||||
|
||||
@staticmethod
|
||||
async def mark_verification_successful(
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
device_fingerprint: str | None,
|
||||
country_code: str | None,
|
||||
client_info: ClientInfo,
|
||||
) -> None:
|
||||
"""
|
||||
标记验证成功,更新信任信息
|
||||
|
||||
Args:
|
||||
redis: Redis 连接
|
||||
user_id: 用户 ID
|
||||
device_fingerprint: 设备指纹
|
||||
country_code: 国家代码
|
||||
client_info: 客户端信息
|
||||
"""
|
||||
# 信任设备
|
||||
if device_fingerprint:
|
||||
await DeviceTrustService.trust_device(redis, user_id, device_fingerprint, client_info)
|
||||
|
||||
# 信任位置
|
||||
if country_code:
|
||||
await DeviceTrustService.trust_location(redis, user_id, country_code)
|
||||
|
||||
# 设置验证冷却期
|
||||
cooldown_seconds = (client_info.is_trusted_client and 3600) or 1800 # 受信任客户端1小时,其他30分钟
|
||||
await DeviceTrustService.set_verification_cooldown(redis, user_id, cooldown_seconds)
|
||||
|
||||
logger.info(f"[Device Trust] Verification successful for user {user_id}, trust updated")
|
||||
@@ -9,7 +9,7 @@ import asyncio
|
||||
from app.database.user_login_log import UserLoginLog
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip
|
||||
from app.log import logger
|
||||
from app.utils import simplify_user_agent, utcnow
|
||||
from app.utils import utcnow
|
||||
|
||||
from fastapi import Request
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -23,6 +23,7 @@ class LoginLogService:
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
request: Request,
|
||||
user_agent: str | None = None,
|
||||
login_success: bool = True,
|
||||
login_method: str = "password",
|
||||
notes: str | None = None,
|
||||
@@ -45,9 +46,6 @@ class LoginLogService:
|
||||
raw_ip = get_client_ip(request)
|
||||
ip_address = normalize_ip(raw_ip)
|
||||
|
||||
raw_user_agent = request.headers.get("User-Agent", "")
|
||||
user_agent = simplify_user_agent(raw_user_agent, max_length=500)
|
||||
|
||||
# 创建基本的登录记录
|
||||
login_log = UserLoginLog(
|
||||
user_id=user_id,
|
||||
@@ -107,6 +105,7 @@ class LoginLogService:
|
||||
attempted_username: str | None = None,
|
||||
login_method: str = "password",
|
||||
notes: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> UserLoginLog:
|
||||
"""
|
||||
记录失败的登录尝试
|
||||
@@ -128,6 +127,7 @@ class LoginLogService:
|
||||
request=request,
|
||||
login_success=False,
|
||||
login_method=login_method,
|
||||
user_agent=user_agent,
|
||||
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
"""
|
||||
API 状态管理 - 模拟 osu! 的 APIState 和会话管理
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class APIState(str, Enum):
|
||||
"""API 连接状态,对应 osu! 的 APIState"""
|
||||
|
||||
OFFLINE = "offline"
|
||||
CONNECTING = "connecting"
|
||||
REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证
|
||||
ONLINE = "online"
|
||||
FAILING = "failing"
|
||||
|
||||
|
||||
class UserSession(BaseModel):
|
||||
"""用户会话信息"""
|
||||
|
||||
user_id: int
|
||||
username: str
|
||||
email: str
|
||||
session_token: str | None = None
|
||||
state: APIState = APIState.OFFLINE
|
||||
requires_verification: bool = False
|
||||
verification_sent: bool = False
|
||||
last_verification_attempt: datetime | None = None
|
||||
failed_attempts: int = 0
|
||||
ip_address: str | None = None
|
||||
country_code: str | None = None
|
||||
is_new_location: bool = False
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._sessions: dict[str, UserSession] = {}
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_id: int,
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False,
|
||||
) -> UserSession:
|
||||
"""创建新的用户会话"""
|
||||
import secrets
|
||||
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
|
||||
# 根据是否为新位置决定初始状态
|
||||
if is_new_location:
|
||||
state = APIState.REQUIRES_SECOND_FACTOR_AUTH
|
||||
else:
|
||||
state = APIState.ONLINE
|
||||
|
||||
session = UserSession(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
email=email,
|
||||
session_token=session_token,
|
||||
state=state,
|
||||
requires_verification=is_new_location,
|
||||
ip_address=ip_address,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location,
|
||||
)
|
||||
|
||||
self._sessions[session_token] = session
|
||||
return session
|
||||
|
||||
def get_session(self, session_token: str) -> UserSession | None:
|
||||
"""获取会话"""
|
||||
return self._sessions.get(session_token)
|
||||
|
||||
def update_session_state(self, session_token: str, state: APIState):
|
||||
"""更新会话状态"""
|
||||
if session_token in self._sessions:
|
||||
self._sessions[session_token].state = state
|
||||
|
||||
def mark_verification_sent(self, session_token: str):
|
||||
"""标记验证邮件已发送"""
|
||||
if session_token in self._sessions:
|
||||
session = self._sessions[session_token]
|
||||
session.verification_sent = True
|
||||
session.last_verification_attempt = datetime.now()
|
||||
|
||||
def increment_failed_attempts(self, session_token: str):
|
||||
"""增加失败尝试次数"""
|
||||
if session_token in self._sessions:
|
||||
self._sessions[session_token].failed_attempts += 1
|
||||
|
||||
def verify_session(self, session_token: str) -> bool:
|
||||
"""验证会话成功"""
|
||||
if session_token in self._sessions:
|
||||
session = self._sessions[session_token]
|
||||
session.state = APIState.ONLINE
|
||||
session.requires_verification = False
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_session(self, session_token: str):
|
||||
"""移除会话"""
|
||||
self._sessions.pop(session_token, None)
|
||||
|
||||
def cleanup_expired_sessions(self):
|
||||
"""清理过期会话"""
|
||||
# 这里可以实现清理逻辑
|
||||
pass
|
||||
|
||||
|
||||
# 全局会话管理器
|
||||
session_manager = SessionManager()
|
||||
@@ -10,15 +10,15 @@ import string
|
||||
from typing import Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.database.verification import EmailVerification, LoginSession
|
||||
from app.database.auth import OAuthToken
|
||||
from app.database.verification import EmailVerification, LoginSession, TrustedDevice
|
||||
from app.log import logger
|
||||
from app.service.client_detection_service import ClientDetectionService, ClientInfo
|
||||
from app.service.device_trust_service import DeviceTrustService
|
||||
from app.service.email_queue import email_queue # 导入邮件队列
|
||||
from app.models.model import UserAgentInfo
|
||||
from app.service.email_queue import email_queue
|
||||
from app.utils import utcnow
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import exists, select
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -248,11 +248,9 @@ This email was sent automatically, please do not reply.
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
client_id: int | None = None,
|
||||
country_code: str | None = None,
|
||||
user_agent: UserAgentInfo | None = None,
|
||||
) -> bool:
|
||||
"""发送验证邮件(带智能检测)"""
|
||||
"""发送验证邮件"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
@@ -260,32 +258,14 @@ This email was sent automatically, please do not reply.
|
||||
return True # 返回成功,但不执行验证流程
|
||||
|
||||
# 检测客户端信息
|
||||
client_info = ClientDetectionService.detect_client(user_agent, client_id)
|
||||
logger.info(
|
||||
f"[Email Verification] Detected client for user {user_id}: "
|
||||
f"{ClientDetectionService.format_client_display_name(client_info)}"
|
||||
)
|
||||
|
||||
# 检查是否需要验证
|
||||
needs_verification, reason = await DeviceTrustService.should_require_verification(
|
||||
redis=redis,
|
||||
user_id=user_id,
|
||||
device_fingerprint=client_info.device_fingerprint,
|
||||
country_code=country_code,
|
||||
client_info=client_info,
|
||||
is_new_location=True, # 这里需要从调用方传入
|
||||
)
|
||||
|
||||
if not needs_verification:
|
||||
logger.info(f"[Email Verification] Skipping verification for user {user_id}: {reason}")
|
||||
return True
|
||||
logger.info(f"[Email Verification] Detected client for user {user_id}: {user_agent}")
|
||||
|
||||
# 创建验证记录
|
||||
(
|
||||
_,
|
||||
code,
|
||||
) = await EmailVerificationService.create_verification_record(
|
||||
db, redis, user_id, email, ip_address, user_agent
|
||||
db, redis, user_id, email, ip_address, user_agent.raw_ua if user_agent else None
|
||||
)
|
||||
|
||||
# 使用邮件队列发送验证邮件
|
||||
@@ -304,107 +284,6 @@ This email was sent automatically, please do not reply.
|
||||
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def send_smart_verification_email(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
client_id: int | None = None,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False,
|
||||
) -> tuple[bool, str, ClientInfo | None]:
|
||||
"""
|
||||
智能邮件验证发送
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
redis: Redis 连接
|
||||
user_id: 用户 ID
|
||||
username: 用户名
|
||||
email: 邮箱地址
|
||||
ip_address: IP 地址
|
||||
user_agent: 用户代理
|
||||
client_id: 客户端 ID
|
||||
country_code: 国家代码
|
||||
is_new_location: 是否为新位置登录
|
||||
|
||||
Returns:
|
||||
tuple[bool, str, ClientInfo | None]: (是否成功, 消息, 客户端信息)
|
||||
"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Smart Verification] Email verification is disabled, skipping for user {user_id}")
|
||||
return True, "邮件验证功能已禁用", None
|
||||
|
||||
# 检查是否启用智能验证
|
||||
if not settings.enable_smart_verification:
|
||||
logger.debug(
|
||||
f"[Smart Verification] Smart verification is disabled, using legacy logic for user {user_id}"
|
||||
)
|
||||
# 回退到传统验证逻辑
|
||||
verification, code = await EmailVerificationService.create_verification_record(
|
||||
db, redis, user_id, email, ip_address, user_agent
|
||||
)
|
||||
success = await EmailVerificationService.send_verification_email_via_queue(
|
||||
email, code, username, user_id
|
||||
)
|
||||
return success, "使用传统验证逻辑发送邮件" if success else "传统验证邮件发送失败", None
|
||||
|
||||
# 检测客户端信息
|
||||
client_info = ClientDetectionService.detect_client(user_agent, client_id)
|
||||
client_display_name = ClientDetectionService.format_client_display_name(client_info)
|
||||
|
||||
logger.info(f"[Smart Verification] Detected client for user {user_id}: {client_display_name}")
|
||||
|
||||
# 检查是否需要验证
|
||||
needs_verification, reason = await DeviceTrustService.should_require_verification(
|
||||
redis=redis,
|
||||
user_id=user_id,
|
||||
device_fingerprint=client_info.device_fingerprint,
|
||||
country_code=country_code,
|
||||
client_info=client_info,
|
||||
is_new_location=is_new_location,
|
||||
)
|
||||
|
||||
if not needs_verification:
|
||||
logger.info(f"[Smart Verification] Skipping verification for user {user_id}: {reason}")
|
||||
|
||||
# 即使不需要验证,也要更新设备信任信息
|
||||
if client_info.device_fingerprint:
|
||||
await DeviceTrustService.trust_device(redis, user_id, client_info.device_fingerprint, client_info)
|
||||
if country_code:
|
||||
await DeviceTrustService.trust_location(redis, user_id, country_code)
|
||||
|
||||
return True, f"跳过验证: {reason}", client_info
|
||||
|
||||
# 创建验证记录
|
||||
verification, code = await EmailVerificationService.create_verification_record(
|
||||
db, redis, user_id, email, ip_address, user_agent
|
||||
)
|
||||
_ = verification # 避免未使用变量警告
|
||||
|
||||
# 使用邮件队列发送验证邮件
|
||||
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
f"[Smart Verification] Successfully sent verification email to {email} "
|
||||
f"for user {username} using {client_display_name}"
|
||||
)
|
||||
return True, "验证邮件已发送", client_info
|
||||
else:
|
||||
logger.error(f"[Smart Verification] Failed to send verification email: {email} (user: {username})")
|
||||
return False, "验证邮件发送失败", client_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Smart Verification] Exception during smart verification: {e}")
|
||||
return False, f"验证过程中发生错误: {e!s}", None
|
||||
|
||||
@staticmethod
|
||||
async def verify_email_code(
|
||||
db: AsyncSession,
|
||||
@@ -416,7 +295,7 @@ This email was sent automatically, please do not reply.
|
||||
client_id: int | None = None,
|
||||
country_code: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""验证邮箱验证码(带智能信任更新)"""
|
||||
"""验证邮箱验证码"""
|
||||
try:
|
||||
# 检查是否启用邮件验证功能
|
||||
if not settings.enable_email_verification:
|
||||
@@ -452,16 +331,6 @@ This email was sent automatically, please do not reply.
|
||||
# 删除 Redis 记录
|
||||
await redis.delete(f"email_verification:{user_id}:{code}")
|
||||
|
||||
# 检测客户端信息并更新信任状态
|
||||
client_info = ClientDetectionService.detect_client(user_agent, client_id)
|
||||
await DeviceTrustService.mark_verification_successful(
|
||||
redis=redis,
|
||||
user_id=user_id,
|
||||
device_fingerprint=client_info.device_fingerprint,
|
||||
country_code=country_code,
|
||||
client_info=client_info,
|
||||
)
|
||||
|
||||
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
|
||||
return True, "验证成功"
|
||||
|
||||
@@ -477,7 +346,7 @@ This email was sent automatically, please do not reply.
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
user_agent: UserAgentInfo | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""重新发送验证码"""
|
||||
try:
|
||||
@@ -516,12 +385,12 @@ class LoginSessionService:
|
||||
|
||||
# Session verification interface methods
|
||||
@staticmethod
|
||||
async def find_for_verification(db: AsyncSession, session_id: str) -> LoginSession | None:
|
||||
async def find_for_verification(db: AsyncSession, token: str) -> LoginSession | None:
|
||||
"""根据会话ID查找会话用于验证"""
|
||||
try:
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.session_token == session_id,
|
||||
col(LoginSession.token).has(col(OAuthToken.access_token) == token),
|
||||
LoginSession.expires_at > utcnow(),
|
||||
)
|
||||
)
|
||||
@@ -537,42 +406,31 @@ class LoginSessionService:
|
||||
@staticmethod
|
||||
async def create_session(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
token_id: int,
|
||||
ip_address: str,
|
||||
user_agent: str | None = None,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False,
|
||||
is_new_device: bool = False,
|
||||
web_uuid: str | None = None,
|
||||
is_verified: bool = False,
|
||||
) -> LoginSession:
|
||||
"""创建登录会话"""
|
||||
|
||||
session_token = EmailVerificationService.generate_session_token()
|
||||
|
||||
session = LoginSession(
|
||||
user_id=user_id,
|
||||
token_id=token_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=None,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location,
|
||||
user_agent=user_agent,
|
||||
is_new_device=is_new_device,
|
||||
expires_at=utcnow() + timedelta(hours=24), # 24小时过期
|
||||
is_verified=is_verified,
|
||||
web_uuid=web_uuid,
|
||||
)
|
||||
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
# 存储到 Redis
|
||||
await redis.setex(
|
||||
f"login_session:{session_token}",
|
||||
86400, # 24小时
|
||||
user_id,
|
||||
)
|
||||
|
||||
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
|
||||
logger.info(f"[Login Session] Created session for user {user_id} (new device: {is_new_device})")
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
@@ -592,35 +450,98 @@ class LoginSessionService:
|
||||
await redis.delete(cls._session_verify_redis_key(user_id, token_id))
|
||||
|
||||
@staticmethod
|
||||
async def check_new_location(
|
||||
db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None
|
||||
async def check_trusted_device(
|
||||
db: AsyncSession, user_id: int, ip_address: str, user_agent: UserAgentInfo, web_uuid: str | None = None
|
||||
) -> bool:
|
||||
"""检查是否为新位置登录"""
|
||||
try:
|
||||
# 查看过去30天内是否有相同IP或相同国家的登录记录
|
||||
thirty_days_ago = utcnow() - timedelta(days=30)
|
||||
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.created_at > thirty_days_ago,
|
||||
(LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code),
|
||||
)
|
||||
if user_agent.is_client:
|
||||
query = select(exists()).where(
|
||||
TrustedDevice.user_id == user_id,
|
||||
TrustedDevice.client_type == "client",
|
||||
TrustedDevice.ip_address == ip_address,
|
||||
TrustedDevice.expires_at > utcnow(),
|
||||
)
|
||||
|
||||
existing_sessions = result.all()
|
||||
|
||||
# 如果有历史记录,则不是新位置
|
||||
return len(existing_sessions) == 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during new location check: {e}")
|
||||
# 出错时默认为新位置(更安全)
|
||||
return True
|
||||
else:
|
||||
if web_uuid is None:
|
||||
return False
|
||||
query = select(exists()).where(
|
||||
TrustedDevice.user_id == user_id,
|
||||
TrustedDevice.client_type == "web",
|
||||
TrustedDevice.web_uuid == web_uuid,
|
||||
TrustedDevice.expires_at > utcnow(),
|
||||
)
|
||||
return (await db.exec(query)).first() or False
|
||||
|
||||
@staticmethod
|
||||
async def mark_session_verified(db: AsyncSession, redis: Redis, user_id: int, token_id: int) -> bool:
|
||||
async def create_trusted_device(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
ip_address: str,
|
||||
user_agent: UserAgentInfo,
|
||||
web_uuid: str | None = None,
|
||||
) -> TrustedDevice:
|
||||
device = TrustedDevice(
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent.raw_ua,
|
||||
client_type="client" if user_agent.is_client else "web",
|
||||
web_uuid=web_uuid if not user_agent.is_client else None,
|
||||
expires_at=utcnow() + timedelta(days=settings.device_trust_duration_days),
|
||||
)
|
||||
db.add(device)
|
||||
await db.commit()
|
||||
await db.refresh(device)
|
||||
return device
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_trusted_device(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
ip_address: str,
|
||||
user_agent: UserAgentInfo,
|
||||
web_uuid: str | None = None,
|
||||
) -> TrustedDevice:
|
||||
if user_agent.is_client:
|
||||
query = select(TrustedDevice).where(
|
||||
TrustedDevice.user_id == user_id,
|
||||
TrustedDevice.client_type == "client",
|
||||
TrustedDevice.ip_address == ip_address,
|
||||
)
|
||||
else:
|
||||
if web_uuid is None:
|
||||
raise ValueError("web_uuid is required for web clients")
|
||||
query = select(TrustedDevice).where(
|
||||
TrustedDevice.user_id == user_id,
|
||||
TrustedDevice.client_type == "web",
|
||||
TrustedDevice.web_uuid == web_uuid,
|
||||
)
|
||||
|
||||
device = (await db.exec(query)).first()
|
||||
if device is None:
|
||||
device = await LoginSessionService.create_trusted_device(db, user_id, ip_address, user_agent, web_uuid)
|
||||
else:
|
||||
device.last_used_at = utcnow()
|
||||
device.expires_at = utcnow() + timedelta(days=settings.device_trust_duration_days)
|
||||
await db.commit()
|
||||
await db.refresh(device)
|
||||
return device
|
||||
|
||||
@staticmethod
|
||||
async def mark_session_verified(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
token_id: int,
|
||||
ip_address: str,
|
||||
user_agent: UserAgentInfo,
|
||||
web_uuid: str | None = None,
|
||||
) -> bool:
|
||||
"""标记用户的未验证会话为已验证"""
|
||||
device_info: TrustedDevice | None = None
|
||||
if user_agent.is_client or web_uuid:
|
||||
device_info = await LoginSessionService.get_or_create_trusted_device(
|
||||
db, user_id, ip_address, user_agent, web_uuid
|
||||
)
|
||||
|
||||
try:
|
||||
# 查找用户所有未验证且未过期的会话
|
||||
result = await db.exec(
|
||||
@@ -631,18 +552,20 @@ class LoginSessionService:
|
||||
LoginSession.token_id == token_id,
|
||||
)
|
||||
)
|
||||
|
||||
sessions = result.all()
|
||||
|
||||
# 标记所有会话为已验证
|
||||
for session in sessions:
|
||||
session.is_verified = True
|
||||
session.verified_at = utcnow()
|
||||
if device_info:
|
||||
session.device_id = device_info.id
|
||||
|
||||
if sessions:
|
||||
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
|
||||
|
||||
await LoginSessionService.clear_login_method(user_id, token_id, redis)
|
||||
await db.commit()
|
||||
|
||||
return len(sessions) > 0
|
||||
|
||||
@@ -658,7 +581,7 @@ class LoginSessionService:
|
||||
await db.exec(
|
||||
select(exists()).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.is_verified == False, # noqa: E712
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > utcnow(),
|
||||
LoginSession.token_id == token_id,
|
||||
)
|
||||
|
||||
128
app/utils.py
128
app/utils.py
@@ -6,11 +6,15 @@ from datetime import UTC, datetime
|
||||
import functools
|
||||
import inspect
|
||||
from io import BytesIO
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.model import UserAgentInfo
|
||||
|
||||
|
||||
def unix_timestamp_to_windows(timestamp: int) -> int:
|
||||
"""Convert a Unix timestamp to a Windows timestamp."""
|
||||
@@ -154,81 +158,79 @@ def check_image(content: bytes, size: int, width: int, height: int) -> str:
|
||||
raise HTTPException(status_code=400, detail=f"Error processing image: {e}")
|
||||
|
||||
|
||||
def simplify_user_agent(user_agent: str | None, max_length: int = 200) -> str | None:
|
||||
"""
|
||||
简化 User-Agent 字符串,只保留 osu! 和关键设备系统信息浏览器
|
||||
def extract_user_agent(user_agent: str | None) -> "UserAgentInfo":
|
||||
from app.models.model import UserAgentInfo
|
||||
|
||||
Args:
|
||||
user_agent: 原始 User-Agent 字符串
|
||||
max_length: 最大长度限制
|
||||
raw_ua = user_agent or ""
|
||||
ua = raw_ua.strip()
|
||||
lower_ua = ua.lower()
|
||||
|
||||
Returns:
|
||||
简化后的 User-Agent 字符串,或 None
|
||||
"""
|
||||
import re
|
||||
info = UserAgentInfo(raw_ua=raw_ua)
|
||||
|
||||
if not user_agent:
|
||||
return None
|
||||
if not ua:
|
||||
return info
|
||||
|
||||
# 如果长度在限制内,直接返回
|
||||
if len(user_agent) <= max_length:
|
||||
return user_agent
|
||||
client_identifiers = ("osu!", "osu!lazer", "osu-framework")
|
||||
if any(identifier in lower_ua for identifier in client_identifiers):
|
||||
info.browser = "osu!"
|
||||
info.is_client = True
|
||||
return info
|
||||
|
||||
# 提取操作系统信息
|
||||
os_info = ""
|
||||
os_patterns = [
|
||||
r"(Windows[^;)]*)",
|
||||
r"(Mac OS[^;)]*)",
|
||||
r"(Linux[^;)]*)",
|
||||
r"(Android[^;)]*)",
|
||||
r"(iOS[^;)]*)",
|
||||
r"(iPhone[^;)]*)",
|
||||
r"(iPad[^;)]*)",
|
||||
]
|
||||
browser_patterns: tuple[tuple[re.Pattern[str], str], ...] = (
|
||||
(re.compile(r"OPR/(\d+(?:\.\d+)*)"), "Opera"),
|
||||
(re.compile(r"Edg/(\d+(?:\.\d+)*)"), "Edge"),
|
||||
(re.compile(r"Chrome/(\d+(?:\.\d+)*)"), "Chrome"),
|
||||
(re.compile(r"Firefox/(\d+(?:\.\d+)*)"), "Firefox"),
|
||||
(re.compile(r"Version/(\d+(?:\.\d+)*).*Safari"), "Safari"),
|
||||
(re.compile(r"Safari/(\d+(?:\.\d+)*)"), "Safari"),
|
||||
(re.compile(r"MSIE (\d+(?:\.\d+)*)"), "Internet Explorer"),
|
||||
(re.compile(r"Trident/.*rv:(\d+(?:\.\d+)*)"), "Internet Explorer"),
|
||||
)
|
||||
|
||||
for pattern in os_patterns:
|
||||
match = re.search(pattern, user_agent, re.IGNORECASE)
|
||||
for pattern, name in browser_patterns:
|
||||
match = pattern.search(ua)
|
||||
if match:
|
||||
os_info = match.group(1).strip()
|
||||
info.browser = name
|
||||
info.version = match.group(1)
|
||||
break
|
||||
|
||||
# 提取浏览器信息
|
||||
browser_info = ""
|
||||
browser_patterns = [
|
||||
r"(osu![^)]*)", # osu! 客户端
|
||||
r"(Chrome/[\d.]+)",
|
||||
r"(Firefox/[\d.]+)",
|
||||
r"(Safari/[\d.]+)",
|
||||
r"(Edge/[\d.]+)",
|
||||
r"(Opera/[\d.]+)",
|
||||
]
|
||||
os_patterns: tuple[tuple[re.Pattern[str], str], ...] = (
|
||||
(re.compile(r"windows nt 10"), "Windows 10"),
|
||||
(re.compile(r"windows nt 6\.3"), "Windows 8.1"),
|
||||
(re.compile(r"windows nt 6\.2"), "Windows 8"),
|
||||
(re.compile(r"windows nt 6\.1"), "Windows 7"),
|
||||
(re.compile(r"windows nt 6\.0"), "Windows Vista"),
|
||||
(re.compile(r"windows nt 5\.1"), "Windows XP"),
|
||||
(re.compile(r"mac os x"), "macOS"),
|
||||
(re.compile(r"iphone os"), "iOS"),
|
||||
(re.compile(r"ipad;"), "iPadOS"),
|
||||
(re.compile(r"android"), "Android"),
|
||||
(re.compile(r"linux"), "Linux"),
|
||||
)
|
||||
|
||||
for pattern in browser_patterns:
|
||||
match = re.search(pattern, user_agent, re.IGNORECASE)
|
||||
if match:
|
||||
browser_info = match.group(1).strip()
|
||||
# 如果找到了 osu! 客户端,优先使用
|
||||
if "osu!" in browser_info.lower():
|
||||
break
|
||||
for pattern, name in os_patterns:
|
||||
if pattern.search(lower_ua):
|
||||
info.os = name
|
||||
break
|
||||
|
||||
# 构建简化的 User-Agent
|
||||
parts = []
|
||||
if os_info:
|
||||
parts.append(os_info)
|
||||
if browser_info:
|
||||
parts.append(browser_info)
|
||||
info.is_mobile = any(keyword in lower_ua for keyword in ("mobile", "iphone", "android", "ipod"))
|
||||
info.is_tablet = any(keyword in lower_ua for keyword in ("ipad", "tablet"))
|
||||
# Only classify as PC if not mobile or tablet
|
||||
if (
|
||||
not info.is_mobile
|
||||
and not info.is_tablet
|
||||
and any(keyword in lower_ua for keyword in ("windows", "macintosh", "linux", "x11"))
|
||||
):
|
||||
info.is_pc = True
|
||||
|
||||
if parts:
|
||||
simplified = "; ".join(parts)
|
||||
else:
|
||||
# 如果没有识别到关键信息,截断原始字符串
|
||||
simplified = user_agent[: max_length - 3] + "..."
|
||||
if info.is_tablet:
|
||||
info.platform = "tablet"
|
||||
elif info.is_mobile:
|
||||
info.platform = "mobile"
|
||||
elif info.is_pc:
|
||||
info.platform = "pc"
|
||||
|
||||
# 确保不超过最大长度
|
||||
if len(simplified) > max_length:
|
||||
simplified = simplified[: max_length - 3] + "..."
|
||||
|
||||
return simplified
|
||||
return info
|
||||
|
||||
|
||||
# https://github.com/encode/starlette/blob/master/starlette/_utils.py
|
||||
|
||||
6
main.py
6
main.py
@@ -25,10 +25,6 @@ from app.router import (
|
||||
from app.router.redirect import redirect_router
|
||||
from app.router.v1 import api_v1_public_router
|
||||
from app.scheduler.cache_scheduler import start_cache_scheduler, stop_cache_scheduler
|
||||
from app.scheduler.database_cleanup_scheduler import (
|
||||
start_database_cleanup_scheduler,
|
||||
stop_database_cleanup_scheduler,
|
||||
)
|
||||
from app.service.beatmap_download_service import download_service
|
||||
from app.service.beatmapset_update_service import init_beatmapset_update_service
|
||||
from app.service.calculate_all_user_rank import calculate_user_rank
|
||||
@@ -68,7 +64,6 @@ async def lifespan(app: FastAPI):
|
||||
await start_email_processor() # 启动邮件队列处理器
|
||||
await download_service.start_health_check() # 启动下载服务健康检查
|
||||
await start_cache_scheduler() # 启动缓存调度器
|
||||
await start_database_cleanup_scheduler() # 启动数据库清理调度器
|
||||
init_beatmapset_update_service(fetcher) # 初始化谱面集更新服务
|
||||
redis_message_system.start() # 启动 Redis 消息系统
|
||||
load_achievements()
|
||||
@@ -83,7 +78,6 @@ async def lifespan(app: FastAPI):
|
||||
stop_scheduler()
|
||||
redis_message_system.stop() # 停止 Redis 消息系统
|
||||
await stop_cache_scheduler() # 停止缓存调度器
|
||||
await stop_database_cleanup_scheduler() # 停止数据库清理调度器
|
||||
await download_service.stop_health_check() # 停止下载服务健康检查
|
||||
await stop_email_processor() # 停止邮件队列处理器
|
||||
await engine.dispose()
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
"""session: support multi-session
|
||||
|
||||
Revision ID: 72a9b8f3f863
|
||||
Revises: b1ac2154bd0d
|
||||
Create Date: 2025-10-02 07:17:19.297498
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "72a9b8f3f863"
|
||||
down_revision: str | Sequence[str] | None = "b1ac2154bd0d"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"trusted_devices",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("ip_address", sa.VARCHAR(length=45), nullable=False),
|
||||
sa.Column("user_agent", sa.Text(), nullable=False),
|
||||
sa.Column("client_type", sa.VARCHAR(length=10), nullable=False),
|
||||
sa.Column("web_uuid", sa.VARCHAR(length=36), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("last_used_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["lazer_users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.alter_column(
|
||||
"login_sessions",
|
||||
"is_new_location",
|
||||
new_column_name="is_new_device",
|
||||
existing_type=mysql.TINYINT(display_width=1),
|
||||
)
|
||||
op.create_index(op.f("ix_trusted_devices_user_id"), "trusted_devices", ["user_id"], unique=False)
|
||||
op.add_column("login_sessions", sa.Column("web_uuid", sa.VARCHAR(length=36), nullable=True))
|
||||
op.alter_column(
|
||||
"login_sessions",
|
||||
"ip_address",
|
||||
existing_type=mysql.VARCHAR(length=255),
|
||||
type_=sa.VARCHAR(length=45),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"login_sessions", "user_agent", existing_type=mysql.VARCHAR(length=250), type_=sa.Text(), existing_nullable=True
|
||||
)
|
||||
op.drop_index(op.f("ix_login_sessions_session_token"), table_name="login_sessions")
|
||||
op.create_foreign_key(None, "login_sessions", "lazer_users", ["user_id"], ["id"])
|
||||
op.drop_column("login_sessions", "country_code")
|
||||
op.drop_column("login_sessions", "session_token")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("login_sessions", sa.Column("session_token", sa.VARCHAR(length=64), nullable=True))
|
||||
op.add_column("login_sessions", sa.Column("country_code", sa.VARCHAR(length=255), nullable=True))
|
||||
op.create_index(op.f("ix_login_sessions_session_token"), "login_sessions", ["session_token"], unique=False)
|
||||
|
||||
op.alter_column(
|
||||
"login_sessions",
|
||||
"user_agent",
|
||||
existing_type=sa.Text(),
|
||||
type_=mysql.VARCHAR(length=250),
|
||||
existing_nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"login_sessions",
|
||||
"ip_address",
|
||||
existing_type=sa.String(length=45),
|
||||
type_=mysql.VARCHAR(length=255),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
op.drop_column("login_sessions", "web_uuid")
|
||||
op.alter_column(
|
||||
"login_sessions",
|
||||
"is_new_device",
|
||||
new_column_name="is_new_location",
|
||||
existing_type=mysql.TINYINT(display_width=1),
|
||||
)
|
||||
op.drop_constraint(op.f("fk_login_sessions_user_id_lazer_users"), "login_sessions", type_="foreignkey")
|
||||
|
||||
op.drop_index(op.f("ix_trusted_devices_user_id"), table_name="trusted_devices")
|
||||
op.drop_table("trusted_devices")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,40 @@
|
||||
"""auth: add refresh_token_expires_at
|
||||
|
||||
Revision ID: 7fe1319250c5
|
||||
Revises: 72a9b8f3f863
|
||||
Create Date: 2025-10-02 10:50:21.169065
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "7fe1319250c5"
|
||||
down_revision: str | Sequence[str] | None = "72a9b8f3f863"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("oauth_tokens", sa.Column("refresh_token_expires_at", sa.DateTime(), nullable=True))
|
||||
op.create_index(op.f("ix_oauth_tokens_expires_at"), "oauth_tokens", ["expires_at"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_tokens_refresh_token_expires_at"), "oauth_tokens", ["refresh_token_expires_at"], unique=False
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_oauth_tokens_refresh_token_expires_at"), table_name="oauth_tokens")
|
||||
op.drop_index(op.f("ix_oauth_tokens_expires_at"), table_name="oauth_tokens")
|
||||
op.drop_column("oauth_tokens", "refresh_token_expires_at")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,35 @@
|
||||
"""session: add device_id to LoginSession
|
||||
|
||||
Revision ID: 9556cd2ec11f
|
||||
Revises: 7fe1319250c5
|
||||
Create Date: 2025-10-02 11:03:09.803140
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9556cd2ec11f"
|
||||
down_revision: str | Sequence[str] | None = "7fe1319250c5"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("login_sessions", sa.Column("device_id", sa.BigInteger(), nullable=True))
|
||||
op.create_index(op.f("ix_login_sessions_device_id"), "login_sessions", ["device_id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("login_sessions", "device_id")
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user