修复部分报错
This commit is contained in:
@@ -8,9 +8,11 @@ from app.database import User
|
|||||||
from app.database.auth import OAuthToken, V1APIKeys
|
from app.database.auth import OAuthToken, V1APIKeys
|
||||||
from app.models.oauth import OAuth2ClientCredentialsBearer
|
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||||
|
|
||||||
from .database import Database
|
from .api_version import APIVersion
|
||||||
|
from .database import Database, get_redis
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
|
from redis.asyncio import Redis
|
||||||
from fastapi.security import (
|
from fastapi.security import (
|
||||||
APIKeyQuery,
|
APIKeyQuery,
|
||||||
HTTPBearer,
|
HTTPBearer,
|
||||||
@@ -97,13 +99,40 @@ async def get_client_user_no_verified(user_and_token: UserAndToken = Depends(get
|
|||||||
return user_and_token[0]
|
return user_and_token[0]
|
||||||
|
|
||||||
|
|
||||||
async def get_client_user(db: Database, user_and_token: UserAndToken = Depends(get_client_user_and_token)):
|
async def get_client_user(
|
||||||
|
db: Database,
|
||||||
|
redis: Annotated[Redis, Depends(get_redis)],
|
||||||
|
api_version: APIVersion,
|
||||||
|
user_and_token: UserAndToken = Depends(get_client_user_and_token)
|
||||||
|
):
|
||||||
from app.service.verification_service import LoginSessionService
|
from app.service.verification_service import LoginSessionService
|
||||||
|
|
||||||
user, token = user_and_token
|
user, token = user_and_token
|
||||||
|
|
||||||
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
|
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
|
||||||
raise HTTPException(status_code=403, detail="User not verified")
|
# 获取当前验证方式
|
||||||
|
verify_method = None
|
||||||
|
if api_version >= 20250913:
|
||||||
|
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
|
||||||
|
|
||||||
|
if verify_method is None:
|
||||||
|
# 智能选择验证方式(有TOTP优先TOTP)
|
||||||
|
totp_key = await user.awaitable_attrs.totp_key
|
||||||
|
if totp_key is not None and api_version >= 20240101:
|
||||||
|
verify_method = "totp"
|
||||||
|
else:
|
||||||
|
verify_method = "mail"
|
||||||
|
|
||||||
|
# 设置选择的验证方法到Redis中,避免重复选择
|
||||||
|
if api_version >= 20250913:
|
||||||
|
await LoginSessionService.set_login_method(user.id, token.id, verify_method, redis)
|
||||||
|
|
||||||
|
# 返回符合 osu! API 标准的错误响应
|
||||||
|
error_response = {
|
||||||
|
"error": "User not verified",
|
||||||
|
"method": verify_method
|
||||||
|
}
|
||||||
|
raise HTTPException(status_code=401, detail=error_response)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -48,9 +48,18 @@ class ChatServer:
|
|||||||
user_id = user.id
|
user_id = user.id
|
||||||
if user_id in self.connect_client:
|
if user_id in self.connect_client:
|
||||||
del self.connect_client[user_id]
|
del self.connect_client[user_id]
|
||||||
|
|
||||||
|
# 创建频道ID列表的副本以避免在迭代过程中修改字典
|
||||||
|
channel_ids_to_process = []
|
||||||
for channel_id, channel in self.channels.items():
|
for channel_id, channel in self.channels.items():
|
||||||
if user_id in channel:
|
if user_id in channel:
|
||||||
channel.remove(user_id)
|
channel_ids_to_process.append(channel_id)
|
||||||
|
|
||||||
|
# 现在安全地处理每个频道
|
||||||
|
for channel_id in channel_ids_to_process:
|
||||||
|
# 再次检查用户是否仍在频道中(防止并发修改)
|
||||||
|
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||||||
|
self.channels[channel_id].remove(user_id)
|
||||||
# 使用明确的查询避免延迟加载
|
# 使用明确的查询避免延迟加载
|
||||||
db_channel = (
|
db_channel = (
|
||||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ async def verify_session(
|
|||||||
if verify_method is None:
|
if verify_method is None:
|
||||||
# 智能选择验证方法(参考osu-web实现)
|
# 智能选择验证方法(参考osu-web实现)
|
||||||
# API版本较老或用户未设置TOTP时强制使用邮件验证
|
# API版本较老或用户未设置TOTP时强制使用邮件验证
|
||||||
print(api_version, totp_key)
|
#print(api_version, totp_key)
|
||||||
if api_version < 20240101 or totp_key is None:
|
if api_version < 20240101 or totp_key is None:
|
||||||
verify_method = "mail"
|
verify_method = "mail"
|
||||||
else:
|
else:
|
||||||
@@ -148,8 +148,8 @@ async def verify_session(
|
|||||||
|
|
||||||
# 构建更详细的错误响应(参考osu-web的错误处理)
|
# 构建更详细的错误响应(参考osu-web的错误处理)
|
||||||
error_response = {
|
error_response = {
|
||||||
"method": verify_method,
|
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
"method": verify_method,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 如果有具体的错误原因,添加到响应中
|
# 如果有具体的错误原因,添加到响应中
|
||||||
|
|||||||
@@ -74,10 +74,13 @@ class DatabaseCleanupScheduler:
|
|||||||
# 清理过期的登录会话
|
# 清理过期的登录会话
|
||||||
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||||
|
|
||||||
|
# 清理1小时前未验证的登录会话
|
||||||
|
unverified_sessions = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1)
|
||||||
|
|
||||||
# 只在有清理记录时输出总结
|
# 只在有清理记录时输出总结
|
||||||
total_cleaned = expired_codes + expired_sessions
|
total_cleaned = expired_codes + expired_sessions + unverified_sessions
|
||||||
if total_cleaned > 0:
|
if total_cleaned > 0:
|
||||||
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}")
|
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}, unverified: {unverified_sessions}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during scheduled database cleanup: {e!s}")
|
logger.error(f"Error during scheduled database cleanup: {e!s}")
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from app.database.verification import EmailVerification, LoginSession
|
|||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.utils import utcnow
|
from app.utils import utcnow
|
||||||
|
|
||||||
from sqlmodel import col, select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ class DatabaseCleanupService:
|
|||||||
# 查找指定天数前的已使用验证码记录
|
# 查找指定天数前的已使用验证码记录
|
||||||
cutoff_time = utcnow() - timedelta(days=days_old)
|
cutoff_time = utcnow() - timedelta(days=days_old)
|
||||||
|
|
||||||
stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
|
stmt = select(EmailVerification).where(EmailVerification.is_used == True)
|
||||||
result = await db.exec(stmt)
|
result = await db.exec(stmt)
|
||||||
all_used_codes = result.all()
|
all_used_codes = result.all()
|
||||||
|
|
||||||
@@ -134,6 +134,50 @@ class DatabaseCleanupService:
|
|||||||
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}")
|
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def cleanup_unverified_login_sessions(db: AsyncSession, hours_old: int = 1) -> int:
|
||||||
|
"""
|
||||||
|
清理指定小时前创建但仍未验证的登录会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
hours_old: 清理多少小时前创建但仍未验证的会话,默认1小时
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 清理的记录数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 计算截止时间
|
||||||
|
cutoff_time = utcnow() - timedelta(hours=hours_old)
|
||||||
|
|
||||||
|
# 查找指定时间前创建且仍未验证的会话记录
|
||||||
|
stmt = select(LoginSession).where(
|
||||||
|
LoginSession.is_verified == False,
|
||||||
|
LoginSession.created_at < cutoff_time
|
||||||
|
)
|
||||||
|
result = await db.exec(stmt)
|
||||||
|
unverified_sessions = result.all()
|
||||||
|
|
||||||
|
# 删除未验证的会话记录
|
||||||
|
deleted_count = 0
|
||||||
|
for session in unverified_sessions:
|
||||||
|
await db.delete(session)
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.debug(
|
||||||
|
f"[Cleanup Service] Cleaned up {deleted_count} unverified login sessions older than {hours_old} hour(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"[Cleanup Service] Error cleaning unverified login sessions: {e!s}")
|
||||||
|
return 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
|
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
|
||||||
"""
|
"""
|
||||||
@@ -150,7 +194,7 @@ class DatabaseCleanupService:
|
|||||||
# 查找指定天数前的已验证会话记录
|
# 查找指定天数前的已验证会话记录
|
||||||
cutoff_time = utcnow() - timedelta(days=days_old)
|
cutoff_time = utcnow() - timedelta(days=days_old)
|
||||||
|
|
||||||
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
stmt = select(LoginSession).where(LoginSession.is_verified == True)
|
||||||
result = await db.exec(stmt)
|
result = await db.exec(stmt)
|
||||||
all_verified_sessions = result.all()
|
all_verified_sessions = result.all()
|
||||||
|
|
||||||
@@ -200,6 +244,9 @@ class DatabaseCleanupService:
|
|||||||
# 清理过期的登录会话
|
# 清理过期的登录会话
|
||||||
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||||
|
|
||||||
|
# 清理1小时前未验证的登录会话
|
||||||
|
results["unverified_login_sessions"] = await DatabaseCleanupService.cleanup_unverified_login_sessions(db, 1)
|
||||||
|
|
||||||
# 清理7天前的已使用验证码
|
# 清理7天前的已使用验证码
|
||||||
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
|
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
|
||||||
|
|
||||||
@@ -227,6 +274,7 @@ class DatabaseCleanupService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
current_time = utcnow()
|
current_time = utcnow()
|
||||||
|
cutoff_1_hour = current_time - timedelta(hours=1)
|
||||||
cutoff_7_days = current_time - timedelta(days=7)
|
cutoff_7_days = current_time - timedelta(days=7)
|
||||||
cutoff_30_days = current_time - timedelta(days=30)
|
cutoff_30_days = current_time - timedelta(days=30)
|
||||||
|
|
||||||
@@ -240,8 +288,16 @@ class DatabaseCleanupService:
|
|||||||
expired_sessions_result = await db.exec(expired_sessions_stmt)
|
expired_sessions_result = await db.exec(expired_sessions_stmt)
|
||||||
expired_sessions_count = len(expired_sessions_result.all())
|
expired_sessions_count = len(expired_sessions_result.all())
|
||||||
|
|
||||||
|
# 统计1小时前未验证的登录会话数量
|
||||||
|
unverified_sessions_stmt = select(LoginSession).where(
|
||||||
|
LoginSession.is_verified == False,
|
||||||
|
LoginSession.created_at < cutoff_1_hour
|
||||||
|
)
|
||||||
|
unverified_sessions_result = await db.exec(unverified_sessions_stmt)
|
||||||
|
unverified_sessions_count = len(unverified_sessions_result.all())
|
||||||
|
|
||||||
# 统计7天前的已使用验证码数量
|
# 统计7天前的已使用验证码数量
|
||||||
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
|
old_used_codes_stmt = select(EmailVerification).where(EmailVerification.is_used == True)
|
||||||
old_used_codes_result = await db.exec(old_used_codes_stmt)
|
old_used_codes_result = await db.exec(old_used_codes_stmt)
|
||||||
all_used_codes = old_used_codes_result.all()
|
all_used_codes = old_used_codes_result.all()
|
||||||
old_used_codes_count = len(
|
old_used_codes_count = len(
|
||||||
@@ -249,7 +305,7 @@ class DatabaseCleanupService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 统计30天前的已验证会话数量
|
# 统计30天前的已验证会话数量
|
||||||
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
old_verified_sessions_stmt = select(LoginSession).where(LoginSession.is_verified == True)
|
||||||
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
|
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
|
||||||
all_verified_sessions = old_verified_sessions_result.all()
|
all_verified_sessions = old_verified_sessions_result.all()
|
||||||
old_verified_sessions_count = len(
|
old_verified_sessions_count = len(
|
||||||
@@ -263,10 +319,12 @@ class DatabaseCleanupService:
|
|||||||
return {
|
return {
|
||||||
"expired_verification_codes": expired_codes_count,
|
"expired_verification_codes": expired_codes_count,
|
||||||
"expired_login_sessions": expired_sessions_count,
|
"expired_login_sessions": expired_sessions_count,
|
||||||
|
"unverified_login_sessions": unverified_sessions_count,
|
||||||
"old_used_verification_codes": old_used_codes_count,
|
"old_used_verification_codes": old_used_codes_count,
|
||||||
"old_verified_sessions": old_verified_sessions_count,
|
"old_verified_sessions": old_verified_sessions_count,
|
||||||
"total_cleanable": expired_codes_count
|
"total_cleanable": expired_codes_count
|
||||||
+ expired_sessions_count
|
+ expired_sessions_count
|
||||||
|
+ unverified_sessions_count
|
||||||
+ old_used_codes_count
|
+ old_used_codes_count
|
||||||
+ old_verified_sessions_count,
|
+ old_verified_sessions_count,
|
||||||
}
|
}
|
||||||
@@ -276,6 +334,7 @@ class DatabaseCleanupService:
|
|||||||
return {
|
return {
|
||||||
"expired_verification_codes": 0,
|
"expired_verification_codes": 0,
|
||||||
"expired_login_sessions": 0,
|
"expired_login_sessions": 0,
|
||||||
|
"unverified_login_sessions": 0,
|
||||||
"old_used_verification_codes": 0,
|
"old_used_verification_codes": 0,
|
||||||
"old_verified_sessions": 0,
|
"old_verified_sessions": 0,
|
||||||
"total_cleanable": 0,
|
"total_cleanable": 0,
|
||||||
|
|||||||
Reference in New Issue
Block a user