refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -63,10 +63,7 @@ class BeatmapCacheService:
|
||||
if preload_tasks:
|
||||
results = await asyncio.gather(*preload_tasks, return_exceptions=True)
|
||||
success_count = sum(1 for r in results if r is True)
|
||||
logger.info(
|
||||
f"Preloaded {success_count}/{len(preload_tasks)} "
|
||||
f"beatmaps successfully"
|
||||
)
|
||||
logger.info(f"Preloaded {success_count}/{len(preload_tasks)} beatmaps successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during beatmap preloading: {e}")
|
||||
@@ -119,9 +116,7 @@ class BeatmapCacheService:
|
||||
|
||||
return {
|
||||
"cached_beatmaps": len(keys),
|
||||
"estimated_total_size_mb": (
|
||||
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
|
||||
),
|
||||
"estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0),
|
||||
"preloading": self._preloading,
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -155,9 +150,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS
|
||||
return _cache_service
|
||||
|
||||
|
||||
async def schedule_preload_task(
|
||||
session: AsyncSession, redis: Redis, fetcher: "Fetcher"
|
||||
):
|
||||
async def schedule_preload_task(session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
|
||||
"""
|
||||
定时预加载任务
|
||||
"""
|
||||
|
||||
@@ -192,22 +192,16 @@ class BeatmapDownloadService:
|
||||
healthy_endpoints.sort(key=lambda x: x.priority)
|
||||
return healthy_endpoints
|
||||
|
||||
def get_download_url(
|
||||
self, beatmapset_id: int, no_video: bool, is_china: bool
|
||||
) -> str:
|
||||
def get_download_url(self, beatmapset_id: int, no_video: bool, is_china: bool) -> str:
|
||||
"""获取下载URL,带负载均衡和故障转移"""
|
||||
healthy_endpoints = self.get_healthy_endpoints(is_china)
|
||||
|
||||
if not healthy_endpoints:
|
||||
# 如果没有健康的端点,记录错误并回退到所有端点中优先级最高的
|
||||
logger.error(f"No healthy endpoints available for is_china={is_china}")
|
||||
endpoints = (
|
||||
self.china_endpoints if is_china else self.international_endpoints
|
||||
)
|
||||
endpoints = self.china_endpoints if is_china else self.international_endpoints
|
||||
if not endpoints:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="No download endpoints available"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="No download endpoints available")
|
||||
endpoint = min(endpoints, key=lambda x: x.priority)
|
||||
else:
|
||||
# 使用第一个健康的端点(已按优先级排序)
|
||||
@@ -218,9 +212,7 @@ class BeatmapDownloadService:
|
||||
video_type = "novideo" if no_video else "full"
|
||||
return endpoint.url_template.format(type=video_type, sid=beatmapset_id)
|
||||
elif endpoint.name == "Nerinyan":
|
||||
return endpoint.url_template.format(
|
||||
sid=beatmapset_id, no_video="true" if no_video else "false"
|
||||
)
|
||||
return endpoint.url_template.format(sid=beatmapset_id, no_video="true" if no_video else "false")
|
||||
elif endpoint.name == "OsuDirect":
|
||||
# osu.direct 似乎没有no_video参数,直接使用基础URL
|
||||
return endpoint.url_template.format(sid=beatmapset_id)
|
||||
@@ -239,9 +231,7 @@ class BeatmapDownloadService:
|
||||
for name, status in self.endpoint_status.items():
|
||||
status_info["endpoints"][name] = {
|
||||
"healthy": status.is_healthy,
|
||||
"last_check": status.last_check.isoformat()
|
||||
if status.last_check
|
||||
else None,
|
||||
"last_check": status.last_check.isoformat() if status.last_check else None,
|
||||
"consecutive_failures": status.consecutive_failures,
|
||||
"last_error": status.last_error,
|
||||
"priority": status.endpoint.priority,
|
||||
|
||||
@@ -11,9 +11,7 @@ from app.models.score import GameMode
|
||||
from sqlmodel import col, exists, select, update
|
||||
|
||||
|
||||
@get_scheduler().scheduled_job(
|
||||
"cron", hour=0, minute=0, second=0, id="calculate_user_rank"
|
||||
)
|
||||
@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="calculate_user_rank")
|
||||
async def calculate_user_rank(is_today: bool = False):
|
||||
today = datetime.now(UTC).date()
|
||||
target_date = today if is_today else today - timedelta(days=1)
|
||||
|
||||
@@ -11,9 +11,7 @@ from sqlmodel import exists, select
|
||||
|
||||
async def create_banchobot():
|
||||
async with with_db() as session:
|
||||
is_exist = (
|
||||
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
|
||||
).first()
|
||||
is_exist = (await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))).first()
|
||||
if not is_exist:
|
||||
banchobot = User(
|
||||
username="BanchoBot",
|
||||
|
||||
@@ -82,8 +82,7 @@ async def daily_challenge_job():
|
||||
|
||||
if beatmap is None or ruleset_id is None:
|
||||
logger.warning(
|
||||
f"[DailyChallenge] Missing required data for daily challenge {now}."
|
||||
" Will try again in 5 minutes."
|
||||
f"[DailyChallenge] Missing required data for daily challenge {now}. Will try again in 5 minutes."
|
||||
)
|
||||
get_scheduler().add_job(
|
||||
daily_challenge_job,
|
||||
@@ -104,9 +103,7 @@ async def daily_challenge_job():
|
||||
else:
|
||||
allowed_mods_list = get_available_mods(ruleset_id_int, required_mods_list)
|
||||
|
||||
next_day = (now + timedelta(days=1)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
next_day = (now + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
room = await create_daily_challenge_room(
|
||||
beatmap=beatmap_int,
|
||||
ruleset_id=ruleset_id_int,
|
||||
@@ -114,24 +111,13 @@ async def daily_challenge_job():
|
||||
allowed_mods=allowed_mods_list,
|
||||
duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60),
|
||||
)
|
||||
await MetadataHubs.broadcast_call(
|
||||
"DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id)
|
||||
)
|
||||
logger.success(
|
||||
"[DailyChallenge] Added today's daily challenge: "
|
||||
f"{beatmap=}, {ruleset_id=}, {required_mods=}"
|
||||
)
|
||||
await MetadataHubs.broadcast_call("DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id))
|
||||
logger.success(f"[DailyChallenge] Added today's daily challenge: {beatmap=}, {ruleset_id=}, {required_mods=}")
|
||||
return
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning(
|
||||
f"[DailyChallenge] Error processing daily challenge data: {e}"
|
||||
" Will try again in 5 minutes."
|
||||
)
|
||||
logger.warning(f"[DailyChallenge] Error processing daily challenge data: {e} Will try again in 5 minutes.")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"[DailyChallenge] Unexpected error in daily challenge job: {e}"
|
||||
" Will try again in 5 minutes."
|
||||
)
|
||||
logger.exception(f"[DailyChallenge] Unexpected error in daily challenge job: {e} Will try again in 5 minutes.")
|
||||
get_scheduler().add_job(
|
||||
daily_challenge_job,
|
||||
"date",
|
||||
@@ -139,9 +125,7 @@ async def daily_challenge_job():
|
||||
)
|
||||
|
||||
|
||||
@get_scheduler().scheduled_job(
|
||||
"cron", hour=0, minute=1, second=0, id="daily_challenge_last_top"
|
||||
)
|
||||
@get_scheduler().scheduled_job("cron", hour=0, minute=1, second=0, id="daily_challenge_last_top")
|
||||
async def process_daily_challenge_top():
|
||||
async with with_db() as session:
|
||||
now = datetime.now(UTC)
|
||||
@@ -182,11 +166,7 @@ async def process_daily_challenge_top():
|
||||
await session.commit()
|
||||
del s
|
||||
|
||||
user_ids = (
|
||||
await session.exec(
|
||||
select(User.id).where(col(User.id).not_in(participated_users))
|
||||
)
|
||||
).all()
|
||||
user_ids = (await session.exec(select(User.id).where(col(User.id).not_in(participated_users)))).all()
|
||||
for id in user_ids:
|
||||
stats = await session.get(DailyChallengeStats, id)
|
||||
if stats is None: # not execute
|
||||
|
||||
@@ -4,14 +4,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.log import logger
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
class DatabaseCleanupService:
|
||||
@@ -21,211 +20,207 @@ class DatabaseCleanupService:
|
||||
async def cleanup_expired_verification_codes(db: AsyncSession) -> int:
|
||||
"""
|
||||
清理过期的邮件验证码
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找过期的验证码记录
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
stmt = select(EmailVerification).where(
|
||||
EmailVerification.expires_at < current_time
|
||||
)
|
||||
|
||||
stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
|
||||
result = await db.exec(stmt)
|
||||
expired_codes = result.all()
|
||||
|
||||
|
||||
# 删除过期的记录
|
||||
deleted_count = 0
|
||||
for code in expired_codes:
|
||||
await db.delete(code)
|
||||
deleted_count += 1
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes")
|
||||
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {str(e)}")
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_login_sessions(db: AsyncSession) -> int:
|
||||
"""
|
||||
清理过期的登录会话
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找过期的登录会话记录
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
stmt = select(LoginSession).where(
|
||||
LoginSession.expires_at < current_time
|
||||
)
|
||||
|
||||
stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
|
||||
result = await db.exec(stmt)
|
||||
expired_sessions = result.all()
|
||||
|
||||
|
||||
# 删除过期的记录
|
||||
deleted_count = 0
|
||||
for session in expired_sessions:
|
||||
await db.delete(session)
|
||||
deleted_count += 1
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions")
|
||||
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {str(e)}")
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int:
|
||||
"""
|
||||
清理旧的已使用验证码记录
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
days_old: 清理多少天前的已使用记录,默认7天
|
||||
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找指定天数前的已使用验证码记录
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||
|
||||
stmt = select(EmailVerification).where(
|
||||
EmailVerification.is_used == True
|
||||
)
|
||||
|
||||
stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
|
||||
result = await db.exec(stmt)
|
||||
all_used_codes = result.all()
|
||||
|
||||
|
||||
# 筛选出过期的记录
|
||||
old_used_codes = [
|
||||
code for code in all_used_codes
|
||||
if code.used_at and code.used_at < cutoff_time
|
||||
]
|
||||
|
||||
old_used_codes = [code for code in all_used_codes if code.used_at and code.used_at < cutoff_time]
|
||||
|
||||
# 删除旧的已使用记录
|
||||
deleted_count = 0
|
||||
for code in old_used_codes:
|
||||
await db.delete(code)
|
||||
deleted_count += 1
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days")
|
||||
|
||||
logger.debug(
|
||||
f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days"
|
||||
)
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {str(e)}")
|
||||
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
|
||||
"""
|
||||
清理旧的已验证会话记录
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
days_old: 清理多少天前的已验证记录,默认30天
|
||||
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找指定天数前的已验证会话记录
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||
|
||||
stmt = select(LoginSession).where(
|
||||
LoginSession.is_verified == True
|
||||
)
|
||||
|
||||
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
||||
result = await db.exec(stmt)
|
||||
all_verified_sessions = result.all()
|
||||
|
||||
|
||||
# 筛选出过期的记录
|
||||
old_verified_sessions = [
|
||||
session for session in all_verified_sessions
|
||||
session
|
||||
for session in all_verified_sessions
|
||||
if session.verified_at and session.verified_at < cutoff_time
|
||||
]
|
||||
|
||||
|
||||
# 删除旧的已验证记录
|
||||
deleted_count = 0
|
||||
for session in old_verified_sessions:
|
||||
await db.delete(session)
|
||||
deleted_count += 1
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days")
|
||||
|
||||
logger.debug(
|
||||
f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days"
|
||||
)
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {str(e)}")
|
||||
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def run_full_cleanup(db: AsyncSession) -> dict[str, int]:
|
||||
"""
|
||||
运行完整的清理流程
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 各项清理的结果统计
|
||||
"""
|
||||
results = {}
|
||||
|
||||
|
||||
# 清理过期的验证码
|
||||
results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
|
||||
|
||||
|
||||
# 清理过期的登录会话
|
||||
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||
|
||||
|
||||
# 清理7天前的已使用验证码
|
||||
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
|
||||
|
||||
|
||||
# 清理30天前的已验证会话
|
||||
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
|
||||
|
||||
|
||||
total_cleaned = sum(results.values())
|
||||
if total_cleaned > 0:
|
||||
logger.debug(f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}")
|
||||
|
||||
logger.debug(
|
||||
f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]:
|
||||
"""
|
||||
获取清理统计信息
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 统计信息
|
||||
"""
|
||||
@@ -233,57 +228,54 @@ class DatabaseCleanupService:
|
||||
current_time = datetime.now(UTC)
|
||||
cutoff_7_days = current_time - timedelta(days=7)
|
||||
cutoff_30_days = current_time - timedelta(days=30)
|
||||
|
||||
|
||||
# 统计过期的验证码数量
|
||||
expired_codes_stmt = select(EmailVerification).where(
|
||||
EmailVerification.expires_at < current_time
|
||||
)
|
||||
expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
|
||||
expired_codes_result = await db.exec(expired_codes_stmt)
|
||||
expired_codes_count = len(expired_codes_result.all())
|
||||
|
||||
|
||||
# 统计过期的登录会话数量
|
||||
expired_sessions_stmt = select(LoginSession).where(
|
||||
LoginSession.expires_at < current_time
|
||||
)
|
||||
expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
|
||||
expired_sessions_result = await db.exec(expired_sessions_stmt)
|
||||
expired_sessions_count = len(expired_sessions_result.all())
|
||||
|
||||
|
||||
# 统计7天前的已使用验证码数量
|
||||
old_used_codes_stmt = select(EmailVerification).where(
|
||||
EmailVerification.is_used == True
|
||||
)
|
||||
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
|
||||
old_used_codes_result = await db.exec(old_used_codes_stmt)
|
||||
all_used_codes = old_used_codes_result.all()
|
||||
old_used_codes_count = len([
|
||||
code for code in all_used_codes
|
||||
if code.used_at and code.used_at < cutoff_7_days
|
||||
])
|
||||
|
||||
# 统计30天前的已验证会话数量
|
||||
old_verified_sessions_stmt = select(LoginSession).where(
|
||||
LoginSession.is_verified == True
|
||||
old_used_codes_count = len(
|
||||
[code for code in all_used_codes if code.used_at and code.used_at < cutoff_7_days]
|
||||
)
|
||||
|
||||
# 统计30天前的已验证会话数量
|
||||
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
||||
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
|
||||
all_verified_sessions = old_verified_sessions_result.all()
|
||||
old_verified_sessions_count = len([
|
||||
session for session in all_verified_sessions
|
||||
if session.verified_at and session.verified_at < cutoff_30_days
|
||||
])
|
||||
|
||||
old_verified_sessions_count = len(
|
||||
[
|
||||
session
|
||||
for session in all_verified_sessions
|
||||
if session.verified_at and session.verified_at < cutoff_30_days
|
||||
]
|
||||
)
|
||||
|
||||
return {
|
||||
"expired_verification_codes": expired_codes_count,
|
||||
"expired_login_sessions": expired_sessions_count,
|
||||
"old_used_verification_codes": old_used_codes_count,
|
||||
"old_verified_sessions": old_verified_sessions_count,
|
||||
"total_cleanable": expired_codes_count + expired_sessions_count + old_used_codes_count + old_verified_sessions_count
|
||||
"total_cleanable": expired_codes_count
|
||||
+ expired_sessions_count
|
||||
+ old_used_codes_count
|
||||
+ old_verified_sessions_count,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {str(e)}")
|
||||
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {e!s}")
|
||||
return {
|
||||
"expired_verification_codes": 0,
|
||||
"expired_login_sessions": 0,
|
||||
"old_used_verification_codes": 0,
|
||||
"old_verified_sessions": 0,
|
||||
"total_cleanable": 0
|
||||
"total_cleanable": 0,
|
||||
}
|
||||
|
||||
@@ -8,17 +8,18 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from datetime import datetime
|
||||
import json
|
||||
import uuid
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from typing import Dict, Any, Optional
|
||||
import redis as sync_redis # 添加同步Redis导入
|
||||
from email.mime.text import MIMEText
|
||||
import json
|
||||
import smtplib
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies.database import redis_message_client # 使用同步Redis客户端
|
||||
from app.log import logger
|
||||
from app.utils import bg_tasks # 添加同步Redis导入
|
||||
|
||||
import redis as sync_redis
|
||||
|
||||
|
||||
class EmailQueue:
|
||||
@@ -30,14 +31,14 @@ class EmailQueue:
|
||||
self._processing = False
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||
self._retry_limit = 3 # 重试次数限制
|
||||
|
||||
|
||||
# 邮件配置
|
||||
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
|
||||
self.smtp_port = getattr(settings, 'smtp_port', 587)
|
||||
self.smtp_username = getattr(settings, 'smtp_username', '')
|
||||
self.smtp_password = getattr(settings, 'smtp_password', '')
|
||||
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
|
||||
self.from_name = getattr(settings, 'from_name', 'osu! server')
|
||||
self.smtp_server = getattr(settings, "smtp_server", "localhost")
|
||||
self.smtp_port = getattr(settings, "smtp_port", 587)
|
||||
self.smtp_username = getattr(settings, "smtp_username", "")
|
||||
self.smtp_password = getattr(settings, "smtp_password", "")
|
||||
self.from_email = getattr(settings, "from_email", "noreply@example.com")
|
||||
self.from_name = getattr(settings, "from_name", "osu! server")
|
||||
|
||||
async def _run_in_executor(self, func, *args):
|
||||
"""在线程池中运行同步操作"""
|
||||
@@ -48,7 +49,7 @@ class EmailQueue:
|
||||
"""启动邮件处理任务"""
|
||||
if not self._processing:
|
||||
self._processing = True
|
||||
asyncio.create_task(self._process_email_queue())
|
||||
bg_tasks.add_task(self._process_email_queue)
|
||||
logger.info("Email queue processing started")
|
||||
|
||||
async def stop_processing(self):
|
||||
@@ -56,27 +57,29 @@ class EmailQueue:
|
||||
self._processing = False
|
||||
logger.info("Email queue processing stopped")
|
||||
|
||||
async def enqueue_email(self,
|
||||
to_email: str,
|
||||
subject: str,
|
||||
content: str,
|
||||
html_content: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> str:
|
||||
async def enqueue_email(
|
||||
self,
|
||||
to_email: str,
|
||||
subject: str,
|
||||
content: str,
|
||||
html_content: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
将邮件加入队列等待发送
|
||||
|
||||
|
||||
Args:
|
||||
to_email: 收件人邮箱地址
|
||||
subject: 邮件主题
|
||||
content: 邮件纯文本内容
|
||||
html_content: 邮件HTML内容(如果有)
|
||||
metadata: 额外元数据(如密码重置ID等)
|
||||
|
||||
|
||||
Returns:
|
||||
邮件任务ID
|
||||
"""
|
||||
email_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
email_data = {
|
||||
"id": email_id,
|
||||
"to_email": to_email,
|
||||
@@ -86,125 +89,117 @@ class EmailQueue:
|
||||
"metadata": json.dumps(metadata) if metadata else "{}",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"status": "pending", # pending, sending, sent, failed
|
||||
"retry_count": "0"
|
||||
"retry_count": "0",
|
||||
}
|
||||
|
||||
|
||||
# 将邮件数据存入Redis
|
||||
await self._run_in_executor(
|
||||
lambda: self.redis.hset(f"email:{email_id}", mapping=email_data)
|
||||
)
|
||||
|
||||
await self._run_in_executor(lambda: self.redis.hset(f"email:{email_id}", mapping=email_data))
|
||||
|
||||
# 设置24小时过期(防止数据堆积)
|
||||
await self._run_in_executor(
|
||||
self.redis.expire, f"email:{email_id}", 86400
|
||||
)
|
||||
|
||||
await self._run_in_executor(self.redis.expire, f"email:{email_id}", 86400)
|
||||
|
||||
# 加入发送队列
|
||||
await self._run_in_executor(
|
||||
self.redis.lpush, "email_queue", email_id
|
||||
)
|
||||
|
||||
await self._run_in_executor(self.redis.lpush, "email_queue", email_id)
|
||||
|
||||
logger.info(f"Email enqueued with id: {email_id} to {to_email}")
|
||||
return email_id
|
||||
|
||||
async def get_email_status(self, email_id: str) -> Dict[str, Any]:
|
||||
async def get_email_status(self, email_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
获取邮件发送状态
|
||||
|
||||
|
||||
Args:
|
||||
email_id: 邮件任务ID
|
||||
|
||||
|
||||
Returns:
|
||||
邮件任务状态信息
|
||||
"""
|
||||
email_data = await self._run_in_executor(
|
||||
self.redis.hgetall, f"email:{email_id}"
|
||||
)
|
||||
|
||||
email_data = await self._run_in_executor(self.redis.hgetall, f"email:{email_id}")
|
||||
|
||||
# 解码Redis返回的字节数据
|
||||
if email_data:
|
||||
return {
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k:
|
||||
v.decode("utf-8") if isinstance(v, bytes) else v
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v
|
||||
for k, v in email_data.items()
|
||||
}
|
||||
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
async def _process_email_queue(self):
|
||||
"""处理邮件队列"""
|
||||
logger.info("Starting email queue processor")
|
||||
|
||||
|
||||
while self._processing:
|
||||
try:
|
||||
# 从队列获取邮件ID
|
||||
def brpop_operation():
|
||||
return self.redis.brpop(["email_queue"], timeout=5)
|
||||
|
||||
|
||||
result = await self._run_in_executor(brpop_operation)
|
||||
|
||||
|
||||
if not result:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
|
||||
# 解包返回结果(列表名和值)
|
||||
queue_name, email_id = result
|
||||
if isinstance(email_id, bytes):
|
||||
email_id = email_id.decode("utf-8")
|
||||
|
||||
|
||||
# 获取邮件数据
|
||||
email_data = await self.get_email_status(email_id)
|
||||
if email_data.get("status") == "not_found":
|
||||
logger.warning(f"Email data not found for id: {email_id}")
|
||||
continue
|
||||
|
||||
|
||||
# 更新状态为发送中
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "status", "sending"
|
||||
)
|
||||
|
||||
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sending")
|
||||
|
||||
# 尝试发送邮件
|
||||
success = await self._send_email(email_data)
|
||||
|
||||
|
||||
if success:
|
||||
# 更新状态为已发送
|
||||
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sent")
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "status", "sent"
|
||||
)
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "sent_at", datetime.now().isoformat()
|
||||
self.redis.hset,
|
||||
f"email:{email_id}",
|
||||
"sent_at",
|
||||
datetime.now().isoformat(),
|
||||
)
|
||||
logger.info(f"Email {email_id} sent successfully to {email_data.get('to_email')}")
|
||||
else:
|
||||
# 计算重试次数
|
||||
retry_count = int(email_data.get("retry_count", "0")) + 1
|
||||
|
||||
|
||||
if retry_count <= self._retry_limit:
|
||||
# 重新入队,稍后重试
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "retry_count", str(retry_count)
|
||||
self.redis.hset,
|
||||
f"email:{email_id}",
|
||||
"retry_count",
|
||||
str(retry_count),
|
||||
)
|
||||
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "pending")
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "status", "pending"
|
||||
self.redis.hset,
|
||||
f"email:{email_id}",
|
||||
"last_retry",
|
||||
datetime.now().isoformat(),
|
||||
)
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "last_retry", datetime.now().isoformat()
|
||||
)
|
||||
|
||||
|
||||
# 延迟重试(使用指数退避)
|
||||
delay = 60 * (2 ** (retry_count - 1)) # 1分钟,2分钟,4分钟...
|
||||
|
||||
|
||||
# 创建延迟任务
|
||||
asyncio.create_task(self._delayed_retry(email_id, delay))
|
||||
|
||||
bg_tasks.add_task(self._delayed_retry, email_id, delay)
|
||||
|
||||
logger.warning(f"Email {email_id} will be retried in {delay} seconds (attempt {retry_count})")
|
||||
else:
|
||||
# 超过重试次数,标记为失败
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "status", "failed"
|
||||
)
|
||||
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "failed")
|
||||
logger.error(f"Email {email_id} failed after {retry_count} attempts")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing email queue: {e}")
|
||||
await asyncio.sleep(5) # 出错后等待5秒
|
||||
@@ -212,53 +207,51 @@ class EmailQueue:
|
||||
async def _delayed_retry(self, email_id: str, delay: int):
|
||||
"""延迟重试发送邮件"""
|
||||
await asyncio.sleep(delay)
|
||||
await self._run_in_executor(
|
||||
self.redis.lpush, "email_queue", email_id
|
||||
)
|
||||
await self._run_in_executor(self.redis.lpush, "email_queue", email_id)
|
||||
logger.info(f"Re-queued email {email_id} for retry after {delay} seconds")
|
||||
|
||||
async def _send_email(self, email_data: Dict[str, Any]) -> bool:
|
||||
async def _send_email(self, email_data: dict[str, Any]) -> bool:
|
||||
"""
|
||||
实际发送邮件
|
||||
|
||||
|
||||
Args:
|
||||
email_data: 邮件数据
|
||||
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
# 如果邮件发送功能被禁用,则只记录日志
|
||||
if not getattr(settings, 'enable_email_sending', True):
|
||||
if not getattr(settings, "enable_email_sending", True):
|
||||
logger.info(f"[Mock Email] Would send to {email_data.get('to_email')}: {email_data.get('subject')}")
|
||||
return True
|
||||
|
||||
|
||||
# 创建邮件
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['From'] = f"{self.from_name} <{self.from_email}>"
|
||||
msg['To'] = email_data.get('to_email', '')
|
||||
msg['Subject'] = email_data.get('subject', '')
|
||||
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["From"] = f"{self.from_name} <{self.from_email}>"
|
||||
msg["To"] = email_data.get("to_email", "")
|
||||
msg["Subject"] = email_data.get("subject", "")
|
||||
|
||||
# 添加纯文本内容
|
||||
content = email_data.get('content', '')
|
||||
content = email_data.get("content", "")
|
||||
if content:
|
||||
msg.attach(MIMEText(content, 'plain', 'utf-8'))
|
||||
|
||||
msg.attach(MIMEText(content, "plain", "utf-8"))
|
||||
|
||||
# 添加HTML内容(如果有)
|
||||
html_content = email_data.get('html_content', '')
|
||||
html_content = email_data.get("html_content", "")
|
||||
if html_content:
|
||||
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
|
||||
|
||||
msg.attach(MIMEText(html_content, "html", "utf-8"))
|
||||
|
||||
# 发送邮件
|
||||
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
|
||||
if self.smtp_username and self.smtp_password:
|
||||
server.starttls()
|
||||
server.login(self.smtp_username, self.smtp_password)
|
||||
|
||||
|
||||
server.send_message(msg)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email: {e}")
|
||||
return False
|
||||
@@ -267,10 +260,12 @@ class EmailQueue:
|
||||
# 全局邮件队列实例
|
||||
email_queue = EmailQueue()
|
||||
|
||||
|
||||
# 在应用启动时调用
|
||||
async def start_email_processor():
|
||||
await email_queue.start_processing()
|
||||
|
||||
|
||||
# 在应用关闭时调用
|
||||
async def stop_email_processor():
|
||||
await email_queue.stop_processing()
|
||||
|
||||
@@ -4,13 +4,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
import secrets
|
||||
import smtplib
|
||||
import string
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.log import logger
|
||||
@@ -18,28 +16,28 @@ from app.log import logger
|
||||
|
||||
class EmailService:
|
||||
"""邮件发送服务"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
|
||||
self.smtp_port = getattr(settings, 'smtp_port', 587)
|
||||
self.smtp_username = getattr(settings, 'smtp_username', '')
|
||||
self.smtp_password = getattr(settings, 'smtp_password', '')
|
||||
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
|
||||
self.from_name = getattr(settings, 'from_name', 'osu! server')
|
||||
|
||||
self.smtp_server = getattr(settings, "smtp_server", "localhost")
|
||||
self.smtp_port = getattr(settings, "smtp_port", 587)
|
||||
self.smtp_username = getattr(settings, "smtp_username", "")
|
||||
self.smtp_password = getattr(settings, "smtp_password", "")
|
||||
self.from_email = getattr(settings, "from_email", "noreply@example.com")
|
||||
self.from_name = getattr(settings, "from_name", "osu! server")
|
||||
|
||||
def generate_verification_code(self) -> str:
|
||||
"""生成8位验证码"""
|
||||
# 只使用数字,避免混淆
|
||||
return ''.join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
return "".join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
async def send_verification_email(self, email: str, code: str, username: str) -> bool:
|
||||
"""发送验证邮件"""
|
||||
try:
|
||||
msg = MIMEMultipart()
|
||||
msg['From'] = f"{self.from_name} <{self.from_email}>"
|
||||
msg['To'] = email
|
||||
msg['Subject'] = "邮箱验证 - Email Verification"
|
||||
|
||||
msg["From"] = f"{self.from_name} <{self.from_email}>"
|
||||
msg["To"] = email
|
||||
msg["Subject"] = "邮箱验证 - Email Verification"
|
||||
|
||||
# HTML 邮件内容
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
@@ -101,15 +99,15 @@ class EmailService:
|
||||
<h1>osu! 邮箱验证</h1>
|
||||
<p>Email Verification</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="content">
|
||||
<h2>你好 {username}!</h2>
|
||||
<p>感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:</p>
|
||||
|
||||
|
||||
<div class="code">{code}</div>
|
||||
|
||||
|
||||
<p>这个验证码将在 <strong>10 分钟后过期</strong>。</p>
|
||||
|
||||
|
||||
<div class="warning">
|
||||
<strong>注意:</strong>
|
||||
<ul>
|
||||
@@ -118,19 +116,19 @@ class EmailService:
|
||||
<li>验证码只能使用一次</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
|
||||
<p>如果你有任何问题,请联系我们的支持团队。</p>
|
||||
|
||||
|
||||
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
|
||||
|
||||
|
||||
<h3>Hello {username}!</h3>
|
||||
<p>Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:</p>
|
||||
|
||||
|
||||
<p>This verification code will expire in <strong>10 minutes</strong>.</p>
|
||||
|
||||
|
||||
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="footer">
|
||||
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
|
||||
<p>This email was sent automatically, please do not reply.</p>
|
||||
@@ -138,26 +136,26 @@ class EmailService:
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
msg.attach(MIMEText(html_content, "html", "utf-8"))
|
||||
|
||||
# 发送邮件
|
||||
if not settings.enable_email_sending:
|
||||
# 邮件发送功能禁用时只记录日志,不实际发送
|
||||
logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}")
|
||||
return True
|
||||
|
||||
|
||||
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
|
||||
if self.smtp_username and self.smtp_password:
|
||||
server.starttls()
|
||||
server.login(self.smtp_username, self.smtp_password)
|
||||
|
||||
|
||||
server.send_message(msg)
|
||||
|
||||
|
||||
logger.info(f"[Email Verification] Successfully sent verification code to {email}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Failed to send email: {e}")
|
||||
return False
|
||||
|
||||
@@ -4,40 +4,38 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.service.email_service import email_service
|
||||
from app.service.email_queue import email_queue # 导入邮件队列
|
||||
from app.log import logger
|
||||
from app.config import settings
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.log import logger
|
||||
from app.service.email_queue import email_queue # 导入邮件队列
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel import select
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class EmailVerificationService:
|
||||
"""邮件验证服务"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_verification_code() -> str:
|
||||
"""生成8位验证码"""
|
||||
return ''.join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
return "".join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
@staticmethod
|
||||
async def send_verification_email_via_queue(email: str, code: str, username: str, user_id: int) -> bool:
|
||||
"""使用邮件队列发送验证邮件
|
||||
|
||||
|
||||
Args:
|
||||
email: 接收验证码的邮箱地址
|
||||
code: 验证码
|
||||
username: 用户名
|
||||
user_id: 用户ID
|
||||
|
||||
|
||||
Returns:
|
||||
是否成功将邮件加入队列
|
||||
"""
|
||||
@@ -103,15 +101,15 @@ class EmailVerificationService:
|
||||
<h1>osu! 邮箱验证</h1>
|
||||
<p>Email Verification</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="content">
|
||||
<h2>你好 {username}!</h2>
|
||||
<p>请使用以下验证码验证您的账户:</p>
|
||||
|
||||
|
||||
<div class="code">{code}</div>
|
||||
|
||||
|
||||
<p>验证码将在 <strong>10 分钟内有效</strong>。</p>
|
||||
|
||||
|
||||
<div class="warning">
|
||||
<p><strong>重要提示:</strong></p>
|
||||
<ul>
|
||||
@@ -120,17 +118,17 @@ class EmailVerificationService:
|
||||
<li>为了账户安全,请勿在其他网站使用相同的密码</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
|
||||
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
|
||||
|
||||
|
||||
<h3>Hello {username}!</h3>
|
||||
<p>Please use the following verification code to verify your account:</p>
|
||||
|
||||
|
||||
<p>This verification code will be valid for <strong>10 minutes</strong>.</p>
|
||||
|
||||
|
||||
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="footer">
|
||||
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
|
||||
<p>This email was sent automatically, please do not reply.</p>
|
||||
@@ -138,8 +136,8 @@ class EmailVerificationService:
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
# 纯文本备用内容
|
||||
plain_content = f"""
|
||||
你好 {username}!
|
||||
@@ -162,34 +160,30 @@ This verification code will be valid for 10 minutes.
|
||||
© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。
|
||||
This email was sent automatically, please do not reply.
|
||||
"""
|
||||
|
||||
|
||||
# 将邮件加入队列
|
||||
subject = "邮箱验证 - Email Verification"
|
||||
metadata = {
|
||||
"type": "email_verification",
|
||||
"user_id": user_id,
|
||||
"code": code
|
||||
}
|
||||
|
||||
metadata = {"type": "email_verification", "user_id": user_id, "code": code}
|
||||
|
||||
await email_queue.enqueue_email(
|
||||
to_email=email,
|
||||
subject=subject,
|
||||
content=plain_content,
|
||||
html_content=html_content,
|
||||
metadata=metadata
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Failed to enqueue email: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_session_token() -> str:
|
||||
"""生成会话令牌"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def create_verification_record(
|
||||
db: AsyncSession,
|
||||
@@ -197,27 +191,27 @@ This email was sent automatically, please do not reply.
|
||||
user_id: int,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
user_agent: str | None = None,
|
||||
) -> tuple[EmailVerification, str]:
|
||||
"""创建邮件验证记录"""
|
||||
|
||||
|
||||
# 检查是否有未过期的验证码
|
||||
existing_result = await db.exec(
|
||||
select(EmailVerification).where(
|
||||
EmailVerification.user_id == user_id,
|
||||
EmailVerification.is_used == False,
|
||||
EmailVerification.expires_at > datetime.now(UTC)
|
||||
col(EmailVerification.is_used).is_(False),
|
||||
EmailVerification.expires_at > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
existing = existing_result.first()
|
||||
|
||||
|
||||
if existing:
|
||||
# 如果有未过期的验证码,直接返回
|
||||
return existing, existing.verification_code
|
||||
|
||||
|
||||
# 生成新的验证码
|
||||
code = EmailVerificationService.generate_verification_code()
|
||||
|
||||
|
||||
# 创建验证记录
|
||||
verification = EmailVerification(
|
||||
user_id=user_id,
|
||||
@@ -225,23 +219,23 @@ This email was sent automatically, please do not reply.
|
||||
verification_code=code,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
|
||||
db.add(verification)
|
||||
await db.commit()
|
||||
await db.refresh(verification)
|
||||
|
||||
|
||||
# 存储到 Redis(用于快速验证)
|
||||
await redis.setex(
|
||||
f"email_verification:{user_id}:{code}",
|
||||
600, # 10分钟过期
|
||||
str(verification.id) if verification.id else "0"
|
||||
str(verification.id) if verification.id else "0",
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}")
|
||||
return verification, code
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def send_verification_email(
|
||||
db: AsyncSession,
|
||||
@@ -250,7 +244,7 @@ This email was sent automatically, please do not reply.
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
user_agent: str | None = None,
|
||||
) -> bool:
|
||||
"""发送验证邮件"""
|
||||
try:
|
||||
@@ -258,33 +252,38 @@ This email was sent automatically, please do not reply.
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
|
||||
return True # 返回成功,但不执行验证流程
|
||||
|
||||
|
||||
# 创建验证记录
|
||||
verification, code = await EmailVerificationService.create_verification_record(
|
||||
(
|
||||
verification,
|
||||
code,
|
||||
) = await EmailVerificationService.create_verification_record(
|
||||
db, redis, user_id, email, ip_address, user_agent
|
||||
)
|
||||
|
||||
|
||||
# 使用邮件队列发送验证邮件
|
||||
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})")
|
||||
logger.info(
|
||||
f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[Email Verification] Failed to enqueue verification email: {email} (user: {username})")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def verify_code(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
code: str,
|
||||
ip_address: str | None = None
|
||||
ip_address: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""验证验证码"""
|
||||
try:
|
||||
@@ -294,46 +293,46 @@ This email was sent automatically, please do not reply.
|
||||
# 仍然标记登录会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
return True, "验证成功(邮件验证功能已禁用)"
|
||||
|
||||
|
||||
# 先从 Redis 检查
|
||||
verification_id = await redis.get(f"email_verification:{user_id}:{code}")
|
||||
if not verification_id:
|
||||
return False, "验证码无效或已过期"
|
||||
|
||||
|
||||
# 从数据库获取验证记录
|
||||
result = await db.exec(
|
||||
select(EmailVerification).where(
|
||||
EmailVerification.id == int(verification_id),
|
||||
EmailVerification.user_id == user_id,
|
||||
EmailVerification.verification_code == code,
|
||||
EmailVerification.is_used == False,
|
||||
EmailVerification.expires_at > datetime.now(UTC)
|
||||
col(EmailVerification.is_used).is_(False),
|
||||
EmailVerification.expires_at > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
verification = result.first()
|
||||
if not verification:
|
||||
return False, "验证码无效或已过期"
|
||||
|
||||
|
||||
# 标记为已使用
|
||||
verification.is_used = True
|
||||
verification.used_at = datetime.now(UTC)
|
||||
|
||||
|
||||
# 同时更新对应的登录会话状态
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
# 删除 Redis 记录
|
||||
await redis.delete(f"email_verification:{user_id}:{code}")
|
||||
|
||||
|
||||
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
|
||||
return True, "验证成功"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during verification code validation: {e}")
|
||||
return False, "验证过程中发生错误"
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def resend_verification_code(
|
||||
db: AsyncSession,
|
||||
@@ -342,7 +341,7 @@ This email was sent automatically, please do not reply.
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
user_agent: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""重新发送验证码"""
|
||||
try:
|
||||
@@ -350,25 +349,25 @@ This email was sent automatically, please do not reply.
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
|
||||
return True, "验证码已发送(邮件验证功能已禁用)"
|
||||
|
||||
|
||||
# 检查重发频率限制(60秒内只能发送一次)
|
||||
rate_limit_key = f"email_verification_rate_limit:{user_id}"
|
||||
if await redis.get(rate_limit_key):
|
||||
return False, "请等待60秒后再重新发送"
|
||||
|
||||
|
||||
# 设置频率限制
|
||||
await redis.setex(rate_limit_key, 60, "1")
|
||||
|
||||
|
||||
# 生成新的验证码
|
||||
success = await EmailVerificationService.send_verification_email(
|
||||
db, redis, user_id, username, email, ip_address, user_agent
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
return True, "验证码已重新发送"
|
||||
else:
|
||||
return False, "重新发送失败,请稍后再试"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during resending verification code: {e}")
|
||||
return False, "重新发送过程中发生错误"
|
||||
@@ -376,7 +375,7 @@ This email was sent automatically, please do not reply.
|
||||
|
||||
class LoginSessionService:
|
||||
"""登录会话服务"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def create_session(
|
||||
db: AsyncSession,
|
||||
@@ -385,47 +384,40 @@ class LoginSessionService:
|
||||
ip_address: str,
|
||||
user_agent: str | None = None,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False
|
||||
is_new_location: bool = False,
|
||||
) -> LoginSession:
|
||||
"""创建登录会话"""
|
||||
from app.utils import simplify_user_agent
|
||||
|
||||
|
||||
session_token = EmailVerificationService.generate_session_token()
|
||||
|
||||
# 简化 User-Agent 字符串
|
||||
simplified_user_agent = simplify_user_agent(user_agent, max_length=250)
|
||||
|
||||
|
||||
session = LoginSession(
|
||||
user_id=user_id,
|
||||
session_token=session_token,
|
||||
ip_address=ip_address,
|
||||
user_agent=simplified_user_agent,
|
||||
user_agent=None,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location,
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期
|
||||
is_verified=not is_new_location # 新位置需要验证
|
||||
is_verified=not is_new_location, # 新位置需要验证
|
||||
)
|
||||
|
||||
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
|
||||
# 存储到 Redis
|
||||
await redis.setex(
|
||||
f"login_session:{session_token}",
|
||||
86400, # 24小时
|
||||
user_id
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
|
||||
return session
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def verify_session(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
session_token: str,
|
||||
verification_code: str
|
||||
db: AsyncSession, redis: Redis, session_token: str, verification_code: str
|
||||
) -> tuple[bool, str]:
|
||||
"""验证会话(通过邮件验证码)"""
|
||||
try:
|
||||
@@ -433,98 +425,89 @@ class LoginSessionService:
|
||||
user_id = await redis.get(f"login_session:{session_token}")
|
||||
if not user_id:
|
||||
return False, "会话无效或已过期"
|
||||
|
||||
|
||||
user_id = int(user_id)
|
||||
|
||||
|
||||
# 验证邮件验证码
|
||||
success, message = await EmailVerificationService.verify_code(
|
||||
db, redis, user_id, verification_code
|
||||
)
|
||||
|
||||
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code)
|
||||
|
||||
if not success:
|
||||
return False, message
|
||||
|
||||
|
||||
# 更新会话状态
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.session_token == session_token,
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.is_verified == False
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
session = result.first()
|
||||
if session:
|
||||
session.is_verified = True
|
||||
session.verified_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
|
||||
|
||||
logger.info(f"[Login Session] User {user_id} session verification successful")
|
||||
return True, "会话验证成功"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during session verification: {e}")
|
||||
return False, "验证过程中发生错误"
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def check_new_location(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
ip_address: str,
|
||||
country_code: str | None = None
|
||||
db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None
|
||||
) -> bool:
|
||||
"""检查是否为新位置登录"""
|
||||
try:
|
||||
# 查看过去30天内是否有相同IP或相同国家的登录记录
|
||||
thirty_days_ago = datetime.now(UTC) - timedelta(days=30)
|
||||
|
||||
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.created_at > thirty_days_ago,
|
||||
(LoginSession.ip_address == ip_address) |
|
||||
(LoginSession.country_code == country_code)
|
||||
(LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
existing_sessions = result.all()
|
||||
|
||||
|
||||
# 如果有历史记录,则不是新位置
|
||||
return len(existing_sessions) == 0
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during new location check: {e}")
|
||||
# 出错时默认为新位置(更安全)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def mark_session_verified(
|
||||
db: AsyncSession,
|
||||
user_id: int
|
||||
) -> bool:
|
||||
async def mark_session_verified(db: AsyncSession, user_id: int) -> bool:
|
||||
"""标记用户的未验证会话为已验证"""
|
||||
try:
|
||||
# 查找用户所有未验证且未过期的会话
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.is_verified == False,
|
||||
LoginSession.expires_at > datetime.now(UTC)
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
sessions = result.all()
|
||||
|
||||
|
||||
# 标记所有会话为已验证
|
||||
for session in sessions:
|
||||
session.is_verified = True
|
||||
session.verified_at = datetime.now(UTC)
|
||||
|
||||
|
||||
if sessions:
|
||||
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
|
||||
|
||||
|
||||
return len(sessions) > 0
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
|
||||
return False
|
||||
|
||||
@@ -117,14 +117,10 @@ class EnhancedIntervalStatsManager:
|
||||
@staticmethod
|
||||
async def get_current_interval_info() -> IntervalInfo:
|
||||
"""获取当前区间信息"""
|
||||
start_time, end_time = (
|
||||
EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
)
|
||||
start_time, end_time = EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time)
|
||||
|
||||
return IntervalInfo(
|
||||
start_time=start_time, end_time=end_time, interval_key=interval_key
|
||||
)
|
||||
return IntervalInfo(start_time=start_time, end_time=end_time, interval_key=interval_key)
|
||||
|
||||
@staticmethod
|
||||
async def initialize_current_interval() -> None:
|
||||
@@ -133,9 +129,7 @@ class EnhancedIntervalStatsManager:
|
||||
redis_async = get_redis()
|
||||
|
||||
try:
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
|
||||
# 存储当前区间信息
|
||||
await _redis_exec(
|
||||
@@ -147,9 +141,7 @@ class EnhancedIntervalStatsManager:
|
||||
|
||||
# 初始化区间用户集合(如果不存在)
|
||||
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
|
||||
playing_key = (
|
||||
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
)
|
||||
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
|
||||
# 设置过期时间为35分钟
|
||||
await redis_async.expire(online_key, 35 * 60)
|
||||
@@ -179,7 +171,8 @@ class EnhancedIntervalStatsManager:
|
||||
await EnhancedIntervalStatsManager._ensure_24h_history_exists()
|
||||
|
||||
logger.info(
|
||||
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}"
|
||||
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')}"
|
||||
f" - {current_interval.end_time.strftime('%H:%M')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -193,42 +186,32 @@ class EnhancedIntervalStatsManager:
|
||||
|
||||
try:
|
||||
# 检查现有历史数据数量
|
||||
history_length = await _redis_exec(
|
||||
redis_sync.llen, REDIS_ONLINE_HISTORY_KEY
|
||||
)
|
||||
history_length = await _redis_exec(redis_sync.llen, REDIS_ONLINE_HISTORY_KEY)
|
||||
|
||||
if history_length < 48: # 少于48个数据点(24小时*2)
|
||||
logger.info(
|
||||
f"History has only {history_length} points, filling with zeros for 24h"
|
||||
)
|
||||
logger.info(f"History has only {history_length} points, filling with zeros for 24h")
|
||||
|
||||
# 计算需要填充的数据点数量
|
||||
needed_points = 48 - history_length
|
||||
|
||||
# 从当前时间往前推,创建缺失的时间点(都填充为0)
|
||||
current_time = datetime.utcnow()
|
||||
current_interval_start, _ = (
|
||||
EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
)
|
||||
current_time = datetime.utcnow() # noqa: F841
|
||||
current_interval_start, _ = EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
|
||||
# 从当前区间开始往前推,创建历史数据点(确保时间对齐到30分钟边界)
|
||||
fill_points = []
|
||||
for i in range(needed_points):
|
||||
# 每次往前推30分钟,确保时间对齐
|
||||
point_time = current_interval_start - timedelta(
|
||||
minutes=30 * (i + 1)
|
||||
)
|
||||
point_time = current_interval_start - timedelta(minutes=30 * (i + 1))
|
||||
|
||||
# 确保时间对齐到30分钟边界
|
||||
aligned_minute = (point_time.minute // 30) * 30
|
||||
point_time = point_time.replace(
|
||||
minute=aligned_minute, second=0, microsecond=0
|
||||
)
|
||||
point_time = point_time.replace(minute=aligned_minute, second=0, microsecond=0)
|
||||
|
||||
history_point = {
|
||||
"timestamp": point_time.isoformat(),
|
||||
"online_count": 0,
|
||||
"playing_count": 0
|
||||
"playing_count": 0,
|
||||
}
|
||||
fill_points.append(json.dumps(history_point))
|
||||
|
||||
@@ -238,9 +221,7 @@ class EnhancedIntervalStatsManager:
|
||||
temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp"
|
||||
if history_length > 0:
|
||||
# 复制现有数据到临时key
|
||||
existing_data = await _redis_exec(
|
||||
redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1
|
||||
)
|
||||
existing_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
|
||||
if existing_data:
|
||||
for data in existing_data:
|
||||
await _redis_exec(redis_sync.rpush, temp_key, data)
|
||||
@@ -250,19 +231,13 @@ class EnhancedIntervalStatsManager:
|
||||
|
||||
# 先添加填充数据(最旧的)
|
||||
for point in reversed(fill_points): # 反向添加,最旧的在最后
|
||||
await _redis_exec(
|
||||
redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point
|
||||
)
|
||||
await _redis_exec(redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point)
|
||||
|
||||
# 再添加原有数据(较新的)
|
||||
if history_length > 0:
|
||||
existing_data = await _redis_exec(
|
||||
redis_sync.lrange, temp_key, 0, -1
|
||||
)
|
||||
existing_data = await _redis_exec(redis_sync.lrange, temp_key, 0, -1)
|
||||
for data in existing_data:
|
||||
await _redis_exec(
|
||||
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data
|
||||
)
|
||||
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data)
|
||||
|
||||
# 清理临时key
|
||||
await redis_async.delete(temp_key)
|
||||
@@ -273,9 +248,7 @@ class EnhancedIntervalStatsManager:
|
||||
# 设置过期时间
|
||||
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
|
||||
|
||||
logger.info(
|
||||
f"Filled {len(fill_points)} historical data points with zeros"
|
||||
)
|
||||
logger.info(f"Filled {len(fill_points)} historical data points with zeros")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring 24h history exists: {e}")
|
||||
@@ -287,9 +260,7 @@ class EnhancedIntervalStatsManager:
|
||||
redis_async = get_redis()
|
||||
|
||||
try:
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
|
||||
# 添加到区间在线用户集合
|
||||
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
|
||||
@@ -298,9 +269,7 @@ class EnhancedIntervalStatsManager:
|
||||
|
||||
# 如果用户在游玩,也添加到游玩用户集合
|
||||
if is_playing:
|
||||
playing_key = (
|
||||
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
)
|
||||
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
await _redis_exec(redis_sync.sadd, playing_key, str(user_id))
|
||||
await redis_async.expire(playing_key, 35 * 60)
|
||||
|
||||
@@ -308,7 +277,8 @@ class EnhancedIntervalStatsManager:
|
||||
await EnhancedIntervalStatsManager._update_interval_stats()
|
||||
|
||||
logger.debug(
|
||||
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}"
|
||||
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}"
|
||||
f"-{current_interval.end_time.strftime('%H:%M')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -321,15 +291,11 @@ class EnhancedIntervalStatsManager:
|
||||
redis_async = get_redis()
|
||||
|
||||
try:
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
|
||||
# 获取区间内独特用户数
|
||||
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
|
||||
playing_key = (
|
||||
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
)
|
||||
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
|
||||
unique_online = await _redis_exec(redis_sync.scard, online_key)
|
||||
unique_playing = await _redis_exec(redis_sync.scard, playing_key)
|
||||
@@ -339,16 +305,12 @@ class EnhancedIntervalStatsManager:
|
||||
current_playing = await _get_playing_users_count(redis_async)
|
||||
|
||||
# 获取现有统计数据
|
||||
existing_data = await _redis_exec(
|
||||
redis_sync.get, current_interval.interval_key
|
||||
)
|
||||
existing_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
|
||||
if existing_data:
|
||||
stats = IntervalStats.from_dict(json.loads(existing_data))
|
||||
# 更新峰值
|
||||
stats.peak_online_count = max(stats.peak_online_count, current_online)
|
||||
stats.peak_playing_count = max(
|
||||
stats.peak_playing_count, current_playing
|
||||
)
|
||||
stats.peak_playing_count = max(stats.peak_playing_count, current_playing)
|
||||
stats.total_samples += 1
|
||||
else:
|
||||
# 创建新的统计记录
|
||||
@@ -377,7 +339,8 @@ class EnhancedIntervalStatsManager:
|
||||
await redis_async.expire(current_interval.interval_key, 35 * 60)
|
||||
|
||||
logger.debug(
|
||||
f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
|
||||
f"Updated interval stats: online={unique_online}, playing={unique_playing}, "
|
||||
f"peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -395,21 +358,21 @@ class EnhancedIntervalStatsManager:
|
||||
# 上一个区间开始时间是当前区间开始时间减去30分钟
|
||||
previous_start = current_start - timedelta(minutes=30)
|
||||
previous_end = current_start # 上一个区间的结束时间就是当前区间的开始时间
|
||||
|
||||
|
||||
interval_key = EnhancedIntervalStatsManager.generate_interval_key(previous_start)
|
||||
|
||||
|
||||
previous_interval = IntervalInfo(
|
||||
start_time=previous_start,
|
||||
end_time=previous_end,
|
||||
interval_key=interval_key
|
||||
interval_key=interval_key,
|
||||
)
|
||||
|
||||
# 获取最终统计数据
|
||||
stats_data = await _redis_exec(
|
||||
redis_sync.get, previous_interval.interval_key
|
||||
)
|
||||
stats_data = await _redis_exec(redis_sync.get, previous_interval.interval_key)
|
||||
if not stats_data:
|
||||
logger.warning(f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}")
|
||||
logger.warning(
|
||||
f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}"
|
||||
)
|
||||
return None
|
||||
|
||||
stats = IntervalStats.from_dict(json.loads(stats_data))
|
||||
@@ -418,13 +381,11 @@ class EnhancedIntervalStatsManager:
|
||||
history_point = {
|
||||
"timestamp": previous_interval.start_time.isoformat(),
|
||||
"online_count": stats.unique_online_users,
|
||||
"playing_count": stats.unique_playing_users
|
||||
"playing_count": stats.unique_playing_users,
|
||||
}
|
||||
|
||||
# 添加到历史记录
|
||||
await _redis_exec(
|
||||
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)
|
||||
)
|
||||
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
|
||||
# 只保留48个数据点(24小时,每30分钟一个点)
|
||||
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
|
||||
# 设置过期时间为26小时,确保有足够缓冲
|
||||
@@ -452,12 +413,8 @@ class EnhancedIntervalStatsManager:
|
||||
redis_sync = get_redis_message()
|
||||
|
||||
try:
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
stats_data = await _redis_exec(
|
||||
redis_sync.get, current_interval.interval_key
|
||||
)
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
|
||||
|
||||
if stats_data:
|
||||
return IntervalStats.from_dict(json.loads(stats_data))
|
||||
@@ -506,8 +463,6 @@ class EnhancedIntervalStatsManager:
|
||||
|
||||
|
||||
# 便捷函数,用于替换现有的统计更新函数
|
||||
async def update_user_activity_in_interval(
|
||||
user_id: int, is_playing: bool = False
|
||||
) -> None:
|
||||
async def update_user_activity_in_interval(user_id: int, is_playing: bool = False) -> None:
|
||||
"""用户活动时更新区间统计(在登录、开始游玩等时调用)"""
|
||||
await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing)
|
||||
|
||||
@@ -11,12 +11,8 @@ def load_achievements() -> Medals:
|
||||
for module in ACHIEVEMENTS_DIR.iterdir():
|
||||
if module.is_file() and module.suffix == ".py":
|
||||
module_name = module.stem
|
||||
module_achievements = importlib.import_module(
|
||||
f"app.achievements.{module_name}"
|
||||
)
|
||||
module_achievements = importlib.import_module(f"app.achievements.{module_name}")
|
||||
medals = getattr(module_achievements, "MEDALS", {})
|
||||
MEDALS.update(medals)
|
||||
logger.success(
|
||||
f"Successfully loaded {len(medals)} achievements from {module_name}.py"
|
||||
)
|
||||
logger.success(f"Successfully loaded {len(medals)} achievements from {module_name}.py")
|
||||
return MEDALS
|
||||
|
||||
@@ -47,6 +47,7 @@ class LoginLogService:
|
||||
|
||||
# 获取并简化User-Agent
|
||||
from app.utils import simplify_user_agent
|
||||
|
||||
raw_user_agent = request.headers.get("User-Agent", "")
|
||||
user_agent = simplify_user_agent(raw_user_agent, max_length=500)
|
||||
|
||||
@@ -67,9 +68,7 @@ class LoginLogService:
|
||||
|
||||
# 在后台线程中运行GeoIP查询(避免阻塞)
|
||||
loop = asyncio.get_event_loop()
|
||||
geo_info = await loop.run_in_executor(
|
||||
None, lambda: geoip.lookup(ip_address)
|
||||
)
|
||||
geo_info = await loop.run_in_executor(None, lambda: geoip.lookup(ip_address))
|
||||
|
||||
if geo_info:
|
||||
login_log.country_code = geo_info.get("country_iso", "")
|
||||
@@ -89,10 +88,7 @@ class LoginLogService:
|
||||
|
||||
login_log.organization = geo_info.get("organization", "")
|
||||
|
||||
logger.debug(
|
||||
f"GeoIP lookup for {ip_address}: "
|
||||
f"{geo_info.get('country_name', 'Unknown')}"
|
||||
)
|
||||
logger.debug(f"GeoIP lookup for {ip_address}: {geo_info.get('country_name', 'Unknown')}")
|
||||
else:
|
||||
logger.warning(f"GeoIP lookup failed for {ip_address}")
|
||||
|
||||
@@ -104,9 +100,7 @@ class LoginLogService:
|
||||
await db.commit()
|
||||
await db.refresh(login_log)
|
||||
|
||||
logger.info(
|
||||
f"Login recorded for user {user_id} from {ip_address} ({login_method})"
|
||||
)
|
||||
logger.info(f"Login recorded for user {user_id} from {ip_address} ({login_method})")
|
||||
return login_log
|
||||
|
||||
@staticmethod
|
||||
@@ -137,9 +131,7 @@ class LoginLogService:
|
||||
request=request,
|
||||
login_success=False,
|
||||
login_method=login_method,
|
||||
notes=f"Failed login attempt: {attempted_username}"
|
||||
if attempted_username
|
||||
else "Failed login attempt",
|
||||
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import uuid
|
||||
from app.database.chat import ChatMessage, MessageType
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.log import logger
|
||||
from app.utils import bg_tasks
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
@@ -34,7 +35,7 @@ class MessageQueue:
|
||||
"""启动消息处理任务"""
|
||||
if not self._processing:
|
||||
self._processing = True
|
||||
asyncio.create_task(self._process_message_queue())
|
||||
bg_tasks.add_task(self._process_message_queue)
|
||||
logger.info("Message queue processing started")
|
||||
|
||||
async def stop_processing(self):
|
||||
@@ -59,12 +60,8 @@ class MessageQueue:
|
||||
message_data["status"] = "pending" # pending, processing, completed, failed
|
||||
|
||||
# 将消息存储到 Redis
|
||||
await self._run_in_executor(
|
||||
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data)
|
||||
)
|
||||
await self._run_in_executor(
|
||||
self.redis.expire, f"msg:{temp_uuid}", 3600
|
||||
) # 1小时过期
|
||||
await self._run_in_executor(lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data))
|
||||
await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
|
||||
|
||||
# 加入处理队列
|
||||
await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid)
|
||||
@@ -74,17 +71,13 @@ class MessageQueue:
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> dict | None:
|
||||
"""获取消息状态"""
|
||||
message_data = await self._run_in_executor(
|
||||
self.redis.hgetall, f"msg:{temp_uuid}"
|
||||
)
|
||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
|
||||
if not message_data:
|
||||
return None
|
||||
|
||||
return message_data
|
||||
|
||||
async def get_cached_messages(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict]:
|
||||
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
||||
"""
|
||||
从 Redis 获取缓存的消息
|
||||
|
||||
@@ -97,15 +90,11 @@ class MessageQueue:
|
||||
消息列表
|
||||
"""
|
||||
# 从 Redis 获取频道最近的消息 UUID 列表
|
||||
message_uuids = await self._run_in_executor(
|
||||
self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1
|
||||
)
|
||||
message_uuids = await self._run_in_executor(self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1)
|
||||
|
||||
messages = []
|
||||
for uuid_str in message_uuids:
|
||||
message_data = await self._run_in_executor(
|
||||
self.redis.hgetall, f"msg:{uuid_str}"
|
||||
)
|
||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}")
|
||||
if message_data:
|
||||
# 检查是否满足 since 条件
|
||||
if since > 0 and "message_id" in message_data:
|
||||
@@ -116,22 +105,14 @@ class MessageQueue:
|
||||
|
||||
return messages[::-1] # 返回时间顺序
|
||||
|
||||
async def cache_channel_message(
|
||||
self, channel_id: int, temp_uuid: str, max_cache: int = 100
|
||||
):
|
||||
async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100):
|
||||
"""将消息 UUID 缓存到频道消息列表"""
|
||||
# 添加到频道消息列表开头
|
||||
await self._run_in_executor(
|
||||
self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid
|
||||
)
|
||||
await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid)
|
||||
# 限制缓存大小
|
||||
await self._run_in_executor(
|
||||
self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1
|
||||
)
|
||||
await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1)
|
||||
# 设置过期时间(24小时)
|
||||
await self._run_in_executor(
|
||||
self.redis.expire, f"channel:{channel_id}:messages", 86400
|
||||
)
|
||||
await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400)
|
||||
|
||||
async def _process_message_queue(self):
|
||||
"""异步处理消息队列,批量写入数据库"""
|
||||
@@ -140,9 +121,7 @@ class MessageQueue:
|
||||
# 批量获取消息
|
||||
message_uuids = []
|
||||
for _ in range(self._batch_size):
|
||||
result = await self._run_in_executor(
|
||||
lambda: self.redis.brpop(["message_queue"], timeout=1)
|
||||
)
|
||||
result = await self._run_in_executor(lambda: self.redis.brpop(["message_queue"], timeout=1))
|
||||
if result:
|
||||
message_uuids.append(result[1])
|
||||
else:
|
||||
@@ -166,16 +145,12 @@ class MessageQueue:
|
||||
for temp_uuid in message_uuids:
|
||||
try:
|
||||
# 获取消息数据
|
||||
message_data = await self._run_in_executor(
|
||||
self.redis.hgetall, f"msg:{temp_uuid}"
|
||||
)
|
||||
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
|
||||
if not message_data:
|
||||
continue
|
||||
|
||||
# 更新状态为处理中
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"msg:{temp_uuid}", "status", "processing"
|
||||
)
|
||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing")
|
||||
|
||||
# 创建数据库消息对象
|
||||
msg = ChatMessage(
|
||||
@@ -190,9 +165,7 @@ class MessageQueue:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing message {temp_uuid}: {e}")
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
|
||||
)
|
||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
|
||||
|
||||
if messages_to_insert:
|
||||
try:
|
||||
@@ -211,16 +184,12 @@ class MessageQueue:
|
||||
mapping={
|
||||
"status": "completed",
|
||||
"message_id": str(msg.message_id),
|
||||
"created_at": msg.timestamp.isoformat()
|
||||
if msg.timestamp
|
||||
else "",
|
||||
"created_at": msg.timestamp.isoformat() if msg.timestamp else "",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}"
|
||||
)
|
||||
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting messages to database: {e}")
|
||||
@@ -228,9 +197,7 @@ class MessageQueue:
|
||||
|
||||
# 标记所有消息为失败
|
||||
for _, temp_uuid in messages_to_insert:
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
|
||||
)
|
||||
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
|
||||
|
||||
|
||||
# 全局消息队列实例
|
||||
|
||||
@@ -33,36 +33,22 @@ class MessageQueueProcessor:
|
||||
"""将消息缓存到 Redis"""
|
||||
try:
|
||||
# 存储消息数据
|
||||
await self._redis_exec(
|
||||
self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data
|
||||
)
|
||||
await self._redis_exec(
|
||||
self.redis_message.expire, f"msg:{temp_uuid}", 3600
|
||||
) # 1小时过期
|
||||
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data)
|
||||
await self._redis_exec(self.redis_message.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
|
||||
|
||||
# 加入频道消息列表
|
||||
await self._redis_exec(
|
||||
self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid
|
||||
)
|
||||
await self._redis_exec(
|
||||
self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99
|
||||
) # 保持最新100条
|
||||
await self._redis_exec(
|
||||
self.redis_message.expire, f"channel:{channel_id}:messages", 86400
|
||||
) # 24小时过期
|
||||
await self._redis_exec(self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid)
|
||||
await self._redis_exec(self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99) # 保持最新100条
|
||||
await self._redis_exec(self.redis_message.expire, f"channel:{channel_id}:messages", 86400) # 24小时过期
|
||||
|
||||
# 加入异步处理队列
|
||||
await self._redis_exec(
|
||||
self.redis_message.lpush, "message_write_queue", temp_uuid
|
||||
)
|
||||
await self._redis_exec(self.redis_message.lpush, "message_write_queue", temp_uuid)
|
||||
|
||||
logger.info(f"Message cached to Redis: {temp_uuid}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cache message to Redis: {e}")
|
||||
|
||||
async def get_cached_messages(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict]:
|
||||
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
||||
"""从 Redis 获取缓存的消息"""
|
||||
try:
|
||||
message_uuids = await self._redis_exec(
|
||||
@@ -78,15 +64,11 @@ class MessageQueueProcessor:
|
||||
if isinstance(temp_uuid, bytes):
|
||||
temp_uuid = temp_uuid.decode("utf-8")
|
||||
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis_message.hgetall, f"msg:{temp_uuid}"
|
||||
)
|
||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
||||
if raw_data:
|
||||
# 解码 Redis 返回的字节数据
|
||||
message_data = {
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode(
|
||||
"utf-8"
|
||||
)
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8")
|
||||
if isinstance(v, bytes)
|
||||
else v
|
||||
for k, v in raw_data.items()
|
||||
@@ -103,9 +85,7 @@ class MessageQueueProcessor:
|
||||
logger.error(f"Failed to get cached messages: {e}")
|
||||
return []
|
||||
|
||||
async def update_message_status(
|
||||
self, temp_uuid: str, status: str, message_id: int | None = None
|
||||
):
|
||||
async def update_message_status(self, temp_uuid: str, status: str, message_id: int | None = None):
|
||||
"""更新消息状态"""
|
||||
try:
|
||||
update_data = {"status": status}
|
||||
@@ -113,26 +93,20 @@ class MessageQueueProcessor:
|
||||
update_data["message_id"] = str(message_id)
|
||||
update_data["db_timestamp"] = datetime.now().isoformat()
|
||||
|
||||
await self._redis_exec(
|
||||
self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data
|
||||
)
|
||||
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update message status: {e}")
|
||||
|
||||
async def get_message_status(self, temp_uuid: str) -> dict | None:
|
||||
"""获取消息状态"""
|
||||
try:
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis_message.hgetall, f"msg:{temp_uuid}"
|
||||
)
|
||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
||||
if not raw_data:
|
||||
return None
|
||||
|
||||
# 解码 Redis 返回的字节数据
|
||||
return {
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8")
|
||||
if isinstance(v, bytes)
|
||||
else v
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v
|
||||
for k, v in raw_data.items()
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -148,9 +122,7 @@ class MessageQueueProcessor:
|
||||
# 批量获取消息
|
||||
message_uuids = []
|
||||
for _ in range(20): # 批量处理20条消息
|
||||
result = await self._redis_exec(
|
||||
self.redis_message.brpop, ["message_write_queue"], timeout=1
|
||||
)
|
||||
result = await self._redis_exec(self.redis_message.brpop, ["message_write_queue"], timeout=1)
|
||||
if result:
|
||||
# result是 (queue_name, value) 的元组,需要解码
|
||||
uuid_value = result[1]
|
||||
@@ -179,17 +151,13 @@ class MessageQueueProcessor:
|
||||
for temp_uuid in message_uuids:
|
||||
try:
|
||||
# 获取消息数据并解码
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis_message.hgetall, f"msg:{temp_uuid}"
|
||||
)
|
||||
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
|
||||
if not raw_data:
|
||||
continue
|
||||
|
||||
# 解码 Redis 返回的字节数据
|
||||
message_data = {
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode(
|
||||
"utf-8"
|
||||
)
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8")
|
||||
if isinstance(v, bytes)
|
||||
else v
|
||||
for k, v in raw_data.items()
|
||||
@@ -215,10 +183,7 @@ class MessageQueueProcessor:
|
||||
await session.refresh(msg)
|
||||
|
||||
# 更新成功状态,包含临时消息ID映射
|
||||
assert msg.message_id is not None
|
||||
await self.update_message_status(
|
||||
temp_uuid, "completed", msg.message_id
|
||||
)
|
||||
await self.update_message_status(temp_uuid, "completed", msg.message_id)
|
||||
|
||||
# 如果有临时消息ID,存储映射关系并通知客户端更新
|
||||
if message_data.get("temp_message_id"):
|
||||
@@ -232,12 +197,11 @@ class MessageQueueProcessor:
|
||||
|
||||
# 发送消息ID更新通知到频道
|
||||
channel_id = int(message_data["channel_id"])
|
||||
await self._notify_message_update(
|
||||
channel_id, temp_msg_id, msg.message_id, message_data
|
||||
)
|
||||
await self._notify_message_update(channel_id, temp_msg_id, msg.message_id, message_data)
|
||||
|
||||
logger.info(
|
||||
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}"
|
||||
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, "
|
||||
f"temp_id: {message_data.get('temp_message_id')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -272,9 +236,7 @@ class MessageQueueProcessor:
|
||||
json.dumps(update_event),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}"
|
||||
)
|
||||
logger.info(f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to notify message update: {e}")
|
||||
@@ -320,9 +282,7 @@ async def cache_message_to_redis(channel_id: int, message_data: dict, temp_uuid:
|
||||
await message_queue_processor.cache_message(channel_id, message_data, temp_uuid)
|
||||
|
||||
|
||||
async def get_cached_messages(
|
||||
channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict]:
|
||||
async def get_cached_messages(channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
||||
"""从 Redis 获取缓存的消息 - 便捷接口"""
|
||||
return await message_queue_processor.get_cached_messages(channel_id, limit, since)
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
此模块提供在游玩状态下维护用户在线状态的功能,
|
||||
解决游玩时显示离线的问题。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
from app.log import logger
|
||||
@@ -17,32 +17,32 @@ from app.router.v2.stats import REDIS_PLAYING_USERS_KEY, _redis_exec, get_redis_
|
||||
async def maintain_playing_users_online_status():
|
||||
"""
|
||||
维护正在游玩用户的在线状态
|
||||
|
||||
|
||||
定期刷新正在游玩用户的metadata在线标记,
|
||||
确保他们在游玩过程中显示为在线状态。
|
||||
"""
|
||||
redis_sync = get_redis_message()
|
||||
redis_async = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
# 获取所有正在游玩的用户
|
||||
playing_users = await _redis_exec(redis_sync.smembers, REDIS_PLAYING_USERS_KEY)
|
||||
|
||||
|
||||
if not playing_users:
|
||||
return
|
||||
|
||||
|
||||
logger.debug(f"Maintaining online status for {len(playing_users)} playing users")
|
||||
|
||||
|
||||
# 为每个游玩用户刷新metadata在线标记
|
||||
for user_id in playing_users:
|
||||
user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id)
|
||||
metadata_key = f"metadata:online:{user_id_str}"
|
||||
|
||||
|
||||
# 设置或刷新metadata在线标记,过期时间为1小时
|
||||
await redis_async.set(metadata_key, "playing", ex=3600)
|
||||
|
||||
|
||||
logger.debug(f"Updated metadata online status for {len(playing_users)} playing users")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error maintaining playing users online status: {e}")
|
||||
|
||||
@@ -50,11 +50,11 @@ async def maintain_playing_users_online_status():
|
||||
async def start_online_status_maintenance_task():
|
||||
"""
|
||||
启动在线状态维护任务
|
||||
|
||||
|
||||
每5分钟运行一次维护任务,确保游玩用户保持在线状态
|
||||
"""
|
||||
logger.info("Starting online status maintenance task")
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
await maintain_playing_users_online_status()
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
|
||||
此模块负责统一管理用户的在线状态,确保用户在连接WebSocket后立即显示为在线。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
@@ -15,92 +15,93 @@ from app.router.v2.stats import add_online_user
|
||||
|
||||
class OnlineStatusManager:
|
||||
"""在线状态管理器"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def set_user_online(user_id: int, hub_type: str = "general") -> None:
|
||||
"""
|
||||
设置用户为在线状态
|
||||
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
hub_type: Hub类型 (metadata, spectator, multiplayer等)
|
||||
"""
|
||||
try:
|
||||
redis = get_redis()
|
||||
|
||||
|
||||
# 1. 添加到在线用户集合
|
||||
await add_online_user(user_id)
|
||||
|
||||
|
||||
# 2. 设置metadata在线标记,这是is_online检查的关键
|
||||
metadata_key = f"metadata:online:{user_id}"
|
||||
await redis.set(metadata_key, hub_type, ex=7200) # 2小时过期
|
||||
|
||||
|
||||
# 3. 设置最后活跃时间戳
|
||||
last_seen_key = f"user:last_seen:{user_id}"
|
||||
await redis.set(last_seen_key, int(datetime.utcnow().timestamp()), ex=7200)
|
||||
|
||||
|
||||
logger.debug(f"[OnlineStatusManager] User {user_id} set online via {hub_type}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OnlineStatusManager] Error setting user {user_id} online: {e}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def refresh_user_online_status(user_id: int, hub_type: str = "active") -> None:
|
||||
"""
|
||||
刷新用户的在线状态
|
||||
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
hub_type: 当前活动类型
|
||||
"""
|
||||
try:
|
||||
redis = get_redis()
|
||||
|
||||
|
||||
# 刷新metadata在线标记
|
||||
metadata_key = f"metadata:online:{user_id}"
|
||||
await redis.set(metadata_key, hub_type, ex=7200)
|
||||
|
||||
|
||||
# 刷新最后活跃时间
|
||||
last_seen_key = f"user:last_seen:{user_id}"
|
||||
await redis.set(last_seen_key, int(datetime.utcnow().timestamp()), ex=7200)
|
||||
|
||||
|
||||
logger.debug(f"[OnlineStatusManager] Refreshed online status for user {user_id}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OnlineStatusManager] Error refreshing user {user_id} status: {e}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def set_user_offline(user_id: int) -> None:
|
||||
"""
|
||||
设置用户为离线状态
|
||||
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
"""
|
||||
try:
|
||||
redis = get_redis()
|
||||
|
||||
|
||||
# 删除metadata在线标记
|
||||
metadata_key = f"metadata:online:{user_id}"
|
||||
await redis.delete(metadata_key)
|
||||
|
||||
|
||||
# 从在线用户集合中移除
|
||||
from app.router.v2.stats import remove_online_user
|
||||
|
||||
await remove_online_user(user_id)
|
||||
|
||||
|
||||
logger.debug(f"[OnlineStatusManager] User {user_id} set offline")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OnlineStatusManager] Error setting user {user_id} offline: {e}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def is_user_online(user_id: int) -> bool:
|
||||
"""
|
||||
检查用户是否在线
|
||||
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 用户是否在线
|
||||
"""
|
||||
@@ -112,19 +113,19 @@ class OnlineStatusManager:
|
||||
except Exception as e:
|
||||
logger.error(f"[OnlineStatusManager] Error checking user {user_id} online status: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def get_online_users_count() -> int:
|
||||
"""
|
||||
获取在线用户数量
|
||||
|
||||
|
||||
Returns:
|
||||
int: 在线用户数量
|
||||
"""
|
||||
try:
|
||||
from app.router.v2.stats import _get_online_users_count
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
from app.router.v2.stats import _get_online_users_count
|
||||
|
||||
redis = get_redis()
|
||||
return await _get_online_users_count(redis)
|
||||
except Exception as e:
|
||||
|
||||
@@ -50,7 +50,6 @@ class OptimizedMessageService:
|
||||
Returns:
|
||||
消息响应对象
|
||||
"""
|
||||
assert sender.id is not None
|
||||
|
||||
# 准备消息数据
|
||||
message_data = {
|
||||
@@ -97,9 +96,7 @@ class OptimizedMessageService:
|
||||
logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}")
|
||||
return temp_response
|
||||
|
||||
async def get_cached_messages(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict]:
|
||||
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
|
||||
"""
|
||||
获取缓存的消息
|
||||
|
||||
@@ -125,9 +122,7 @@ class OptimizedMessageService:
|
||||
"""
|
||||
return await self.message_queue.get_message_status(temp_uuid)
|
||||
|
||||
async def wait_for_message_persisted(
|
||||
self, temp_uuid: str, timeout: int = 30
|
||||
) -> dict | None:
|
||||
async def wait_for_message_persisted(self, temp_uuid: str, timeout: int = 30) -> dict | None: # noqa: ASYNC109
|
||||
"""
|
||||
等待消息持久化到数据库
|
||||
|
||||
|
||||
@@ -4,74 +4,67 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from typing import Optional, Tuple
|
||||
import json
|
||||
|
||||
from app.config import settings
|
||||
from app.auth import get_password_hash, invalidate_user_tokens
|
||||
from app.database import User
|
||||
from app.dependencies.database import with_db
|
||||
from app.service.email_service import EmailService
|
||||
from app.service.email_queue import email_queue # 导入邮件队列
|
||||
from app.log import logger
|
||||
from app.auth import get_password_hash, invalidate_user_tokens
|
||||
from app.service.email_queue import email_queue # 导入邮件队列
|
||||
from app.service.email_service import EmailService
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
class PasswordResetService:
|
||||
"""密码重置服务 - 使用Redis管理验证码"""
|
||||
|
||||
|
||||
# Redis键前缀
|
||||
RESET_CODE_PREFIX = "password_reset:code:" # 存储验证码
|
||||
RESET_RATE_LIMIT_PREFIX = "password_reset:rate_limit:" # 限制请求频率
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.email_service = EmailService()
|
||||
|
||||
|
||||
def generate_reset_code(self) -> str:
|
||||
"""生成8位重置验证码"""
|
||||
return ''.join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
return "".join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
def _get_reset_code_key(self, email: str) -> str:
|
||||
"""获取验证码Redis键"""
|
||||
return f"{self.RESET_CODE_PREFIX}{email.lower()}"
|
||||
|
||||
|
||||
def _get_rate_limit_key(self, email: str) -> str:
|
||||
"""获取频率限制Redis键"""
|
||||
return f"{self.RESET_RATE_LIMIT_PREFIX}{email.lower()}"
|
||||
|
||||
|
||||
async def request_password_reset(
|
||||
self,
|
||||
email: str,
|
||||
ip_address: str,
|
||||
user_agent: str,
|
||||
redis: Redis
|
||||
) -> Tuple[bool, str]:
|
||||
self, email: str, ip_address: str, user_agent: str, redis: Redis
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
请求密码重置
|
||||
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
ip_address: 请求IP
|
||||
user_agent: 用户代理
|
||||
redis: Redis连接
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[success, message]
|
||||
"""
|
||||
email = email.lower().strip()
|
||||
|
||||
|
||||
async with with_db() as session:
|
||||
# 查找用户
|
||||
user_query = select(User).where(User.email == email)
|
||||
user_result = await session.exec(user_query)
|
||||
user = user_result.first()
|
||||
|
||||
|
||||
if not user:
|
||||
# 为了安全考虑,不告诉用户邮箱不存在,但仍然要检查频率限制
|
||||
rate_limit_key = self._get_rate_limit_key(email)
|
||||
@@ -80,15 +73,15 @@ class PasswordResetService:
|
||||
# 设置一个假的频率限制,防止恶意用户探测邮箱
|
||||
await redis.setex(rate_limit_key, 60, "1")
|
||||
return True, "如果该邮箱地址存在,您将收到密码重置邮件"
|
||||
|
||||
|
||||
# 检查频率限制
|
||||
rate_limit_key = self._get_rate_limit_key(email)
|
||||
if await redis.get(rate_limit_key):
|
||||
return False, "请求过于频繁,请稍后再试"
|
||||
|
||||
|
||||
# 生成重置验证码
|
||||
reset_code = self.generate_reset_code()
|
||||
|
||||
|
||||
# 存储验证码信息到Redis
|
||||
reset_code_key = self._get_reset_code_key(email)
|
||||
reset_data = {
|
||||
@@ -98,22 +91,18 @@ class PasswordResetService:
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"ip_address": ip_address,
|
||||
"user_agent": user_agent,
|
||||
"used": False
|
||||
"used": False,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# 先设置频率限制
|
||||
await redis.setex(rate_limit_key, 60, "1")
|
||||
# 存储验证码,10分钟过期
|
||||
await redis.setex(reset_code_key, 600, json.dumps(reset_data))
|
||||
|
||||
|
||||
# 发送重置邮件
|
||||
email_sent = await self.send_password_reset_email(
|
||||
email=email,
|
||||
code=reset_code,
|
||||
username=user.username
|
||||
)
|
||||
|
||||
email_sent = await self.send_password_reset_email(email=email, code=reset_code, username=user.username)
|
||||
|
||||
if email_sent:
|
||||
logger.info(f"[Password Reset] Sent reset code to user {user.id} ({email})")
|
||||
return True, "密码重置邮件已发送,请查收邮箱"
|
||||
@@ -123,17 +112,17 @@ class PasswordResetService:
|
||||
await redis.delete(rate_limit_key)
|
||||
logger.warning(f"[Password Reset] Email sending failed, cleaned up Redis data for {email}")
|
||||
return False, "邮件发送失败,请稍后重试"
|
||||
|
||||
except Exception as e:
|
||||
|
||||
except Exception:
|
||||
# Redis操作失败,清理可能的部分数据
|
||||
try:
|
||||
await redis.delete(reset_code_key)
|
||||
await redis.delete(rate_limit_key)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"[Password Reset] Redis operation failed: {e}")
|
||||
logger.exception("[Password Reset] Redis operation failed")
|
||||
return False, "服务暂时不可用,请稍后重试"
|
||||
|
||||
|
||||
async def send_password_reset_email(self, email: str, code: str, username: str) -> bool:
|
||||
"""发送密码重置邮件(使用邮件队列)"""
|
||||
try:
|
||||
@@ -206,15 +195,15 @@ class PasswordResetService:
|
||||
<h1>osu! 密码重置</h1>
|
||||
<p>Password Reset Request</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="content">
|
||||
<h2>你好 {username}!</h2>
|
||||
<p>我们收到了您的密码重置请求。如果这是您本人操作,请使用以下验证码重置密码:</p>
|
||||
|
||||
|
||||
<div class="code">{code}</div>
|
||||
|
||||
|
||||
<p>这个验证码将在 <strong>10 分钟后过期</strong>。</p>
|
||||
|
||||
|
||||
<div class="danger">
|
||||
<strong>⚠️ 安全提醒:</strong>
|
||||
<ul>
|
||||
@@ -224,19 +213,19 @@ class PasswordResetService:
|
||||
<li>建议设置一个强密码以保护您的账户安全</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
|
||||
<p>如果您有任何问题,请联系我们的支持团队。</p>
|
||||
|
||||
|
||||
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
|
||||
|
||||
|
||||
<h3>Hello {username}!</h3>
|
||||
<p>We received a request to reset your password. If this was you, please use the following verification code to reset your password:</p>
|
||||
|
||||
|
||||
<p>This verification code will expire in <strong>10 minutes</strong>.</p>
|
||||
|
||||
|
||||
<p><strong>Security Notice:</strong> Do not share this verification code with anyone. If you did not request a password reset, please ignore this email.</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="footer">
|
||||
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
|
||||
<p>This email was sent automatically, please do not reply.</p>
|
||||
@@ -244,8 +233,8 @@ class PasswordResetService:
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
# 纯文本内容(作为备用)
|
||||
plain_content = f"""
|
||||
你好 {username}!
|
||||
@@ -270,120 +259,123 @@ class PasswordResetService:
|
||||
# 添加邮件到队列
|
||||
subject = "密码重置 - Password Reset"
|
||||
metadata = {"type": "password_reset", "email": email, "code": code}
|
||||
|
||||
|
||||
await email_queue.enqueue_email(
|
||||
to_email=email,
|
||||
subject=subject,
|
||||
content=plain_content,
|
||||
html_content=html_content,
|
||||
metadata=metadata
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Password Reset] Enqueued reset code email to {email}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Password Reset] Failed to enqueue email: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def reset_password(
|
||||
self,
|
||||
email: str,
|
||||
reset_code: str,
|
||||
new_password: str,
|
||||
ip_address: str,
|
||||
redis: Redis
|
||||
) -> Tuple[bool, str]:
|
||||
redis: Redis,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
重置密码
|
||||
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
reset_code: 重置验证码
|
||||
new_password: 新密码
|
||||
ip_address: 请求IP
|
||||
redis: Redis连接
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[success, message]
|
||||
"""
|
||||
email = email.lower().strip()
|
||||
reset_code = reset_code.strip()
|
||||
|
||||
|
||||
async with with_db() as session:
|
||||
# 从Redis获取验证码数据
|
||||
reset_code_key = self._get_reset_code_key(email)
|
||||
reset_data_str = await redis.get(reset_code_key)
|
||||
|
||||
|
||||
if not reset_data_str:
|
||||
return False, "验证码无效或已过期"
|
||||
|
||||
|
||||
try:
|
||||
reset_data = json.loads(reset_data_str)
|
||||
except json.JSONDecodeError:
|
||||
return False, "验证码数据格式错误"
|
||||
|
||||
|
||||
# 验证验证码
|
||||
if reset_data.get("reset_code") != reset_code:
|
||||
return False, "验证码错误"
|
||||
|
||||
|
||||
# 检查是否已使用
|
||||
if reset_data.get("used", False):
|
||||
return False, "验证码已使用"
|
||||
|
||||
|
||||
# 验证邮箱匹配
|
||||
if reset_data.get("email") != email:
|
||||
return False, "邮箱地址不匹配"
|
||||
|
||||
|
||||
# 查找用户
|
||||
user_query = select(User).where(User.email == email)
|
||||
user_result = await session.exec(user_query)
|
||||
user = user_result.first()
|
||||
|
||||
|
||||
if not user:
|
||||
return False, "用户不存在"
|
||||
|
||||
|
||||
if user.id is None:
|
||||
return False, "用户ID无效"
|
||||
|
||||
|
||||
# 验证用户ID匹配
|
||||
if reset_data.get("user_id") != user.id:
|
||||
return False, "用户信息不匹配"
|
||||
|
||||
|
||||
# 密码强度检查
|
||||
if len(new_password) < 6:
|
||||
return False, "密码长度至少为6位"
|
||||
|
||||
|
||||
try:
|
||||
# 先标记验证码为已使用(在数据库操作之前)
|
||||
reset_data["used"] = True
|
||||
reset_data["used_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
|
||||
# 保存用户ID用于日志记录
|
||||
user_id = user.id
|
||||
|
||||
|
||||
# 更新用户密码
|
||||
password_hash = get_password_hash(new_password)
|
||||
user.pw_bcrypt = password_hash # 使用正确的字段名称 pw_bcrypt 而不是 password_hash
|
||||
|
||||
|
||||
# 提交数据库更改
|
||||
await session.commit()
|
||||
|
||||
|
||||
# 使该用户的所有现有令牌失效(使其他客户端登录失效)
|
||||
tokens_deleted = await invalidate_user_tokens(session, user_id)
|
||||
|
||||
|
||||
# 数据库操作成功后,更新Redis状态
|
||||
await redis.setex(reset_code_key, 300, json.dumps(reset_data)) # 保留5分钟用于日志记录
|
||||
|
||||
logger.info(f"[Password Reset] User {user_id} ({email}) successfully reset password from IP {ip_address}, invalidated {tokens_deleted} tokens")
|
||||
|
||||
logger.info(
|
||||
f"[Password Reset] User {user_id} ({email}) successfully reset password from IP {ip_address},"
|
||||
f" invalidated {tokens_deleted} tokens"
|
||||
)
|
||||
return True, "密码重置成功,所有设备已被登出"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 不要在异常处理中访问user.id,可能触发数据库操作
|
||||
user_id = reset_data.get("user_id", "未知")
|
||||
logger.error(f"[Password Reset] Failed to reset password for user {user_id}: {e}")
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# 数据库回滚时,需要恢复Redis中的验证码状态
|
||||
try:
|
||||
# 恢复验证码为未使用状态
|
||||
@@ -394,35 +386,39 @@ class PasswordResetService:
|
||||
"created_at": reset_data.get("created_at"),
|
||||
"ip_address": reset_data.get("ip_address"),
|
||||
"user_agent": reset_data.get("user_agent"),
|
||||
"used": False # 恢复为未使用状态
|
||||
"used": False, # 恢复为未使用状态
|
||||
}
|
||||
|
||||
|
||||
# 计算剩余的TTL时间
|
||||
created_at = datetime.fromisoformat(reset_data.get("created_at", ""))
|
||||
elapsed = (datetime.now(UTC) - created_at).total_seconds()
|
||||
remaining_ttl = max(0, 600 - int(elapsed)) # 600秒总过期时间
|
||||
|
||||
|
||||
if remaining_ttl > 0:
|
||||
await redis.setex(reset_code_key, remaining_ttl, json.dumps(original_reset_data))
|
||||
await redis.setex(
|
||||
reset_code_key,
|
||||
remaining_ttl,
|
||||
json.dumps(original_reset_data),
|
||||
)
|
||||
logger.info(f"[Password Reset] Restored Redis state after database rollback for {email}")
|
||||
else:
|
||||
# 如果已经过期,直接删除
|
||||
await redis.delete(reset_code_key)
|
||||
logger.info(f"[Password Reset] Removed expired reset code after database rollback for {email}")
|
||||
|
||||
|
||||
except Exception as redis_error:
|
||||
logger.error(f"[Password Reset] Failed to restore Redis state after rollback: {redis_error}")
|
||||
|
||||
|
||||
return False, "密码重置失败,请稍后重试"
|
||||
|
||||
|
||||
async def get_reset_attempts_count(self, email: str, redis: Redis) -> int:
|
||||
"""
|
||||
获取邮箱的重置尝试次数(通过检查频率限制键)
|
||||
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
redis: Redis连接
|
||||
|
||||
|
||||
Returns:
|
||||
尝试次数
|
||||
"""
|
||||
|
||||
@@ -34,9 +34,7 @@ class DateTimeEncoder(json.JSONEncoder):
|
||||
|
||||
def safe_json_dumps(data) -> str:
|
||||
"""安全的 JSON 序列化,支持 datetime 对象"""
|
||||
return json.dumps(
|
||||
data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")
|
||||
)
|
||||
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
|
||||
class RankingCacheService:
|
||||
@@ -225,9 +223,7 @@ class RankingCacheService:
|
||||
) -> None:
|
||||
"""刷新排行榜缓存"""
|
||||
if self._refreshing:
|
||||
logger.debug(
|
||||
f"Ranking cache refresh already in progress for {ruleset}:{type}"
|
||||
)
|
||||
logger.debug(f"Ranking cache refresh already in progress for {ruleset}:{type}")
|
||||
return
|
||||
|
||||
# 使用配置文件的设置
|
||||
@@ -253,9 +249,7 @@ class RankingCacheService:
|
||||
order_by = col(UserStatistics.ranked_score).desc()
|
||||
|
||||
if country:
|
||||
wheres.append(
|
||||
col(UserStatistics.user).has(country_code=country.upper())
|
||||
)
|
||||
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
||||
|
||||
# 获取总用户数用于统计
|
||||
total_users_query = select(UserStatistics).where(*wheres)
|
||||
@@ -277,11 +271,7 @@ class RankingCacheService:
|
||||
for page in range(1, max_pages + 1):
|
||||
try:
|
||||
statistics_list = await session.exec(
|
||||
select(UserStatistics)
|
||||
.where(*wheres)
|
||||
.order_by(order_by)
|
||||
.limit(50)
|
||||
.offset(50 * (page - 1))
|
||||
select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1))
|
||||
)
|
||||
|
||||
statistics_data = statistics_list.all()
|
||||
@@ -291,9 +281,7 @@ class RankingCacheService:
|
||||
# 转换为响应格式并确保正确序列化
|
||||
ranking_data = []
|
||||
for statistics in statistics_data:
|
||||
user_stats_resp = await UserStatisticsResp.from_db(
|
||||
statistics, session, None, include
|
||||
)
|
||||
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
|
||||
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题
|
||||
user_dict = json.loads(user_stats_resp.model_dump_json())
|
||||
ranking_data.append(user_dict)
|
||||
@@ -323,9 +311,7 @@ class RankingCacheService:
|
||||
) -> None:
|
||||
"""刷新地区排行榜缓存"""
|
||||
if self._refreshing:
|
||||
logger.debug(
|
||||
f"Country ranking cache refresh already in progress for {ruleset}"
|
||||
)
|
||||
logger.debug(f"Country ranking cache refresh already in progress for {ruleset}")
|
||||
return
|
||||
|
||||
if max_pages is None:
|
||||
@@ -449,9 +435,7 @@ class RankingCacheService:
|
||||
for country in top_countries:
|
||||
for mode in game_modes:
|
||||
for ranking_type in ranking_types:
|
||||
task = self.refresh_ranking_cache(
|
||||
session, mode, ranking_type, country
|
||||
)
|
||||
task = self.refresh_ranking_cache(session, mode, ranking_type, country)
|
||||
refresh_tasks.append(task)
|
||||
|
||||
# 地区排行榜
|
||||
@@ -493,9 +477,7 @@ class RankingCacheService:
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
deleted_keys += len(keys)
|
||||
logger.info(
|
||||
f"Invalidated {len(keys)} cache keys for {ruleset}:{type}"
|
||||
)
|
||||
logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}")
|
||||
elif ruleset:
|
||||
# 删除特定游戏模式的所有缓存
|
||||
patterns = [
|
||||
@@ -563,9 +545,7 @@ class RankingCacheService:
|
||||
"cached_user_rankings": len(ranking_keys),
|
||||
"cached_country_rankings": len(country_keys),
|
||||
"total_cached_rankings": len(total_keys),
|
||||
"estimated_total_size_mb": (
|
||||
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
|
||||
),
|
||||
"estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0),
|
||||
"refreshing": self._refreshing,
|
||||
}
|
||||
except Exception as e:
|
||||
|
||||
@@ -35,12 +35,8 @@ async def recalculate():
|
||||
fetcher = await get_fetcher()
|
||||
redis = get_redis()
|
||||
for mode in GameMode:
|
||||
await session.execute(
|
||||
delete(PPBestScore).where(col(PPBestScore.gamemode) == mode)
|
||||
)
|
||||
await session.execute(
|
||||
delete(BestScore).where(col(BestScore.gamemode) == mode)
|
||||
)
|
||||
await session.execute(delete(PPBestScore).where(col(PPBestScore.gamemode) == mode))
|
||||
await session.execute(delete(BestScore).where(col(BestScore.gamemode) == mode))
|
||||
await session.commit()
|
||||
logger.info(f"Recalculating for mode: {mode}")
|
||||
statistics_list = (
|
||||
@@ -53,32 +49,21 @@ async def recalculate():
|
||||
).all()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_recalculate_pp(
|
||||
statistics.user_id, statistics.mode, session, fetcher, redis
|
||||
)
|
||||
_recalculate_pp(statistics.user_id, statistics.mode, session, fetcher, redis)
|
||||
for statistics in statistics_list
|
||||
]
|
||||
)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_recalculate_best_score(
|
||||
statistics.user_id, statistics.mode, session
|
||||
)
|
||||
_recalculate_best_score(statistics.user_id, statistics.mode, session)
|
||||
for statistics in statistics_list
|
||||
]
|
||||
)
|
||||
await session.commit()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_recalculate_statistics(statistics, session)
|
||||
for statistics in statistics_list
|
||||
]
|
||||
)
|
||||
await asyncio.gather(*[_recalculate_statistics(statistics, session) for statistics in statistics_list])
|
||||
|
||||
await session.commit()
|
||||
logger.success(
|
||||
f"Recalculated for mode: {mode}, total users: {len(statistics_list)}"
|
||||
)
|
||||
logger.success(f"Recalculated for mode: {mode}, total users: {len(statistics_list)}")
|
||||
|
||||
|
||||
async def _recalculate_pp(
|
||||
@@ -104,9 +89,7 @@ async def _recalculate_pp(
|
||||
beatmap_id = score.beatmap_id
|
||||
while time > 0:
|
||||
try:
|
||||
db_beatmap = await Beatmap.get_or_fetch(
|
||||
session, fetcher, bid=beatmap_id
|
||||
)
|
||||
db_beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=beatmap_id)
|
||||
except HTTPError:
|
||||
time -= 1
|
||||
await asyncio.sleep(2)
|
||||
@@ -116,9 +99,7 @@ async def _recalculate_pp(
|
||||
score.pp = 0
|
||||
return
|
||||
try:
|
||||
pp = await pre_fetch_and_calculate_pp(
|
||||
score, beatmap_id, session, redis, fetcher
|
||||
)
|
||||
pp = await pre_fetch_and_calculate_pp(score, beatmap_id, session, redis, fetcher)
|
||||
score.pp = pp
|
||||
if pp == 0:
|
||||
return
|
||||
@@ -138,15 +119,10 @@ async def _recalculate_pp(
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Error calculating pp for score {score.id} on beatmap {beatmap_id}"
|
||||
)
|
||||
logger.exception(f"Error calculating pp for score {score.id} on beatmap {beatmap_id}")
|
||||
return
|
||||
if time <= 0:
|
||||
logger.warning(
|
||||
f"Failed to fetch beatmap {beatmap_id} after 10 attempts, "
|
||||
"retrying later..."
|
||||
)
|
||||
logger.warning(f"Failed to fetch beatmap {beatmap_id} after 10 attempts, retrying later...")
|
||||
return score
|
||||
|
||||
while len(scores) > 0:
|
||||
@@ -271,9 +247,7 @@ async def _recalculate_statistics(statistics: UserStatistics, session: AsyncSess
|
||||
statistics.count_100 += score.n100 + score.nkatu
|
||||
statistics.count_50 += score.n50
|
||||
statistics.count_miss += score.nmiss
|
||||
statistics.total_hits += (
|
||||
score.n300 + score.ngeki + score.n100 + score.nkatu + score.n50
|
||||
)
|
||||
statistics.total_hits += score.n300 + score.ngeki + score.n100 + score.nkatu + score.n50
|
||||
|
||||
if ranked and score.passed:
|
||||
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
|
||||
|
||||
@@ -18,6 +18,7 @@ from app.database.chat import ChatMessage, ChatMessageResp, MessageType
|
||||
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
|
||||
from app.dependencies.database import get_redis_message, with_db
|
||||
from app.log import logger
|
||||
from app.utils import bg_tasks
|
||||
|
||||
|
||||
class RedisMessageSystem:
|
||||
@@ -67,12 +68,11 @@ class RedisMessageSystem:
|
||||
|
||||
# 获取频道类型以判断是否需要存储到数据库
|
||||
async with with_db() as session:
|
||||
from app.database.chat import ChatChannel, ChannelType
|
||||
from app.database.chat import ChannelType, ChatChannel
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
channel_result = await session.exec(
|
||||
select(ChatChannel.type).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
|
||||
channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id))
|
||||
channel_type = channel_result.first()
|
||||
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
|
||||
|
||||
@@ -132,17 +132,14 @@ class RedisMessageSystem:
|
||||
|
||||
if is_multiplayer:
|
||||
logger.info(
|
||||
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}, will not be persisted to database"
|
||||
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id},"
|
||||
" will not be persisted to database"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Message {message_id} sent to Redis cache for channel {channel_id}"
|
||||
)
|
||||
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
||||
return response
|
||||
|
||||
async def get_messages(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[ChatMessageResp]:
|
||||
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]:
|
||||
"""
|
||||
获取频道消息 - 优先从 Redis 获取最新消息
|
||||
|
||||
@@ -166,9 +163,7 @@ class RedisMessageSystem:
|
||||
# 获取发送者信息
|
||||
sender = await session.get(User, msg_data["sender_id"])
|
||||
if sender:
|
||||
user_resp = await UserResp.from_db(
|
||||
sender, session, RANKING_INCLUDES
|
||||
)
|
||||
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
|
||||
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
@@ -223,39 +218,28 @@ class RedisMessageSystem:
|
||||
async def _generate_message_id(self, channel_id: int) -> int:
|
||||
"""生成唯一的消息ID - 确保全局唯一且严格递增"""
|
||||
# 使用全局计数器确保所有频道的消息ID都是严格递增的
|
||||
message_id = await self._redis_exec(
|
||||
self.redis.incr, "global_message_id_counter"
|
||||
)
|
||||
message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter")
|
||||
|
||||
# 同时更新频道的最后消息ID,用于客户端状态同步
|
||||
await self._redis_exec(
|
||||
self.redis.set, f"channel:{channel_id}:last_msg_id", message_id
|
||||
)
|
||||
await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id)
|
||||
|
||||
return message_id
|
||||
|
||||
async def _store_to_redis(
|
||||
self, message_id: int, channel_id: int, message_data: dict[str, Any]
|
||||
):
|
||||
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]):
|
||||
"""存储消息到 Redis"""
|
||||
try:
|
||||
# 检查是否是多人房间消息
|
||||
is_multiplayer = message_data.get("is_multiplayer", False)
|
||||
|
||||
|
||||
# 存储消息数据
|
||||
await self._redis_exec(
|
||||
self.redis.hset,
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
mapping={
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)
|
||||
for k, v in message_data.items()
|
||||
},
|
||||
mapping={k: json.dumps(v) if isinstance(v, dict | list) else str(v) for k, v in message_data.items()},
|
||||
)
|
||||
|
||||
# 设置消息过期时间(7天)
|
||||
await self._redis_exec(
|
||||
self.redis.expire, f"msg:{channel_id}:{message_id}", 604800
|
||||
)
|
||||
await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800)
|
||||
|
||||
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
|
||||
channel_messages_key = f"channel:{channel_id}:messages"
|
||||
@@ -264,14 +248,10 @@ class RedisMessageSystem:
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, channel_messages_key)
|
||||
if key_type and key_type != "zset":
|
||||
logger.warning(
|
||||
f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}"
|
||||
)
|
||||
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
except Exception as type_check_error:
|
||||
logger.warning(
|
||||
f"Failed to check key type for {channel_messages_key}: {type_check_error}"
|
||||
)
|
||||
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
||||
# 如果检查失败,直接删除键以确保清理
|
||||
await self._redis_exec(self.redis.delete, channel_messages_key)
|
||||
|
||||
@@ -283,15 +263,11 @@ class RedisMessageSystem:
|
||||
)
|
||||
|
||||
# 保持频道消息列表大小(最多1000条)
|
||||
await self._redis_exec(
|
||||
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
|
||||
)
|
||||
await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001)
|
||||
|
||||
# 只有非多人房间消息才添加到待持久化队列
|
||||
if not is_multiplayer:
|
||||
await self._redis_exec(
|
||||
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
|
||||
)
|
||||
await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}")
|
||||
logger.debug(f"Message {message_id} added to persistence queue")
|
||||
else:
|
||||
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
|
||||
@@ -300,9 +276,7 @@ class RedisMessageSystem:
|
||||
logger.error(f"Failed to store message to Redis: {e}")
|
||||
raise
|
||||
|
||||
async def _get_from_redis(
|
||||
self, channel_id: int, limit: int = 50, since: int = 0
|
||||
) -> list[dict[str, Any]]:
|
||||
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]:
|
||||
"""从 Redis 获取消息"""
|
||||
try:
|
||||
# 获取消息键列表,按消息ID排序
|
||||
@@ -340,9 +314,7 @@ class RedisMessageSystem:
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
if k in ["grade_counts", "level"] or v.startswith(
|
||||
("{", "[")
|
||||
):
|
||||
if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
|
||||
message_data[k] = json.loads(v)
|
||||
elif k in ["message_id", "channel_id", "sender_id"]:
|
||||
message_data[k] = int(v)
|
||||
@@ -368,9 +340,7 @@ class RedisMessageSystem:
|
||||
logger.error(f"Failed to get messages from Redis: {e}")
|
||||
return []
|
||||
|
||||
async def _backfill_from_database(
|
||||
self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int
|
||||
):
|
||||
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int):
|
||||
"""从数据库补充历史消息"""
|
||||
try:
|
||||
# 找到最小的消息ID
|
||||
@@ -404,9 +374,7 @@ class RedisMessageSystem:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill from database: {e}")
|
||||
|
||||
async def _get_from_database_only(
|
||||
self, channel_id: int, limit: int, since: int
|
||||
) -> list[ChatMessageResp]:
|
||||
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]:
|
||||
"""仅从数据库获取消息(回退方案)"""
|
||||
try:
|
||||
async with with_db() as session:
|
||||
@@ -417,20 +385,14 @@ class RedisMessageSystem:
|
||||
if since > 0:
|
||||
# 获取指定ID之后的消息,按ID正序
|
||||
query = query.where(col(ChatMessage.message_id) > since)
|
||||
query = query.order_by(col(ChatMessage.message_id).asc()).limit(
|
||||
limit
|
||||
)
|
||||
query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit)
|
||||
else:
|
||||
# 获取最新消息,按ID倒序(最新的在前面)
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(
|
||||
limit
|
||||
)
|
||||
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
|
||||
|
||||
messages = (await session.exec(query)).all()
|
||||
|
||||
results = [
|
||||
await ChatMessageResp.from_db(msg, session) for msg in messages
|
||||
]
|
||||
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
|
||||
# 如果是 since > 0,保持正序;否则反转为时间正序
|
||||
if since == 0:
|
||||
@@ -451,9 +413,7 @@ class RedisMessageSystem:
|
||||
# 获取待处理的消息
|
||||
message_keys = []
|
||||
for _ in range(self.max_batch_size):
|
||||
key = await self._redis_exec(
|
||||
self.redis.brpop, ["pending_messages"], timeout=1
|
||||
)
|
||||
key = await self._redis_exec(self.redis.brpop, ["pending_messages"], timeout=1)
|
||||
if key:
|
||||
# key 是 (queue_name, value) 的元组
|
||||
value = key[1]
|
||||
@@ -483,9 +443,7 @@ class RedisMessageSystem:
|
||||
channel_id, message_id = map(int, key.split(":"))
|
||||
|
||||
# 从 Redis 获取消息数据
|
||||
raw_data = await self._redis_exec(
|
||||
self.redis.hgetall, f"msg:{channel_id}:{message_id}"
|
||||
)
|
||||
raw_data = await self._redis_exec(self.redis.hgetall, f"msg:{channel_id}:{message_id}")
|
||||
|
||||
if not raw_data:
|
||||
continue
|
||||
@@ -546,9 +504,7 @@ class RedisMessageSystem:
|
||||
# 提交批次
|
||||
try:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"Batch of {len(message_keys)} messages committed to database"
|
||||
)
|
||||
logger.info(f"Batch of {len(message_keys)} messages committed to database")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to commit message batch: {e}")
|
||||
await session.rollback()
|
||||
@@ -559,7 +515,7 @@ class RedisMessageSystem:
|
||||
self._running = True
|
||||
self._batch_timer = asyncio.create_task(self._batch_persist_to_database())
|
||||
# 启动时初始化消息ID计数器
|
||||
asyncio.create_task(self._initialize_message_counter())
|
||||
bg_tasks.add_task(self._initialize_message_counter)
|
||||
logger.info("Redis message system started")
|
||||
|
||||
async def _initialize_message_counter(self):
|
||||
@@ -576,27 +532,19 @@ class RedisMessageSystem:
|
||||
max_id = result.one() or 0
|
||||
|
||||
# 检查 Redis 中的计数器值
|
||||
current_counter = await self._redis_exec(
|
||||
self.redis.get, "global_message_id_counter"
|
||||
)
|
||||
current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter")
|
||||
current_counter = int(current_counter) if current_counter else 0
|
||||
|
||||
# 设置计数器为两者中的最大值
|
||||
initial_counter = max(max_id, current_counter)
|
||||
await self._redis_exec(
|
||||
self.redis.set, "global_message_id_counter", initial_counter
|
||||
)
|
||||
await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter)
|
||||
|
||||
logger.info(
|
||||
f"Initialized global message ID counter to {initial_counter}"
|
||||
)
|
||||
logger.info(f"Initialized global message ID counter to {initial_counter}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize message counter: {e}")
|
||||
# 如果初始化失败,设置一个安全的起始值
|
||||
await self._redis_exec(
|
||||
self.redis.setnx, "global_message_id_counter", 1000000
|
||||
)
|
||||
await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000)
|
||||
|
||||
async def _cleanup_redis_keys(self):
|
||||
"""清理可能存在问题的 Redis 键"""
|
||||
@@ -612,9 +560,7 @@ class RedisMessageSystem:
|
||||
try:
|
||||
key_type = await self._redis_exec(self.redis.type, key)
|
||||
if key_type and key_type != "zset":
|
||||
logger.warning(
|
||||
f"Cleaning up Redis key {key} with wrong type: {key_type}"
|
||||
)
|
||||
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}")
|
||||
await self._redis_exec(self.redis.delete, key)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")
|
||||
|
||||
@@ -14,15 +14,11 @@ from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
async def create_playlist_room_from_api(
|
||||
session: AsyncSession, room: APIUploadedRoom, host_id: int
|
||||
) -> Room:
|
||||
async def create_playlist_room_from_api(session: AsyncSession, room: APIUploadedRoom, host_id: int) -> Room:
|
||||
db_room = room.to_room()
|
||||
db_room.host_id = host_id
|
||||
db_room.starts_at = datetime.now(UTC)
|
||||
db_room.ends_at = db_room.starts_at + timedelta(
|
||||
minutes=db_room.duration if db_room.duration is not None else 0
|
||||
)
|
||||
db_room.ends_at = db_room.starts_at + timedelta(minutes=db_room.duration if db_room.duration is not None else 0)
|
||||
session.add(db_room)
|
||||
await session.commit()
|
||||
await session.refresh(db_room)
|
||||
@@ -87,13 +83,9 @@ async def create_playlist_room(
|
||||
return db_room
|
||||
|
||||
|
||||
async def add_playlists_to_room(
|
||||
session: AsyncSession, room_id: int, playlist: list[Playlist], owner_id: int
|
||||
):
|
||||
async def add_playlists_to_room(session: AsyncSession, room_id: int, playlist: list[Playlist], owner_id: int):
|
||||
for item in playlist:
|
||||
if not (
|
||||
await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap)))
|
||||
).first():
|
||||
if not (await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap)))).first():
|
||||
fetcher = await get_fetcher()
|
||||
await Beatmap.get_or_fetch(session, fetcher, item.beatmap_id)
|
||||
item.id = await Playlist.get_next_id_for_room(room_id, session)
|
||||
|
||||
@@ -4,15 +4,15 @@ API 状态管理 - 模拟 osu! 的 APIState 和会话管理
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class APIState(str, Enum):
|
||||
"""API 连接状态,对应 osu! 的 APIState"""
|
||||
|
||||
OFFLINE = "offline"
|
||||
CONNECTING = "connecting"
|
||||
REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证
|
||||
@@ -22,6 +22,7 @@ class APIState(str, Enum):
|
||||
|
||||
class UserSession(BaseModel):
|
||||
"""用户会话信息"""
|
||||
|
||||
user_id: int
|
||||
username: str
|
||||
email: str
|
||||
@@ -38,10 +39,10 @@ class UserSession(BaseModel):
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._sessions: dict[str, UserSession] = {}
|
||||
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_id: int,
|
||||
@@ -49,19 +50,19 @@ class SessionManager:
|
||||
email: str,
|
||||
ip_address: str,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False
|
||||
is_new_location: bool = False,
|
||||
) -> UserSession:
|
||||
"""创建新的用户会话"""
|
||||
import secrets
|
||||
|
||||
|
||||
session_token = secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
# 根据是否为新位置决定初始状态
|
||||
if is_new_location:
|
||||
state = APIState.REQUIRES_SECOND_FACTOR_AUTH
|
||||
else:
|
||||
state = APIState.ONLINE
|
||||
|
||||
|
||||
session = UserSession(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
@@ -71,33 +72,33 @@ class SessionManager:
|
||||
requires_verification=is_new_location,
|
||||
ip_address=ip_address,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location
|
||||
is_new_location=is_new_location,
|
||||
)
|
||||
|
||||
|
||||
self._sessions[session_token] = session
|
||||
return session
|
||||
|
||||
|
||||
def get_session(self, session_token: str) -> UserSession | None:
|
||||
"""获取会话"""
|
||||
return self._sessions.get(session_token)
|
||||
|
||||
|
||||
def update_session_state(self, session_token: str, state: APIState):
|
||||
"""更新会话状态"""
|
||||
if session_token in self._sessions:
|
||||
self._sessions[session_token].state = state
|
||||
|
||||
|
||||
def mark_verification_sent(self, session_token: str):
|
||||
"""标记验证邮件已发送"""
|
||||
if session_token in self._sessions:
|
||||
session = self._sessions[session_token]
|
||||
session.verification_sent = True
|
||||
session.last_verification_attempt = datetime.now()
|
||||
|
||||
|
||||
def increment_failed_attempts(self, session_token: str):
|
||||
"""增加失败尝试次数"""
|
||||
if session_token in self._sessions:
|
||||
self._sessions[session_token].failed_attempts += 1
|
||||
|
||||
|
||||
def verify_session(self, session_token: str) -> bool:
|
||||
"""验证会话成功"""
|
||||
if session_token in self._sessions:
|
||||
@@ -106,11 +107,11 @@ class SessionManager:
|
||||
session.requires_verification = False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def remove_session(self, session_token: str):
|
||||
"""移除会话"""
|
||||
self._sessions.pop(session_token, None)
|
||||
|
||||
|
||||
def cleanup_expired_sessions(self):
|
||||
"""清理过期会话"""
|
||||
# 这里可以实现清理逻辑
|
||||
|
||||
@@ -26,14 +26,12 @@ async def cleanup_stale_online_users() -> tuple[int, int]:
|
||||
|
||||
# 检查在线用户的最后活动时间
|
||||
current_time = datetime.utcnow()
|
||||
stale_threshold = current_time - timedelta(hours=2) # 2小时无活动视为过期
|
||||
stale_threshold = current_time - timedelta(hours=2) # 2小时无活动视为过期 # noqa: F841
|
||||
|
||||
# 对于在线用户,我们检查metadata在线标记
|
||||
stale_online_users = []
|
||||
for user_id in online_users:
|
||||
user_id_str = (
|
||||
user_id.decode() if isinstance(user_id, bytes) else str(user_id)
|
||||
)
|
||||
user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id)
|
||||
metadata_key = f"metadata:online:{user_id_str}"
|
||||
|
||||
# 如果metadata标记不存在,说明用户已经离线
|
||||
@@ -42,9 +40,7 @@ async def cleanup_stale_online_users() -> tuple[int, int]:
|
||||
|
||||
# 清理过期的在线用户
|
||||
if stale_online_users:
|
||||
await _redis_exec(
|
||||
redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users
|
||||
)
|
||||
await _redis_exec(redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users)
|
||||
online_cleaned = len(stale_online_users)
|
||||
logger.info(f"Cleaned {online_cleaned} stale online users")
|
||||
|
||||
@@ -52,22 +48,19 @@ async def cleanup_stale_online_users() -> tuple[int, int]:
|
||||
# 只有当用户明确不在任何hub连接中时才移除
|
||||
stale_playing_users = []
|
||||
for user_id in playing_users:
|
||||
user_id_str = (
|
||||
user_id.decode() if isinstance(user_id, bytes) else str(user_id)
|
||||
)
|
||||
user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id)
|
||||
metadata_key = f"metadata:online:{user_id_str}"
|
||||
|
||||
|
||||
# 只有当metadata在线标记完全不存在且用户也不在在线列表中时,
|
||||
# 才认为用户真正离线
|
||||
if (not await redis_async.exists(metadata_key) and
|
||||
user_id_str not in [u.decode() if isinstance(u, bytes) else str(u) for u in online_users]):
|
||||
if not await redis_async.exists(metadata_key) and user_id_str not in [
|
||||
u.decode() if isinstance(u, bytes) else str(u) for u in online_users
|
||||
]:
|
||||
stale_playing_users.append(user_id_str)
|
||||
|
||||
# 清理过期的游玩用户
|
||||
if stale_playing_users:
|
||||
await _redis_exec(
|
||||
redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users
|
||||
)
|
||||
await _redis_exec(redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users)
|
||||
playing_cleaned = len(stale_playing_users)
|
||||
logger.info(f"Cleaned {playing_cleaned} stale playing users")
|
||||
|
||||
|
||||
@@ -61,26 +61,29 @@ class StatsScheduler:
|
||||
try:
|
||||
# 计算下次区间结束时间
|
||||
now = datetime.utcnow()
|
||||
|
||||
|
||||
# 计算当前区间的结束时间
|
||||
current_minute = (now.minute // 30) * 30
|
||||
current_interval_end = now.replace(minute=current_minute, second=0, microsecond=0) + timedelta(minutes=30)
|
||||
|
||||
current_interval_end = now.replace(minute=current_minute, second=0, microsecond=0) + timedelta(
|
||||
minutes=30
|
||||
)
|
||||
|
||||
# 如果当前时间已经超过了当前区间结束时间,说明需要等待下一个区间结束
|
||||
if now >= current_interval_end:
|
||||
current_interval_end += timedelta(minutes=30)
|
||||
|
||||
|
||||
# 计算需要等待的时间
|
||||
sleep_seconds = (current_interval_end - now).total_seconds()
|
||||
|
||||
|
||||
# 添加小的缓冲时间,确保区间真正结束后再处理
|
||||
sleep_seconds += 10 # 额外等待10秒
|
||||
|
||||
|
||||
# 限制等待时间范围
|
||||
sleep_seconds = max(min(sleep_seconds, 32 * 60), 10)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Next interval finalization in {sleep_seconds / 60:.1f} minutes at {current_interval_end.strftime('%H:%M:%S')}"
|
||||
f"Next interval finalization in {sleep_seconds / 60:.1f} "
|
||||
f"minutes at {current_interval_end.strftime('%H:%M:%S')}"
|
||||
)
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
|
||||
@@ -137,7 +140,8 @@ class StatsScheduler:
|
||||
online_cleaned, playing_cleaned = await cleanup_stale_online_users()
|
||||
if online_cleaned > 0 or playing_cleaned > 0:
|
||||
logger.info(
|
||||
f"Initial cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users"
|
||||
f"Initial cleanup: removed {online_cleaned} stale online users,"
|
||||
f" {playing_cleaned} stale playing users"
|
||||
)
|
||||
|
||||
await refresh_redis_key_expiry()
|
||||
|
||||
@@ -31,9 +31,7 @@ class RedisSubscriber:
|
||||
|
||||
async def listen(self):
|
||||
while True:
|
||||
message = await self.pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=None
|
||||
)
|
||||
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=None)
|
||||
if message is not None and message["type"] == "message":
|
||||
matched_handlers: list[Callable[[str, str], Awaitable[Any]]] = []
|
||||
|
||||
@@ -53,10 +51,7 @@ class RedisSubscriber:
|
||||
|
||||
if matched_handlers:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
handler(message["channel"], message["data"])
|
||||
for handler in matched_handlers
|
||||
]
|
||||
*[handler(message["channel"], message["data"]) for handler in matched_handlers]
|
||||
)
|
||||
|
||||
def start(self):
|
||||
|
||||
@@ -46,12 +46,7 @@ class ScoreSubscriber(RedisSubscriber):
|
||||
return
|
||||
async with with_db() as session:
|
||||
score = await session.get(Score, score_id)
|
||||
if (
|
||||
not score
|
||||
or not score.passed
|
||||
or score.room_id is None
|
||||
or score.playlist_item_id is None
|
||||
):
|
||||
if not score or not score.passed or score.room_id is None or score.playlist_item_id is None:
|
||||
return
|
||||
if not self.room_subscriber.get(score.room_id, []):
|
||||
return
|
||||
|
||||
@@ -47,17 +47,13 @@ class UserCacheService:
|
||||
self._refreshing = False
|
||||
self._background_tasks: set = set()
|
||||
|
||||
def _get_v1_user_cache_key(
|
||||
self, user_id: int, ruleset: GameMode | None = None
|
||||
) -> str:
|
||||
def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str:
|
||||
"""生成 V1 用户缓存键"""
|
||||
if ruleset:
|
||||
return f"v1_user:{user_id}:ruleset:{ruleset}"
|
||||
return f"v1_user:{user_id}"
|
||||
|
||||
async def get_v1_user_from_cache(
|
||||
self, user_id: int, ruleset: GameMode | None = None
|
||||
) -> dict | None:
|
||||
async def get_v1_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> dict | None:
|
||||
"""从缓存获取 V1 用户信息"""
|
||||
try:
|
||||
cache_key = self._get_v1_user_cache_key(user_id, ruleset)
|
||||
@@ -96,9 +92,7 @@ class UserCacheService:
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
logger.info(
|
||||
f"Invalidated {len(keys)} V1 cache entries for user {user_id}"
|
||||
)
|
||||
logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating V1 user cache: {e}")
|
||||
|
||||
@@ -126,9 +120,7 @@ class UserCacheService:
|
||||
"""生成用户谱面集缓存键"""
|
||||
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
|
||||
|
||||
async def get_user_from_cache(
|
||||
self, user_id: int, ruleset: GameMode | None = None
|
||||
) -> UserResp | None:
|
||||
async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserResp | None:
|
||||
"""从缓存获取用户信息"""
|
||||
try:
|
||||
cache_key = self._get_user_cache_key(user_id, ruleset)
|
||||
@@ -172,14 +164,10 @@ class UserCacheService:
|
||||
) -> list[ScoreResp] | None:
|
||||
"""从缓存获取用户成绩"""
|
||||
try:
|
||||
cache_key = self._get_user_scores_cache_key(
|
||||
user_id, score_type, mode, limit, offset
|
||||
)
|
||||
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(
|
||||
f"User scores cache hit for user {user_id}, type {score_type}"
|
||||
)
|
||||
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
|
||||
data = json.loads(cached_data)
|
||||
return [ScoreResp(**score_data) for score_data in data]
|
||||
return None
|
||||
@@ -201,16 +189,12 @@ class UserCacheService:
|
||||
try:
|
||||
if expire_seconds is None:
|
||||
expire_seconds = settings.user_scores_cache_expire_seconds
|
||||
cache_key = self._get_user_scores_cache_key(
|
||||
user_id, score_type, mode, limit, offset
|
||||
)
|
||||
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
|
||||
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
|
||||
scores_json_list = [score.model_dump_json() for score in scores]
|
||||
cached_data = f"[{','.join(scores_json_list)}]"
|
||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||
logger.debug(
|
||||
f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s"
|
||||
)
|
||||
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching user scores: {e}")
|
||||
|
||||
@@ -219,14 +203,10 @@ class UserCacheService:
|
||||
) -> list[Any] | None:
|
||||
"""从缓存获取用户谱面集"""
|
||||
try:
|
||||
cache_key = self._get_user_beatmapsets_cache_key(
|
||||
user_id, beatmapset_type, limit, offset
|
||||
)
|
||||
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(
|
||||
f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}"
|
||||
)
|
||||
logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}")
|
||||
return json.loads(cached_data)
|
||||
return None
|
||||
except Exception as e:
|
||||
@@ -246,9 +226,7 @@ class UserCacheService:
|
||||
try:
|
||||
if expire_seconds is None:
|
||||
expire_seconds = settings.user_beatmapsets_cache_expire_seconds
|
||||
cache_key = self._get_user_beatmapsets_cache_key(
|
||||
user_id, beatmapset_type, limit, offset
|
||||
)
|
||||
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
|
||||
# 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps
|
||||
serialized_beatmapsets = []
|
||||
for bms in beatmapsets:
|
||||
@@ -258,9 +236,7 @@ class UserCacheService:
|
||||
serialized_beatmapsets.append(safe_json_dumps(bms))
|
||||
cached_data = f"[{','.join(serialized_beatmapsets)}]"
|
||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||
logger.debug(
|
||||
f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s"
|
||||
)
|
||||
logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s")
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching user beatmapsets: {e}")
|
||||
|
||||
@@ -276,9 +252,7 @@ class UserCacheService:
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating user cache: {e}")
|
||||
|
||||
async def invalidate_user_scores_cache(
|
||||
self, user_id: int, mode: GameMode | None = None
|
||||
):
|
||||
async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None):
|
||||
"""使用户成绩缓存失效"""
|
||||
try:
|
||||
# 删除用户成绩相关缓存
|
||||
@@ -287,9 +261,7 @@ class UserCacheService:
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
logger.info(
|
||||
f"Invalidated {len(keys)} score cache entries for user {user_id}"
|
||||
)
|
||||
logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating user scores cache: {e}")
|
||||
|
||||
@@ -303,9 +275,7 @@ class UserCacheService:
|
||||
logger.info(f"Preloading cache for {len(user_ids)} users")
|
||||
|
||||
# 批量获取用户
|
||||
users = (
|
||||
await session.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||
).all()
|
||||
users = (await session.exec(select(User).where(col(User.id).in_(user_ids)))).all()
|
||||
|
||||
# 串行缓存用户信息,避免并发数据库访问问题
|
||||
cached_count = 0
|
||||
@@ -332,9 +302,7 @@ class UserCacheService:
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching single user {user.id}: {e}")
|
||||
|
||||
async def refresh_user_cache_on_score_submit(
|
||||
self, session: AsyncSession, user_id: int, mode: GameMode
|
||||
):
|
||||
async def refresh_user_cache_on_score_submit(self, session: AsyncSession, user_id: int, mode: GameMode):
|
||||
"""成绩提交后刷新用户缓存"""
|
||||
try:
|
||||
# 使相关缓存失效(包括 v1 和 v2)
|
||||
@@ -367,24 +335,12 @@ class UserCacheService:
|
||||
continue
|
||||
|
||||
return {
|
||||
"cached_users": len(
|
||||
[
|
||||
k
|
||||
for k in user_keys
|
||||
if ":scores:" not in k and ":beatmapsets:" not in k
|
||||
]
|
||||
),
|
||||
"cached_v1_users": len(
|
||||
[k for k in v1_user_keys if ":scores:" not in k]
|
||||
),
|
||||
"cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]),
|
||||
"cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]),
|
||||
"cached_user_scores": len([k for k in user_keys if ":scores:" in k]),
|
||||
"cached_user_beatmapsets": len(
|
||||
[k for k in user_keys if ":beatmapsets:" in k]
|
||||
),
|
||||
"cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]),
|
||||
"total_cached_entries": len(all_keys),
|
||||
"estimated_total_size_mb": (
|
||||
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
|
||||
),
|
||||
"estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0),
|
||||
"refreshing": self._refreshing,
|
||||
}
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user