style(project): remove from __future__ import annotations
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .verify_session import SessionState, VerifySessionMiddleware
|
||||
|
||||
__all__ = ["SessionState", "VerifySessionMiddleware"]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.config import settings
|
||||
from app.middleware.verify_session import VerifySessionMiddleware
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ FastAPI会话验证中间件
|
||||
基于osu-web的会话验证系统,适配FastAPI框架
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import ClassVar
|
||||
|
||||
@@ -28,6 +26,96 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
logger = log("Middleware")
|
||||
|
||||
|
||||
class SessionState:
|
||||
"""会话状态类
|
||||
|
||||
简化版本的会话状态管理
|
||||
"""
|
||||
|
||||
def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession, api_version: int = 0) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.redis = redis
|
||||
self.db = db
|
||||
self.api_version = api_version
|
||||
self._verification_method: str | None = None
|
||||
|
||||
def is_verified(self) -> bool:
|
||||
"""检查会话是否已验证"""
|
||||
return self.session.is_verified
|
||||
|
||||
async def get_method(self) -> str:
|
||||
"""获取验证方法"""
|
||||
if self._verification_method is None:
|
||||
# 从Redis获取已设置的方法
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
self._verification_method = await LoginSessionService.get_login_method(
|
||||
self.user.id, token_id, self.redis
|
||||
)
|
||||
|
||||
if self._verification_method is None:
|
||||
if self.api_version < SUPPORT_TOTP_VERIFICATION_VER:
|
||||
self._verification_method = "mail"
|
||||
return self._verification_method
|
||||
|
||||
await self.user.awaitable_attrs.totp_key
|
||||
totp_key = self.user.totp_key
|
||||
self._verification_method = "totp" if totp_key else "mail"
|
||||
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
await LoginSessionService.set_login_method(
|
||||
self.user.id, token_id, self._verification_method, self.redis
|
||||
)
|
||||
|
||||
return self._verification_method
|
||||
|
||||
async def mark_verified(self) -> None:
|
||||
"""标记会话为已验证"""
|
||||
try:
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
await LoginSessionService.mark_session_verified(
|
||||
self.db,
|
||||
self.redis,
|
||||
self.user.id,
|
||||
token_id,
|
||||
self.session.ip_address,
|
||||
extract_user_agent(self.session.user_agent),
|
||||
self.session.web_uuid,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error marking verified: {e}")
|
||||
|
||||
async def issue_mail_if_needed(self) -> None:
|
||||
"""如果需要,发送验证邮件"""
|
||||
try:
|
||||
if await self.get_method() == "mail":
|
||||
from app.service.verification_service import EmailVerificationService
|
||||
|
||||
# 这里可以触发邮件发送
|
||||
await EmailVerificationService.send_verification_email(
|
||||
self.db, self.redis, self.user.id, self.user.username, self.user.email, None, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error issuing mail: {e}")
|
||||
|
||||
def get_key(self) -> str:
|
||||
"""获取会话密钥"""
|
||||
return str(self.session.id) if self.session.id else ""
|
||||
|
||||
@property
|
||||
def key_for_event(self) -> str:
|
||||
"""获取用于事件广播的会话密钥"""
|
||||
return LoginSessionService.get_key_for_event(self.get_key())
|
||||
|
||||
@property
|
||||
def user_id(self) -> int:
|
||||
"""获取用户ID"""
|
||||
return self.user.id
|
||||
|
||||
|
||||
class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
"""会话验证中间件
|
||||
|
||||
@@ -192,93 +280,3 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"error": "Verification initiation failed"}
|
||||
)
|
||||
|
||||
|
||||
class SessionState:
|
||||
"""会话状态类
|
||||
|
||||
简化版本的会话状态管理
|
||||
"""
|
||||
|
||||
def __init__(self, session: LoginSession, user: User, redis: Redis, db: AsyncSession, api_version: int = 0) -> None:
|
||||
self.session = session
|
||||
self.user = user
|
||||
self.redis = redis
|
||||
self.db = db
|
||||
self.api_version = api_version
|
||||
self._verification_method: str | None = None
|
||||
|
||||
def is_verified(self) -> bool:
|
||||
"""检查会话是否已验证"""
|
||||
return self.session.is_verified
|
||||
|
||||
async def get_method(self) -> str:
|
||||
"""获取验证方法"""
|
||||
if self._verification_method is None:
|
||||
# 从Redis获取已设置的方法
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
self._verification_method = await LoginSessionService.get_login_method(
|
||||
self.user.id, token_id, self.redis
|
||||
)
|
||||
|
||||
if self._verification_method is None:
|
||||
if self.api_version < SUPPORT_TOTP_VERIFICATION_VER:
|
||||
self._verification_method = "mail"
|
||||
return self._verification_method
|
||||
|
||||
await self.user.awaitable_attrs.totp_key
|
||||
totp_key = self.user.totp_key
|
||||
self._verification_method = "totp" if totp_key else "mail"
|
||||
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
await LoginSessionService.set_login_method(
|
||||
self.user.id, token_id, self._verification_method, self.redis
|
||||
)
|
||||
|
||||
return self._verification_method
|
||||
|
||||
async def mark_verified(self) -> None:
|
||||
"""标记会话为已验证"""
|
||||
try:
|
||||
token_id = self.session.token_id
|
||||
if token_id is not None:
|
||||
await LoginSessionService.mark_session_verified(
|
||||
self.db,
|
||||
self.redis,
|
||||
self.user.id,
|
||||
token_id,
|
||||
self.session.ip_address,
|
||||
extract_user_agent(self.session.user_agent),
|
||||
self.session.web_uuid,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error marking verified: {e}")
|
||||
|
||||
async def issue_mail_if_needed(self) -> None:
|
||||
"""如果需要,发送验证邮件"""
|
||||
try:
|
||||
if await self.get_method() == "mail":
|
||||
from app.service.verification_service import EmailVerificationService
|
||||
|
||||
# 这里可以触发邮件发送
|
||||
await EmailVerificationService.send_verification_email(
|
||||
self.db, self.redis, self.user.id, self.user.username, self.user.email, None, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error issuing mail: {e}")
|
||||
|
||||
def get_key(self) -> str:
|
||||
"""获取会话密钥"""
|
||||
return str(self.session.id) if self.session.id else ""
|
||||
|
||||
@property
|
||||
def key_for_event(self) -> str:
|
||||
"""获取用于事件广播的会话密钥"""
|
||||
return LoginSessionService.get_key_for_event(self.get_key())
|
||||
|
||||
@property
|
||||
def user_id(self) -> int:
|
||||
"""获取用户ID"""
|
||||
return self.user.id
|
||||
|
||||
Reference in New Issue
Block a user