From 3a434ee02c7cf50a599bc50346138d8adf6bbf45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=92=95=E8=B0=B7=E9=85=B1?= Date: Wed, 24 Sep 2025 03:04:09 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=83=A8=E5=88=86=E6=8A=A5?= =?UTF-8?q?=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/dependencies/user.py | 35 ++++++++++- app/router/notification/server.py | 11 +++- app/router/v2/session_verify.py | 4 +- app/scheduler/database_cleanup_scheduler.py | 7 ++- app/service/database_cleanup_service.py | 69 +++++++++++++++++++-- 5 files changed, 113 insertions(+), 13 deletions(-) diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 674df37..aa091b4 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -8,9 +8,11 @@ from app.database import User from app.database.auth import OAuthToken, V1APIKeys from app.models.oauth import OAuth2ClientCredentialsBearer -from .database import Database +from .api_version import APIVersion +from .database import Database, get_redis from fastapi import Depends, HTTPException +from redis.asyncio import Redis from fastapi.security import ( APIKeyQuery, HTTPBearer, @@ -97,13 +99,40 @@ async def get_client_user_no_verified(user_and_token: UserAndToken = Depends(get return user_and_token[0] -async def get_client_user(db: Database, user_and_token: UserAndToken = Depends(get_client_user_and_token)): +async def get_client_user( + db: Database, + redis: Annotated[Redis, Depends(get_redis)], + api_version: APIVersion, + user_and_token: UserAndToken = Depends(get_client_user_and_token) +): from app.service.verification_service import LoginSessionService user, token = user_and_token if await LoginSessionService.check_is_need_verification(db, user.id, token.id): - raise HTTPException(status_code=403, detail="User not verified") + # 获取当前验证方式 + verify_method = None + if api_version >= 20250913: + verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis) + + if verify_method is None: + # 智能选择验证方式(有TOTP优先TOTP) + totp_key = await user.awaitable_attrs.totp_key + if totp_key is not None and api_version >= 20240101: + verify_method = "totp" + else: + verify_method = "mail" + + # 设置选择的验证方法到Redis中,避免重复选择 + if api_version >= 20250913: + await LoginSessionService.set_login_method(user.id, token.id, verify_method, redis) + + # 返回符合 osu! API 标准的错误响应 + error_response = { + "error": "User not verified", + "method": verify_method + } + raise HTTPException(status_code=401, detail=error_response) return user diff --git a/app/router/notification/server.py b/app/router/notification/server.py index 16e9836..7fe8859 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -48,9 +48,18 @@ class ChatServer: user_id = user.id if user_id in self.connect_client: del self.connect_client[user_id] + + # 创建频道ID列表的副本以避免在迭代过程中修改字典 + channel_ids_to_process = [] for channel_id, channel in self.channels.items(): if user_id in channel: - channel.remove(user_id) + channel_ids_to_process.append(channel_id) + + # 现在安全地处理每个频道 + for channel_id in channel_ids_to_process: + # 再次检查用户是否仍在频道中(防止并发修改) + if channel_id in self.channels and user_id in self.channels[channel_id]: + self.channels[channel_id].remove(user_id) # 使用明确的查询避免延迟加载 db_channel = ( await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id)) diff --git a/app/router/v2/session_verify.py b/app/router/v2/session_verify.py index 88b9201..ce60225 100644 --- a/app/router/v2/session_verify.py +++ b/app/router/v2/session_verify.py @@ -86,7 +86,7 @@ async def verify_session( if verify_method is None: # 智能选择验证方法(参考osu-web实现) # API版本较老或用户未设置TOTP时强制使用邮件验证 - print(api_version, totp_key) + #print(api_version, totp_key) if api_version < 20240101 or totp_key is None: verify_method = "mail" else: @@ -148,8 +148,8 @@ async def verify_session( # 构建更详细的错误响应(参考osu-web的错误处理) error_response = { - "method": verify_method, "error": str(e), + "method": verify_method, } # 如果有具体的错误原因,添加到响应中 diff --git a/app/scheduler/database_cleanup_scheduler.py b/app/scheduler/database_cleanup_scheduler.py index 1a21bcf..43a94a2 100644 --- a/app/scheduler/database_cleanup_scheduler.py +++ b/app/scheduler/database_cleanup_scheduler.py @@ -74,10 +74,13 @@ class DatabaseCleanupScheduler: # 清理过期的登录会话 expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db) + # 清理1小时前未验证的登录会话 + unverified_sessions = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1) + # 只在有清理记录时输出总结 - total_cleaned = expired_codes + expired_sessions + total_cleaned = expired_codes + expired_sessions + unverified_sessions if total_cleaned > 0: - logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}") + logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}, unverified: {unverified_sessions}") except Exception as e: logger.error(f"Error during scheduled database cleanup: {e!s}") diff --git a/app/service/database_cleanup_service.py b/app/service/database_cleanup_service.py index 75dfce9..6fea2df 100644 --- a/app/service/database_cleanup_service.py +++ b/app/service/database_cleanup_service.py @@ -10,7 +10,7 @@ from app.database.verification import EmailVerification, LoginSession from app.log import logger from app.utils import utcnow -from sqlmodel import col, select +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -107,7 +107,7 @@ class DatabaseCleanupService: # 查找指定天数前的已使用验证码记录 cutoff_time = utcnow() - timedelta(days=days_old) - stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True)) + stmt = select(EmailVerification).where(EmailVerification.is_used == True) result = await db.exec(stmt) all_used_codes = result.all() @@ -134,6 +134,50 @@ class DatabaseCleanupService: logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}") return 0 + @staticmethod + async def cleanup_unverified_login_sessions(db: AsyncSession, hours_old: int = 1) -> int: + """ + 清理指定小时前创建但仍未验证的登录会话 + + Args: + db: 数据库会话 + hours_old: 清理多少小时前创建但仍未验证的会话,默认1小时 + + Returns: + int: 清理的记录数 + """ + try: + # 计算截止时间 + cutoff_time = utcnow() - timedelta(hours=hours_old) + + # 查找指定时间前创建且仍未验证的会话记录 + stmt = select(LoginSession).where( + LoginSession.is_verified == False, + LoginSession.created_at < cutoff_time + ) + result = await db.exec(stmt) + unverified_sessions = result.all() + + # 删除未验证的会话记录 + deleted_count = 0 + for session in unverified_sessions: + await db.delete(session) + deleted_count += 1 + + await db.commit() + + if deleted_count > 0: + logger.debug( + f"[Cleanup Service] Cleaned up {deleted_count} unverified login sessions older than {hours_old} hour(s)" + ) + + return deleted_count + + except Exception as e: + await db.rollback() + logger.error(f"[Cleanup Service] Error cleaning unverified login sessions: {e!s}") + return 0 + @staticmethod async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int: """ @@ -150,7 +194,7 @@ class DatabaseCleanupService: # 查找指定天数前的已验证会话记录 cutoff_time = utcnow() - timedelta(days=days_old) - stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) + stmt = select(LoginSession).where(LoginSession.is_verified == True) result = await db.exec(stmt) all_verified_sessions = result.all() @@ -200,6 +244,9 @@ class DatabaseCleanupService: # 清理过期的登录会话 results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db) + # 清理1小时前未验证的登录会话 + results["unverified_login_sessions"] = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1) + # 清理7天前的已使用验证码 results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7) @@ -227,6 +274,7 @@ class DatabaseCleanupService: """ try: current_time = utcnow() + cutoff_1_hour = current_time - timedelta(hours=1) cutoff_7_days = current_time - timedelta(days=7) cutoff_30_days = current_time - timedelta(days=30) @@ -240,8 +288,16 @@ class DatabaseCleanupService: expired_sessions_result = await db.exec(expired_sessions_stmt) expired_sessions_count = len(expired_sessions_result.all()) + # 统计1小时前未验证的登录会话数量 + unverified_sessions_stmt = select(LoginSession).where( + LoginSession.is_verified == False, + LoginSession.created_at < cutoff_1_hour + ) + unverified_sessions_result = await db.exec(unverified_sessions_stmt) + unverified_sessions_count = len(unverified_sessions_result.all()) + # 统计7天前的已使用验证码数量 - old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True)) + old_used_codes_stmt = select(EmailVerification).where(EmailVerification.is_used == True) old_used_codes_result = await db.exec(old_used_codes_stmt) all_used_codes = old_used_codes_result.all() old_used_codes_count = len( @@ -249,7 +305,7 @@ class DatabaseCleanupService: ) # 统计30天前的已验证会话数量 - old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) + old_verified_sessions_stmt = select(LoginSession).where(LoginSession.is_verified == True) old_verified_sessions_result = await db.exec(old_verified_sessions_stmt) all_verified_sessions = old_verified_sessions_result.all() old_verified_sessions_count = len( @@ -263,10 +319,12 @@ class DatabaseCleanupService: return { "expired_verification_codes": expired_codes_count, "expired_login_sessions": expired_sessions_count, + "unverified_login_sessions": unverified_sessions_count, "old_used_verification_codes": old_used_codes_count, "old_verified_sessions": old_verified_sessions_count, "total_cleanable": expired_codes_count + expired_sessions_count + + unverified_sessions_count + old_used_codes_count + old_verified_sessions_count, } @@ -276,6 +334,7 @@ class DatabaseCleanupService: return { "expired_verification_codes": 0, "expired_login_sessions": 0, + "unverified_login_sessions": 0, "old_used_verification_codes": 0, "old_verified_sessions": 0, "total_cleanable": 0,