修复部分报错

This commit is contained in:
咕谷酱
2025-09-24 03:04:09 +08:00
parent 86c7bbb74e
commit 3a434ee02c
5 changed files with 113 additions and 13 deletions

View File

@@ -8,9 +8,11 @@ from app.database import User
from app.database.auth import OAuthToken, V1APIKeys from app.database.auth import OAuthToken, V1APIKeys
from app.models.oauth import OAuth2ClientCredentialsBearer from app.models.oauth import OAuth2ClientCredentialsBearer
from .database import Database from .api_version import APIVersion
from .database import Database, get_redis
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from redis.asyncio import Redis
from fastapi.security import ( from fastapi.security import (
APIKeyQuery, APIKeyQuery,
HTTPBearer, HTTPBearer,
@@ -97,13 +99,40 @@ async def get_client_user_no_verified(user_and_token: UserAndToken = Depends(get
return user_and_token[0] return user_and_token[0]
async def get_client_user(db: Database, user_and_token: UserAndToken = Depends(get_client_user_and_token)): async def get_client_user(
db: Database,
redis: Annotated[Redis, Depends(get_redis)],
api_version: APIVersion,
user_and_token: UserAndToken = Depends(get_client_user_and_token)
):
from app.service.verification_service import LoginSessionService from app.service.verification_service import LoginSessionService
user, token = user_and_token user, token = user_and_token
if await LoginSessionService.check_is_need_verification(db, user.id, token.id): if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
raise HTTPException(status_code=403, detail="User not verified") # 获取当前验证方式
verify_method = None
if api_version >= 20250913:
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
if verify_method is None:
# 智能选择验证方式有TOTP优先TOTP
totp_key = await user.awaitable_attrs.totp_key
if totp_key is not None and api_version >= 20240101:
verify_method = "totp"
else:
verify_method = "mail"
# 设置选择的验证方法到Redis中避免重复选择
if api_version >= 20250913:
await LoginSessionService.set_login_method(user.id, token.id, verify_method, redis)
# 返回符合 osu! API 标准的错误响应
error_response = {
"error": "User not verified",
"method": verify_method
}
raise HTTPException(status_code=401, detail=error_response)
return user return user

View File

@@ -48,9 +48,18 @@ class ChatServer:
user_id = user.id user_id = user.id
if user_id in self.connect_client: if user_id in self.connect_client:
del self.connect_client[user_id] del self.connect_client[user_id]
# 创建频道ID列表的副本以避免在迭代过程中修改字典
channel_ids_to_process = []
for channel_id, channel in self.channels.items(): for channel_id, channel in self.channels.items():
if user_id in channel: if user_id in channel:
channel.remove(user_id) channel_ids_to_process.append(channel_id)
# 现在安全地处理每个频道
for channel_id in channel_ids_to_process:
# 再次检查用户是否仍在频道中(防止并发修改)
if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id)
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
db_channel = ( db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id)) await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))

View File

@@ -86,7 +86,7 @@ async def verify_session(
if verify_method is None: if verify_method is None:
# 智能选择验证方法参考osu-web实现 # 智能选择验证方法参考osu-web实现
# API版本较老或用户未设置TOTP时强制使用邮件验证 # API版本较老或用户未设置TOTP时强制使用邮件验证
print(api_version, totp_key) #print(api_version, totp_key)
if api_version < 20240101 or totp_key is None: if api_version < 20240101 or totp_key is None:
verify_method = "mail" verify_method = "mail"
else: else:
@@ -148,8 +148,8 @@ async def verify_session(
# 构建更详细的错误响应参考osu-web的错误处理 # 构建更详细的错误响应参考osu-web的错误处理
error_response = { error_response = {
"method": verify_method,
"error": str(e), "error": str(e),
"method": verify_method,
} }
# 如果有具体的错误原因,添加到响应中 # 如果有具体的错误原因,添加到响应中

View File

@@ -74,10 +74,13 @@ class DatabaseCleanupScheduler:
# 清理过期的登录会话 # 清理过期的登录会话
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db) expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 清理1小时前未验证的登录会话
unverified_sessions = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1)
# 只在有清理记录时输出总结 # 只在有清理记录时输出总结
total_cleaned = expired_codes + expired_sessions total_cleaned = expired_codes + expired_sessions + unverified_sessions
if total_cleaned > 0: if total_cleaned > 0:
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}") logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}, unverified: {unverified_sessions}")
except Exception as e: except Exception as e:
logger.error(f"Error during scheduled database cleanup: {e!s}") logger.error(f"Error during scheduled database cleanup: {e!s}")

View File

@@ -10,7 +10,7 @@ from app.database.verification import EmailVerification, LoginSession
from app.log import logger from app.log import logger
from app.utils import utcnow from app.utils import utcnow
from sqlmodel import col, select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -107,7 +107,7 @@ class DatabaseCleanupService:
# 查找指定天数前的已使用验证码记录 # 查找指定天数前的已使用验证码记录
cutoff_time = utcnow() - timedelta(days=days_old) cutoff_time = utcnow() - timedelta(days=days_old)
stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True)) stmt = select(EmailVerification).where(EmailVerification.is_used == True)
result = await db.exec(stmt) result = await db.exec(stmt)
all_used_codes = result.all() all_used_codes = result.all()
@@ -134,6 +134,50 @@ class DatabaseCleanupService:
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}") logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}")
return 0 return 0
@staticmethod
async def cleanup_unverified_login_sessions(db: AsyncSession, hours_old: int = 1) -> int:
"""
清理指定小时前创建但仍未验证的登录会话
Args:
db: 数据库会话
hours_old: 清理多少小时前创建但仍未验证的会话默认1小时
Returns:
int: 清理的记录数
"""
try:
# 计算截止时间
cutoff_time = utcnow() - timedelta(hours=hours_old)
# 查找指定时间前创建且仍未验证的会话记录
stmt = select(LoginSession).where(
LoginSession.is_verified == False,
LoginSession.created_at < cutoff_time
)
result = await db.exec(stmt)
unverified_sessions = result.all()
# 删除未验证的会话记录
deleted_count = 0
for session in unverified_sessions:
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} unverified login sessions older than {hours_old} hour(s)"
)
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning unverified login sessions: {e!s}")
return 0
@staticmethod @staticmethod
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int: async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
""" """
@@ -150,7 +194,7 @@ class DatabaseCleanupService:
# 查找指定天数前的已验证会话记录 # 查找指定天数前的已验证会话记录
cutoff_time = utcnow() - timedelta(days=days_old) cutoff_time = utcnow() - timedelta(days=days_old)
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) stmt = select(LoginSession).where(LoginSession.is_verified == True)
result = await db.exec(stmt) result = await db.exec(stmt)
all_verified_sessions = result.all() all_verified_sessions = result.all()
@@ -200,6 +244,9 @@ class DatabaseCleanupService:
# 清理过期的登录会话 # 清理过期的登录会话
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db) results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 清理1小时前未验证的登录会话
results["unverified_login_sessions"] = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1)
# 清理7天前的已使用验证码 # 清理7天前的已使用验证码
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7) results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
@@ -227,6 +274,7 @@ class DatabaseCleanupService:
""" """
try: try:
current_time = utcnow() current_time = utcnow()
cutoff_1_hour = current_time - timedelta(hours=1)
cutoff_7_days = current_time - timedelta(days=7) cutoff_7_days = current_time - timedelta(days=7)
cutoff_30_days = current_time - timedelta(days=30) cutoff_30_days = current_time - timedelta(days=30)
@@ -240,8 +288,16 @@ class DatabaseCleanupService:
expired_sessions_result = await db.exec(expired_sessions_stmt) expired_sessions_result = await db.exec(expired_sessions_stmt)
expired_sessions_count = len(expired_sessions_result.all()) expired_sessions_count = len(expired_sessions_result.all())
# 统计1小时前未验证的登录会话数量
unverified_sessions_stmt = select(LoginSession).where(
LoginSession.is_verified == False,
LoginSession.created_at < cutoff_1_hour
)
unverified_sessions_result = await db.exec(unverified_sessions_stmt)
unverified_sessions_count = len(unverified_sessions_result.all())
# 统计7天前的已使用验证码数量 # 统计7天前的已使用验证码数量
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True)) old_used_codes_stmt = select(EmailVerification).where(EmailVerification.is_used == True)
old_used_codes_result = await db.exec(old_used_codes_stmt) old_used_codes_result = await db.exec(old_used_codes_stmt)
all_used_codes = old_used_codes_result.all() all_used_codes = old_used_codes_result.all()
old_used_codes_count = len( old_used_codes_count = len(
@@ -249,7 +305,7 @@ class DatabaseCleanupService:
) )
# 统计30天前的已验证会话数量 # 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) old_verified_sessions_stmt = select(LoginSession).where(LoginSession.is_verified == True)
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt) old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
all_verified_sessions = old_verified_sessions_result.all() all_verified_sessions = old_verified_sessions_result.all()
old_verified_sessions_count = len( old_verified_sessions_count = len(
@@ -263,10 +319,12 @@ class DatabaseCleanupService:
return { return {
"expired_verification_codes": expired_codes_count, "expired_verification_codes": expired_codes_count,
"expired_login_sessions": expired_sessions_count, "expired_login_sessions": expired_sessions_count,
"unverified_login_sessions": unverified_sessions_count,
"old_used_verification_codes": old_used_codes_count, "old_used_verification_codes": old_used_codes_count,
"old_verified_sessions": old_verified_sessions_count, "old_verified_sessions": old_verified_sessions_count,
"total_cleanable": expired_codes_count "total_cleanable": expired_codes_count
+ expired_sessions_count + expired_sessions_count
+ unverified_sessions_count
+ old_used_codes_count + old_used_codes_count
+ old_verified_sessions_count, + old_verified_sessions_count,
} }
@@ -276,6 +334,7 @@ class DatabaseCleanupService:
return { return {
"expired_verification_codes": 0, "expired_verification_codes": 0,
"expired_login_sessions": 0, "expired_login_sessions": 0,
"unverified_login_sessions": 0,
"old_used_verification_codes": 0, "old_used_verification_codes": 0,
"old_verified_sessions": 0, "old_verified_sessions": 0,
"total_cleanable": 0, "total_cleanable": 0,