chore(linter): make linter happy

This commit is contained in:
MingxuanGame
2025-09-30 07:57:08 +00:00
parent 0f637446df
commit 017b058e63
15 changed files with 99 additions and 120 deletions

View File

@@ -22,7 +22,7 @@ from jose import JWTError, jwt
from passlib.context import CryptContext
import pyotp
from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
# 密码哈希上下文
@@ -242,7 +242,7 @@ async def store_token(
statement = (
select(OAuthToken)
.where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id, OAuthToken.expires_at > utcnow())
.order_by(OAuthToken.created_at.desc())
.order_by(col(OAuthToken.created_at).desc())
)
active_tokens = (await db.exec(statement)).all()
@@ -369,9 +369,7 @@ def verify_totp_key(secret: str, code: str) -> bool:
return pyotp.TOTP(secret).verify(code, valid_window=1)
async def verify_totp_key_with_replay_protection(
user_id: int, secret: str, code: str, redis: Redis
) -> bool:
async def verify_totp_key_with_replay_protection(user_id: int, secret: str, code: str, redis: Redis) -> bool:
"""验证TOTP密钥并防止密钥重放攻击"""
if not pyotp.TOTP(secret).verify(code, valid_window=1):
return False

View File

@@ -12,7 +12,6 @@ from .api_version import APIVersion
from .database import Database, get_redis
from fastapi import Depends, HTTPException
from redis.asyncio import Redis
from fastapi.security import (
APIKeyQuery,
HTTPBearer,
@@ -20,6 +19,7 @@ from fastapi.security import (
OAuth2PasswordBearer,
SecurityScopes,
)
from redis.asyncio import Redis
from sqlmodel import select
security = HTTPBearer()
@@ -103,7 +103,7 @@ async def get_client_user(
db: Database,
redis: Annotated[Redis, Depends(get_redis)],
api_version: APIVersion,
user_and_token: UserAndToken = Depends(get_client_user_and_token)
user_and_token: UserAndToken = Depends(get_client_user_and_token),
):
from app.service.verification_service import LoginSessionService
@@ -128,10 +128,7 @@ async def get_client_user(
await LoginSessionService.set_login_method(user.id, token.id, verify_method, redis)
# 返回符合 osu! API 标准的错误响应
error_response = {
"error": "User not verified",
"method": verify_method
}
error_response = {"error": "User not verified", "method": verify_method}
raise HTTPException(status_code=401, detail=error_response)
return user

View File

@@ -8,7 +8,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
class SessionVerificationInterface(ABC):
@@ -19,7 +18,7 @@ class SessionVerificationInterface(ABC):
@classmethod
@abstractmethod
async def find_for_verification(cls, session_id: str) -> Optional[SessionVerificationInterface]:
async def find_for_verification(cls, session_id: str) -> SessionVerificationInterface | None:
"""根据会话ID查找会话用于验证
Args:
@@ -41,7 +40,7 @@ class SessionVerificationInterface(ABC):
pass
@abstractmethod
def get_verification_method(self) -> Optional[str]:
def get_verification_method(self) -> str | None:
"""获取当前验证方法
Returns:
@@ -69,6 +68,6 @@ class SessionVerificationInterface(ABC):
pass
@abstractmethod
def user_id(self) -> Optional[int]:
def user_id(self) -> int | None:
"""获取关联的用户ID"""
pass

View File

@@ -1,9 +1,5 @@
"""
中间件模块
from __future__ import annotations
提供会话验证和其他中间件功能
"""
from .verify_session import SessionState, VerifySessionMiddleware
from .verify_session import VerifySessionMiddleware, SessionState
__all__ = ["VerifySessionMiddleware", "SessionState"]
__all__ = ["SessionState", "VerifySessionMiddleware"]

View File

@@ -6,20 +6,21 @@
from __future__ import annotations
import asyncio
from typing import Awaitable, Callable, Optional
from collections.abc import Awaitable, Callable
from typing import ClassVar, Literal, cast
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import get_redis, with_db
from app.log import logger
from app.service.verification_service import LoginSessionService
from app.utils import bg_tasks
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlmodel.ext.asyncio.session import AsyncSession
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import with_db, get_redis
from app.log import logger
from app.service.verification_service import LoginSessionService
class SessionVerificationState:
"""会话验证状态管理类
@@ -39,7 +40,7 @@ class SessionVerificationState:
db: AsyncSession,
redis: Redis,
user: User,
) -> Optional[SessionVerificationState]:
) -> SessionVerificationState | None:
"""获取当前会话验证状态"""
try:
# 从请求头或token中获取会话信息
@@ -58,7 +59,7 @@ class SessionVerificationState:
return None
@staticmethod
def _extract_session_token(request: Request) -> Optional[str]:
def _extract_session_token(request: Request) -> str | None:
"""从请求中提取会话token"""
# 尝试从Authorization header提取
auth_header = request.headers.get("Authorization", "")
@@ -79,11 +80,11 @@ class SessionVerificationState:
# 智能选择验证方法
# 参考osu-web: API版本 < 20250913 或用户没有TOTP时使用邮件验证
# 这里简化为检查用户是否有TOTP
totp_key = getattr(self.user, 'totp_key', None)
current_method = 'totp' if totp_key else 'mail'
totp_key = getattr(self.user, "totp_key", None)
current_method = "totp" if totp_key else "mail"
# 设置验证方法
asyncio.create_task(self._set_verification_method(current_method))
bg_tasks.add_task(self._set_verification_method, current_method)
return current_method
@@ -91,11 +92,14 @@ class SessionVerificationState:
"""内部方法:设置验证方法"""
try:
token_id = self.session.token_id
if token_id is not None and method in ['totp', 'mail']:
if token_id is not None and method in ["totp", "mail"]:
# 类型检查确保method是正确的字面量类型
verification_method = method if method in ['totp', 'mail'] else 'totp'
verification_method = method if method in ["totp", "mail"] else "totp"
await LoginSessionService.set_login_method(
self.user.id, token_id, verification_method, self.redis # type: ignore
self.user.id,
token_id,
cast(Literal["totp", "mail"], verification_method),
self.redis,
)
except Exception as e:
logger.error(f"[Session Verification] Error setting verification method: {e}")
@@ -112,9 +116,7 @@ class SessionVerificationState:
try:
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.mark_session_verified(
db, self.redis, self.user.id, token_id
)
await LoginSessionService.mark_session_verified(db, self.redis, self.user.id, token_id)
finally:
await db.close()
except Exception as e:
@@ -142,8 +144,7 @@ class SessionVerificationState:
db = with_db()
try:
await EmailVerificationService.send_verification_email(
db, self.redis, self.user.id, self.user.username,
self.user.email, None, None
db, self.redis, self.user.id, self.user.username, self.user.email, None, None
)
finally:
await db.close()
@@ -158,7 +159,7 @@ class SessionVerificationController:
"""
# 需要跳过验证的路由参考osu-web的SKIP_VERIFICATION_ROUTES
SKIP_VERIFICATION_ROUTES = {
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/me",
@@ -190,26 +191,18 @@ class SessionVerificationController:
# API请求返回JSON响应
if request.url.path.startswith("/api/"):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"method": method}
)
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"method": method})
# 其他情况可以扩展支持HTML响应
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"authentication": "verify",
"method": method,
"message": "Session verification required"
}
content={"authentication": "verify", "method": method, "message": "Session verification required"},
)
except Exception as e:
logger.error(f"[Session Verification] Error initiating verification: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Verification initiation failed"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Verification initiation failed"
)
@@ -262,7 +255,7 @@ class SessionVerificationMiddleware:
# 出错时允许请求继续,避免阻塞正常流程
return await call_next(request)
async def _get_user(self, request: Request) -> Optional[User]:
async def _get_user(self, request: Request) -> User | None:
"""获取当前用户"""
try:
# 这里需要手动获取用户,因为在中间件中无法直接使用依赖注入

View File

@@ -1,14 +1,10 @@
"""
中间件设置和配置
展示如何将会话验证中间件集成到FastAPI应用中
"""
from fastapi import FastAPI
from __future__ import annotations
from app.config import settings
from app.middleware.verify_session import VerifySessionMiddleware
from fastapi import FastAPI
def setup_session_verification_middleware(app: FastAPI) -> None:
"""设置会话验证中间件
@@ -22,9 +18,11 @@ def setup_session_verification_middleware(app: FastAPI) -> None:
# 可以在这里添加中间件配置日志
from app.log import logger
logger.info("[Middleware] Session verification middleware enabled")
else:
from app.log import logger
logger.info("[Middleware] Session verification middleware disabled")
@@ -41,4 +39,5 @@ def setup_all_middlewares(app: FastAPI) -> None:
# app.add_middleware(OtherMiddleware)
from app.log import logger
logger.info("[Middleware] All middlewares configured")

View File

@@ -6,22 +6,23 @@ FastAPI会话验证中间件
from __future__ import annotations
from typing import Callable, Optional
from collections.abc import Callable
from typing import ClassVar
from app.auth import get_token_by_access_token
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import get_redis, with_db
from app.log import logger
from app.service.verification_service import LoginSessionService
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.middleware.base import BaseHTTPMiddleware
from app.database.lazer_user import User
from app.database.verification import LoginSession
from app.dependencies.database import with_db, get_redis
from app.auth import get_token_by_access_token
from app.log import logger
from app.service.verification_service import LoginSessionService
from sqlmodel import select
class VerifySessionMiddleware(BaseHTTPMiddleware):
"""会话验证中间件
@@ -30,7 +31,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
"""
# 需要跳过验证的路由
SKIP_VERIFICATION_ROUTES = {
SKIP_VERIFICATION_ROUTES: ClassVar[set[str]] = {
"/api/v2/session/verify",
"/api/v2/session/verify/reissue",
"/api/v2/me",
@@ -44,7 +45,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
}
# 需要强制验证的路由模式(敏感操作)
ALWAYS_VERIFY_PATTERNS = {
ALWAYS_VERIFY_PATTERNS: ClassVar[set[str]] = {
"/api/v2/account/",
"/api/v2/settings/",
"/api/private/admin/",
@@ -110,9 +111,9 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
return True
# 特权用户或非活跃用户需要验证
if hasattr(user, 'is_privileged') and user.is_privileged():
return True
if hasattr(user, 'is_inactive') and user.is_inactive():
# if hasattr(user, 'is_privileged') and user.is_privileged():
# return True
if not user.is_active:
return True
# 安全方法GET/HEAD/OPTIONS一般不需要验证
@@ -123,7 +124,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
# 修改操作POST/PUT/DELETE/PATCH需要验证
return method in {"POST", "PUT", "DELETE", "PATCH"}
async def _get_current_user(self, request: Request) -> Optional[User]:
async def _get_current_user(self, request: Request) -> User | None:
"""获取当前用户"""
try:
# 从Authorization header提取token
@@ -151,7 +152,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
logger.debug(f"[Verify Session Middleware] Error getting user: {e}")
return None
async def _get_session_state(self, request: Request, user: User) -> Optional[SessionState]:
async def _get_session_state(self, request: Request, user: User) -> SessionState | None:
"""获取会话状态"""
try:
# 提取会话token这里简化为使用相同的auth token
@@ -191,17 +192,13 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
# 返回验证要求响应
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"method": method,
"message": "Session verification required"
}
content={"method": method, "message": "Session verification required"},
)
except Exception as e:
logger.error(f"[Verify Session Middleware] Error initiating verification: {e}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": "Verification initiation failed"}
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"error": "Verification initiation failed"}
)
@@ -216,7 +213,7 @@ class SessionState:
self.user = user
self.redis = redis
self.db = db
self._verification_method: Optional[str] = None
self._verification_method: str | None = None
def is_verified(self) -> bool:
"""检查会话是否已验证"""
@@ -236,8 +233,8 @@ class SessionState:
if self._verification_method is None:
# 检查用户是否有TOTP密钥
await self.user.awaitable_attrs.totp_key # 预加载
totp_key = getattr(self.user, 'totp_key', None)
self._verification_method = 'totp' if totp_key else 'mail'
totp_key = getattr(self.user, "totp_key", None)
self._verification_method = "totp" if totp_key else "mail"
# 保存选择的方法
token_id = self.session.token_id
@@ -253,9 +250,7 @@ class SessionState:
try:
token_id = self.session.token_id
if token_id is not None:
await LoginSessionService.mark_session_verified(
self.db, self.redis, self.user.id, token_id
)
await LoginSessionService.mark_session_verified(self.db, self.redis, self.user.id, token_id)
self.session.is_verified = True # 更新本地状态
except Exception as e:
logger.error(f"[Session State] Error marking verified: {e}")
@@ -268,8 +263,7 @@ class SessionState:
# 这里可以触发邮件发送
await EmailVerificationService.send_verification_email(
self.db, self.redis, self.user.id, self.user.username,
self.user.email, None, None
self.db, self.redis, self.user.id, self.user.username, self.user.email, None, None
)
except Exception as e:
logger.error(f"[Session State] Error issuing mail: {e}")

View File

@@ -17,13 +17,14 @@ from app.models.totp import FinishStatus, StartCreateTotpKeyResp
from .router import router
from fastapi import Body, Depends, HTTPException, Security
import pyotp
from pydantic import BaseModel
import pyotp
from redis.asyncio import Redis
class TotpStatusResp(BaseModel):
"""TOTP状态响应"""
enabled: bool
created_at: str | None = None
@@ -42,10 +43,7 @@ async def get_totp_status(
totp_key = await current_user.awaitable_attrs.totp_key
if totp_key:
return TotpStatusResp(
enabled=True,
created_at=totp_key.created_at.isoformat()
)
return TotpStatusResp(enabled=True, created_at=totp_key.created_at.isoformat())
else:
return TotpStatusResp(enabled=False)

View File

@@ -354,7 +354,7 @@ async def get_user_ranking(
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
# 查询总数
count_query = select(func.count(UserStatistics.id)).where(*wheres)
count_query = select(func.count()).select_from(UserStatistics).where(*wheres)
total_count_result = await session.exec(count_query)
total_count = total_count_result.one()

View File

@@ -86,7 +86,7 @@ async def verify_session(
if verify_method is None:
# 智能选择验证方法参考osu-web实现
# API版本较老或用户未设置TOTP时强制使用邮件验证
#print(api_version, totp_key)
# print(api_version, totp_key)
if api_version < 20240101 or totp_key is None:
verify_method = "mail"
else:
@@ -153,11 +153,11 @@ async def verify_session(
}
# 如果有具体的错误原因,添加到响应中
if hasattr(e, 'reason') and e.reason:
if hasattr(e, "reason") and e.reason:
error_response["reason"] = e.reason
# 如果需要重新发送邮件验证码
if hasattr(e, 'should_reissue') and e.should_reissue and verify_method == "mail":
if hasattr(e, "should_reissue") and e.should_reissue and verify_method == "mail":
try:
await EmailVerificationService.send_verification_email(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent

View File

@@ -80,7 +80,10 @@ class DatabaseCleanupScheduler:
# 只在有清理记录时输出总结
total_cleaned = expired_codes + expired_sessions + unverified_sessions
if total_cleaned > 0:
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}, unverified: {unverified_sessions}")
logger.debug(
f"Scheduled cleanup completed - codes: {expired_codes}, "
f"sessions: {expired_sessions}, unverified: {unverified_sessions}"
)
except Exception as e:
logger.error(f"Error during scheduled database cleanup: {e!s}")

View File

@@ -10,7 +10,7 @@ from app.database.verification import EmailVerification, LoginSession
from app.log import logger
from app.utils import utcnow
from sqlmodel import select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -107,7 +107,7 @@ class DatabaseCleanupService:
# 查找指定天数前的已使用验证码记录
cutoff_time = utcnow() - timedelta(days=days_old)
stmt = select(EmailVerification).where(EmailVerification.is_used == True)
stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
result = await db.exec(stmt)
all_used_codes = result.all()
@@ -152,8 +152,7 @@ class DatabaseCleanupService:
# 查找指定时间前创建且仍未验证的会话记录
stmt = select(LoginSession).where(
LoginSession.is_verified == False,
LoginSession.created_at < cutoff_time
col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_time
)
result = await db.exec(stmt)
unverified_sessions = result.all()
@@ -168,7 +167,8 @@ class DatabaseCleanupService:
if deleted_count > 0:
logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} unverified login sessions older than {hours_old} hour(s)"
f"[Cleanup Service] Cleaned up {deleted_count} unverified "
f"login sessions older than {hours_old} hour(s)"
)
return deleted_count
@@ -194,7 +194,7 @@ class DatabaseCleanupService:
# 查找指定天数前的已验证会话记录
cutoff_time = utcnow() - timedelta(days=days_old)
stmt = select(LoginSession).where(LoginSession.is_verified == True)
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
result = await db.exec(stmt)
all_verified_sessions = result.all()
@@ -290,14 +290,13 @@ class DatabaseCleanupService:
# 统计1小时前未验证的登录会话数量
unverified_sessions_stmt = select(LoginSession).where(
LoginSession.is_verified == False,
LoginSession.created_at < cutoff_1_hour
col(LoginSession.is_verified).is_(False), LoginSession.created_at < cutoff_1_hour
)
unverified_sessions_result = await db.exec(unverified_sessions_stmt)
unverified_sessions_count = len(unverified_sessions_result.all())
# 统计7天前的已使用验证码数量
old_used_codes_stmt = select(EmailVerification).where(EmailVerification.is_used == True)
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
old_used_codes_result = await db.exec(old_used_codes_stmt)
all_used_codes = old_used_codes_result.all()
old_used_codes_count = len(
@@ -305,7 +304,7 @@ class DatabaseCleanupService:
)
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(LoginSession.is_verified == True)
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
all_verified_sessions = old_verified_sessions_result.all()
old_verified_sessions_count = len(

View File

@@ -7,11 +7,10 @@ from __future__ import annotations
from datetime import timedelta
import secrets
import string
from typing import Literal, Optional
from typing import Literal
from app.config import settings
from app.database.verification import EmailVerification, LoginSession
from app.interfaces.session_verification import SessionVerificationInterface
from app.log import logger
from app.service.client_detection_service import ClientDetectionService, ClientInfo
from app.service.device_trust_service import DeviceTrustService
@@ -517,7 +516,7 @@ class LoginSessionService:
# Session verification interface methods
@staticmethod
async def find_for_verification(db: AsyncSession, session_id: str) -> Optional[LoginSession]:
async def find_for_verification(db: AsyncSession, session_id: str) -> LoginSession | None:
"""根据会话ID查找会话用于验证"""
try:
result = await db.exec(

View File

@@ -9,6 +9,7 @@ from app.dependencies.database import Database, engine, get_redis, redis_client
from app.dependencies.fetcher import get_fetcher
from app.dependencies.scheduler import start_scheduler, stop_scheduler
from app.log import logger
from app.middleware.verify_session import VerifySessionMiddleware
from app.router import (
api_v1_router,
api_v2_router,
@@ -46,8 +47,6 @@ from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_limiter import FastAPILimiter
import sentry_sdk
from app.middleware.verify_session import VerifySessionMiddleware
@asynccontextmanager
async def lifespan(app: FastAPI):

View File

@@ -5,6 +5,7 @@ Revises: fe8e9f3da298
Create Date: 2025-09-24 00:46:57.367742
"""
from __future__ import annotations
from collections.abc import Sequence
@@ -23,8 +24,12 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("login_sessions", sa.Column("session_token", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=True))
op.add_column("login_sessions", sa.Column("verification_method", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True))
op.add_column(
"login_sessions", sa.Column("session_token", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=True)
)
op.add_column(
"login_sessions", sa.Column("verification_method", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True)
)
op.create_index(op.f("ix_login_sessions_session_token"), "login_sessions", ["session_token"], unique=False)
# ### end Alembic commands ###