refactor(project): make pyright & ruff happy

This commit is contained in:
MingxuanGame
2025-08-22 08:21:52 +00:00
parent 3b1d7a2234
commit 598fcc8b38
157 changed files with 2382 additions and 4590 deletions

View File

@@ -63,10 +63,7 @@ class BeatmapCacheService:
if preload_tasks:
results = await asyncio.gather(*preload_tasks, return_exceptions=True)
success_count = sum(1 for r in results if r is True)
logger.info(
f"Preloaded {success_count}/{len(preload_tasks)} "
f"beatmaps successfully"
)
logger.info(f"Preloaded {success_count}/{len(preload_tasks)} beatmaps successfully")
except Exception as e:
logger.error(f"Error during beatmap preloading: {e}")
@@ -119,9 +116,7 @@ class BeatmapCacheService:
return {
"cached_beatmaps": len(keys),
"estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
),
"estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0),
"preloading": self._preloading,
}
except Exception as e:
@@ -155,9 +150,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS
return _cache_service
async def schedule_preload_task(
session: AsyncSession, redis: Redis, fetcher: "Fetcher"
):
async def schedule_preload_task(session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
"""
定时预加载任务
"""

View File

@@ -192,22 +192,16 @@ class BeatmapDownloadService:
healthy_endpoints.sort(key=lambda x: x.priority)
return healthy_endpoints
def get_download_url(
self, beatmapset_id: int, no_video: bool, is_china: bool
) -> str:
def get_download_url(self, beatmapset_id: int, no_video: bool, is_china: bool) -> str:
"""获取下载URL带负载均衡和故障转移"""
healthy_endpoints = self.get_healthy_endpoints(is_china)
if not healthy_endpoints:
# 如果没有健康的端点,记录错误并回退到所有端点中优先级最高的
logger.error(f"No healthy endpoints available for is_china={is_china}")
endpoints = (
self.china_endpoints if is_china else self.international_endpoints
)
endpoints = self.china_endpoints if is_china else self.international_endpoints
if not endpoints:
raise HTTPException(
status_code=503, detail="No download endpoints available"
)
raise HTTPException(status_code=503, detail="No download endpoints available")
endpoint = min(endpoints, key=lambda x: x.priority)
else:
# 使用第一个健康的端点(已按优先级排序)
@@ -218,9 +212,7 @@ class BeatmapDownloadService:
video_type = "novideo" if no_video else "full"
return endpoint.url_template.format(type=video_type, sid=beatmapset_id)
elif endpoint.name == "Nerinyan":
return endpoint.url_template.format(
sid=beatmapset_id, no_video="true" if no_video else "false"
)
return endpoint.url_template.format(sid=beatmapset_id, no_video="true" if no_video else "false")
elif endpoint.name == "OsuDirect":
# osu.direct 似乎没有no_video参数直接使用基础URL
return endpoint.url_template.format(sid=beatmapset_id)
@@ -239,9 +231,7 @@ class BeatmapDownloadService:
for name, status in self.endpoint_status.items():
status_info["endpoints"][name] = {
"healthy": status.is_healthy,
"last_check": status.last_check.isoformat()
if status.last_check
else None,
"last_check": status.last_check.isoformat() if status.last_check else None,
"consecutive_failures": status.consecutive_failures,
"last_error": status.last_error,
"priority": status.endpoint.priority,

View File

@@ -11,9 +11,7 @@ from app.models.score import GameMode
from sqlmodel import col, exists, select, update
@get_scheduler().scheduled_job(
"cron", hour=0, minute=0, second=0, id="calculate_user_rank"
)
@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="calculate_user_rank")
async def calculate_user_rank(is_today: bool = False):
today = datetime.now(UTC).date()
target_date = today if is_today else today - timedelta(days=1)

View File

@@ -11,9 +11,7 @@ from sqlmodel import exists, select
async def create_banchobot():
async with with_db() as session:
is_exist = (
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
).first()
is_exist = (await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))).first()
if not is_exist:
banchobot = User(
username="BanchoBot",

View File

@@ -82,8 +82,7 @@ async def daily_challenge_job():
if beatmap is None or ruleset_id is None:
logger.warning(
f"[DailyChallenge] Missing required data for daily challenge {now}."
" Will try again in 5 minutes."
f"[DailyChallenge] Missing required data for daily challenge {now}. Will try again in 5 minutes."
)
get_scheduler().add_job(
daily_challenge_job,
@@ -104,9 +103,7 @@ async def daily_challenge_job():
else:
allowed_mods_list = get_available_mods(ruleset_id_int, required_mods_list)
next_day = (now + timedelta(days=1)).replace(
hour=0, minute=0, second=0, microsecond=0
)
next_day = (now + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
room = await create_daily_challenge_room(
beatmap=beatmap_int,
ruleset_id=ruleset_id_int,
@@ -114,24 +111,13 @@ async def daily_challenge_job():
allowed_mods=allowed_mods_list,
duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60),
)
await MetadataHubs.broadcast_call(
"DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id)
)
logger.success(
"[DailyChallenge] Added today's daily challenge: "
f"{beatmap=}, {ruleset_id=}, {required_mods=}"
)
await MetadataHubs.broadcast_call("DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id))
logger.success(f"[DailyChallenge] Added today's daily challenge: {beatmap=}, {ruleset_id=}, {required_mods=}")
return
except (ValueError, json.JSONDecodeError) as e:
logger.warning(
f"[DailyChallenge] Error processing daily challenge data: {e}"
" Will try again in 5 minutes."
)
logger.warning(f"[DailyChallenge] Error processing daily challenge data: {e} Will try again in 5 minutes.")
except Exception as e:
logger.exception(
f"[DailyChallenge] Unexpected error in daily challenge job: {e}"
" Will try again in 5 minutes."
)
logger.exception(f"[DailyChallenge] Unexpected error in daily challenge job: {e} Will try again in 5 minutes.")
get_scheduler().add_job(
daily_challenge_job,
"date",
@@ -139,9 +125,7 @@ async def daily_challenge_job():
)
@get_scheduler().scheduled_job(
"cron", hour=0, minute=1, second=0, id="daily_challenge_last_top"
)
@get_scheduler().scheduled_job("cron", hour=0, minute=1, second=0, id="daily_challenge_last_top")
async def process_daily_challenge_top():
async with with_db() as session:
now = datetime.now(UTC)
@@ -182,11 +166,7 @@ async def process_daily_challenge_top():
await session.commit()
del s
user_ids = (
await session.exec(
select(User.id).where(col(User.id).not_in(participated_users))
)
).all()
user_ids = (await session.exec(select(User.id).where(col(User.id).not_in(participated_users)))).all()
for id in user_ids:
stats = await session.get(DailyChallengeStats, id)
if stats is None: # not execute

View File

@@ -4,14 +4,13 @@
from __future__ import annotations
from datetime import datetime, UTC, timedelta
from datetime import UTC, datetime, timedelta
from app.database.email_verification import EmailVerification, LoginSession
from app.log import logger
from sqlmodel import select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy import and_
class DatabaseCleanupService:
@@ -21,211 +20,207 @@ class DatabaseCleanupService:
async def cleanup_expired_verification_codes(db: AsyncSession) -> int:
"""
清理过期的邮件验证码
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
# 查找过期的验证码记录
current_time = datetime.now(UTC)
stmt = select(EmailVerification).where(
EmailVerification.expires_at < current_time
)
stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
result = await db.exec(stmt)
expired_codes = result.all()
# 删除过期的记录
deleted_count = 0
for code in expired_codes:
await db.delete(code)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {str(e)}")
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {e!s}")
return 0
@staticmethod
async def cleanup_expired_login_sessions(db: AsyncSession) -> int:
"""
清理过期的登录会话
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
# 查找过期的登录会话记录
current_time = datetime.now(UTC)
stmt = select(LoginSession).where(
LoginSession.expires_at < current_time
)
stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
result = await db.exec(stmt)
expired_sessions = result.all()
# 删除过期的记录
deleted_count = 0
for session in expired_sessions:
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {str(e)}")
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {e!s}")
return 0
@staticmethod
async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int:
"""
清理旧的已使用验证码记录
Args:
db: 数据库会话
days_old: 清理多少天前的已使用记录默认7天
Returns:
int: 清理的记录数
"""
try:
# 查找指定天数前的已使用验证码记录
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
stmt = select(EmailVerification).where(
EmailVerification.is_used == True
)
stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
result = await db.exec(stmt)
all_used_codes = result.all()
# 筛选出过期的记录
old_used_codes = [
code for code in all_used_codes
if code.used_at and code.used_at < cutoff_time
]
old_used_codes = [code for code in all_used_codes if code.used_at and code.used_at < cutoff_time]
# 删除旧的已使用记录
deleted_count = 0
for code in old_used_codes:
await db.delete(code)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days")
logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days"
)
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {str(e)}")
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}")
return 0
@staticmethod
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
"""
清理旧的已验证会话记录
Args:
db: 数据库会话
days_old: 清理多少天前的已验证记录默认30天
Returns:
int: 清理的记录数
"""
try:
# 查找指定天数前的已验证会话记录
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
stmt = select(LoginSession).where(
LoginSession.is_verified == True
)
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
result = await db.exec(stmt)
all_verified_sessions = result.all()
# 筛选出过期的记录
old_verified_sessions = [
session for session in all_verified_sessions
session
for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_time
]
# 删除旧的已验证记录
deleted_count = 0
for session in old_verified_sessions:
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days")
logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days"
)
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {str(e)}")
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}")
return 0
@staticmethod
async def run_full_cleanup(db: AsyncSession) -> dict[str, int]:
"""
运行完整的清理流程
Args:
db: 数据库会话
Returns:
dict: 各项清理的结果统计
"""
results = {}
# 清理过期的验证码
results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
# 清理过期的登录会话
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 清理7天前的已使用验证码
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
# 清理30天前的已验证会话
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
total_cleaned = sum(results.values())
if total_cleaned > 0:
logger.debug(f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}")
logger.debug(
f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}"
)
return results
@staticmethod
async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]:
"""
获取清理统计信息
Args:
db: 数据库会话
Returns:
dict: 统计信息
"""
@@ -233,57 +228,54 @@ class DatabaseCleanupService:
current_time = datetime.now(UTC)
cutoff_7_days = current_time - timedelta(days=7)
cutoff_30_days = current_time - timedelta(days=30)
# 统计过期的验证码数量
expired_codes_stmt = select(EmailVerification).where(
EmailVerification.expires_at < current_time
)
expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
expired_codes_result = await db.exec(expired_codes_stmt)
expired_codes_count = len(expired_codes_result.all())
# 统计过期的登录会话数量
expired_sessions_stmt = select(LoginSession).where(
LoginSession.expires_at < current_time
)
expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
expired_sessions_result = await db.exec(expired_sessions_stmt)
expired_sessions_count = len(expired_sessions_result.all())
# 统计7天前的已使用验证码数量
old_used_codes_stmt = select(EmailVerification).where(
EmailVerification.is_used == True
)
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
old_used_codes_result = await db.exec(old_used_codes_stmt)
all_used_codes = old_used_codes_result.all()
old_used_codes_count = len([
code for code in all_used_codes
if code.used_at and code.used_at < cutoff_7_days
])
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(
LoginSession.is_verified == True
old_used_codes_count = len(
[code for code in all_used_codes if code.used_at and code.used_at < cutoff_7_days]
)
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
all_verified_sessions = old_verified_sessions_result.all()
old_verified_sessions_count = len([
session for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_30_days
])
old_verified_sessions_count = len(
[
session
for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_30_days
]
)
return {
"expired_verification_codes": expired_codes_count,
"expired_login_sessions": expired_sessions_count,
"old_used_verification_codes": old_used_codes_count,
"old_verified_sessions": old_verified_sessions_count,
"total_cleanable": expired_codes_count + expired_sessions_count + old_used_codes_count + old_verified_sessions_count
"total_cleanable": expired_codes_count
+ expired_sessions_count
+ old_used_codes_count
+ old_verified_sessions_count,
}
except Exception as e:
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {str(e)}")
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {e!s}")
return {
"expired_verification_codes": 0,
"expired_login_sessions": 0,
"old_used_verification_codes": 0,
"old_verified_sessions": 0,
"total_cleanable": 0
"total_cleanable": 0,
}

View File

@@ -8,17 +8,18 @@ from __future__ import annotations
import asyncio
import concurrent.futures
from datetime import datetime
import json
import uuid
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from typing import Dict, Any, Optional
import redis as sync_redis # 添加同步Redis导入
from email.mime.text import MIMEText
import json
import smtplib
from typing import Any
import uuid
from app.config import settings
from app.dependencies.database import redis_message_client # 使用同步Redis客户端
from app.log import logger
from app.utils import bg_tasks # 添加同步Redis导入
import redis as sync_redis
class EmailQueue:
@@ -30,14 +31,14 @@ class EmailQueue:
self._processing = False
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
self._retry_limit = 3 # 重试次数限制
# 邮件配置
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
self.smtp_port = getattr(settings, 'smtp_port', 587)
self.smtp_username = getattr(settings, 'smtp_username', '')
self.smtp_password = getattr(settings, 'smtp_password', '')
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
self.from_name = getattr(settings, 'from_name', 'osu! server')
self.smtp_server = getattr(settings, "smtp_server", "localhost")
self.smtp_port = getattr(settings, "smtp_port", 587)
self.smtp_username = getattr(settings, "smtp_username", "")
self.smtp_password = getattr(settings, "smtp_password", "")
self.from_email = getattr(settings, "from_email", "noreply@example.com")
self.from_name = getattr(settings, "from_name", "osu! server")
async def _run_in_executor(self, func, *args):
"""在线程池中运行同步操作"""
@@ -48,7 +49,7 @@ class EmailQueue:
"""启动邮件处理任务"""
if not self._processing:
self._processing = True
asyncio.create_task(self._process_email_queue())
bg_tasks.add_task(self._process_email_queue)
logger.info("Email queue processing started")
async def stop_processing(self):
@@ -56,27 +57,29 @@ class EmailQueue:
self._processing = False
logger.info("Email queue processing stopped")
async def enqueue_email(self,
to_email: str,
subject: str,
content: str,
html_content: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> str:
async def enqueue_email(
self,
to_email: str,
subject: str,
content: str,
html_content: str | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
"""
将邮件加入队列等待发送
Args:
to_email: 收件人邮箱地址
subject: 邮件主题
content: 邮件纯文本内容
html_content: 邮件HTML内容如果有
metadata: 额外元数据如密码重置ID等
Returns:
邮件任务ID
"""
email_id = str(uuid.uuid4())
email_data = {
"id": email_id,
"to_email": to_email,
@@ -86,125 +89,117 @@ class EmailQueue:
"metadata": json.dumps(metadata) if metadata else "{}",
"created_at": datetime.now().isoformat(),
"status": "pending", # pending, sending, sent, failed
"retry_count": "0"
"retry_count": "0",
}
# 将邮件数据存入Redis
await self._run_in_executor(
lambda: self.redis.hset(f"email:{email_id}", mapping=email_data)
)
await self._run_in_executor(lambda: self.redis.hset(f"email:{email_id}", mapping=email_data))
# 设置24小时过期防止数据堆积
await self._run_in_executor(
self.redis.expire, f"email:{email_id}", 86400
)
await self._run_in_executor(self.redis.expire, f"email:{email_id}", 86400)
# 加入发送队列
await self._run_in_executor(
self.redis.lpush, "email_queue", email_id
)
await self._run_in_executor(self.redis.lpush, "email_queue", email_id)
logger.info(f"Email enqueued with id: {email_id} to {to_email}")
return email_id
async def get_email_status(self, email_id: str) -> Dict[str, Any]:
async def get_email_status(self, email_id: str) -> dict[str, Any]:
"""
获取邮件发送状态
Args:
email_id: 邮件任务ID
Returns:
邮件任务状态信息
"""
email_data = await self._run_in_executor(
self.redis.hgetall, f"email:{email_id}"
)
email_data = await self._run_in_executor(self.redis.hgetall, f"email:{email_id}")
# 解码Redis返回的字节数据
if email_data:
return {
k.decode("utf-8") if isinstance(k, bytes) else k:
v.decode("utf-8") if isinstance(v, bytes) else v
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v
for k, v in email_data.items()
}
return {"status": "not_found"}
async def _process_email_queue(self):
"""处理邮件队列"""
logger.info("Starting email queue processor")
while self._processing:
try:
# 从队列获取邮件ID
def brpop_operation():
return self.redis.brpop(["email_queue"], timeout=5)
result = await self._run_in_executor(brpop_operation)
if not result:
await asyncio.sleep(1)
continue
# 解包返回结果(列表名和值)
queue_name, email_id = result
if isinstance(email_id, bytes):
email_id = email_id.decode("utf-8")
# 获取邮件数据
email_data = await self.get_email_status(email_id)
if email_data.get("status") == "not_found":
logger.warning(f"Email data not found for id: {email_id}")
continue
# 更新状态为发送中
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "status", "sending"
)
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sending")
# 尝试发送邮件
success = await self._send_email(email_data)
if success:
# 更新状态为已发送
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sent")
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "status", "sent"
)
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "sent_at", datetime.now().isoformat()
self.redis.hset,
f"email:{email_id}",
"sent_at",
datetime.now().isoformat(),
)
logger.info(f"Email {email_id} sent successfully to {email_data.get('to_email')}")
else:
# 计算重试次数
retry_count = int(email_data.get("retry_count", "0")) + 1
if retry_count <= self._retry_limit:
# 重新入队,稍后重试
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "retry_count", str(retry_count)
self.redis.hset,
f"email:{email_id}",
"retry_count",
str(retry_count),
)
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "pending")
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "status", "pending"
self.redis.hset,
f"email:{email_id}",
"last_retry",
datetime.now().isoformat(),
)
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "last_retry", datetime.now().isoformat()
)
# 延迟重试(使用指数退避)
delay = 60 * (2 ** (retry_count - 1)) # 1分钟2分钟4分钟...
# 创建延迟任务
asyncio.create_task(self._delayed_retry(email_id, delay))
bg_tasks.add_task(self._delayed_retry, email_id, delay)
logger.warning(f"Email {email_id} will be retried in {delay} seconds (attempt {retry_count})")
else:
# 超过重试次数,标记为失败
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "status", "failed"
)
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "failed")
logger.error(f"Email {email_id} failed after {retry_count} attempts")
except Exception as e:
logger.error(f"Error processing email queue: {e}")
await asyncio.sleep(5) # 出错后等待5秒
@@ -212,53 +207,51 @@ class EmailQueue:
async def _delayed_retry(self, email_id: str, delay: int):
"""延迟重试发送邮件"""
await asyncio.sleep(delay)
await self._run_in_executor(
self.redis.lpush, "email_queue", email_id
)
await self._run_in_executor(self.redis.lpush, "email_queue", email_id)
logger.info(f"Re-queued email {email_id} for retry after {delay} seconds")
async def _send_email(self, email_data: Dict[str, Any]) -> bool:
async def _send_email(self, email_data: dict[str, Any]) -> bool:
"""
实际发送邮件
Args:
email_data: 邮件数据
Returns:
是否发送成功
"""
try:
# 如果邮件发送功能被禁用,则只记录日志
if not getattr(settings, 'enable_email_sending', True):
if not getattr(settings, "enable_email_sending", True):
logger.info(f"[Mock Email] Would send to {email_data.get('to_email')}: {email_data.get('subject')}")
return True
# 创建邮件
msg = MIMEMultipart('alternative')
msg['From'] = f"{self.from_name} <{self.from_email}>"
msg['To'] = email_data.get('to_email', '')
msg['Subject'] = email_data.get('subject', '')
msg = MIMEMultipart("alternative")
msg["From"] = f"{self.from_name} <{self.from_email}>"
msg["To"] = email_data.get("to_email", "")
msg["Subject"] = email_data.get("subject", "")
# 添加纯文本内容
content = email_data.get('content', '')
content = email_data.get("content", "")
if content:
msg.attach(MIMEText(content, 'plain', 'utf-8'))
msg.attach(MIMEText(content, "plain", "utf-8"))
# 添加HTML内容如果有
html_content = email_data.get('html_content', '')
html_content = email_data.get("html_content", "")
if html_content:
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
msg.attach(MIMEText(html_content, "html", "utf-8"))
# 发送邮件
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
if self.smtp_username and self.smtp_password:
server.starttls()
server.login(self.smtp_username, self.smtp_password)
server.send_message(msg)
return True
except Exception as e:
logger.error(f"Failed to send email: {e}")
return False
@@ -267,10 +260,12 @@ class EmailQueue:
# 全局邮件队列实例
email_queue = EmailQueue()
# 在应用启动时调用
async def start_email_processor():
await email_queue.start_processing()
# 在应用关闭时调用
async def stop_email_processor():
await email_queue.stop_processing()

View File

@@ -4,13 +4,11 @@
from __future__ import annotations
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
import secrets
import smtplib
import string
from datetime import datetime, UTC, timedelta
from typing import Optional
from app.config import settings
from app.log import logger
@@ -18,28 +16,28 @@ from app.log import logger
class EmailService:
"""邮件发送服务"""
def __init__(self):
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
self.smtp_port = getattr(settings, 'smtp_port', 587)
self.smtp_username = getattr(settings, 'smtp_username', '')
self.smtp_password = getattr(settings, 'smtp_password', '')
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
self.from_name = getattr(settings, 'from_name', 'osu! server')
self.smtp_server = getattr(settings, "smtp_server", "localhost")
self.smtp_port = getattr(settings, "smtp_port", 587)
self.smtp_username = getattr(settings, "smtp_username", "")
self.smtp_password = getattr(settings, "smtp_password", "")
self.from_email = getattr(settings, "from_email", "noreply@example.com")
self.from_name = getattr(settings, "from_name", "osu! server")
def generate_verification_code(self) -> str:
"""生成8位验证码"""
# 只使用数字,避免混淆
return ''.join(secrets.choice(string.digits) for _ in range(8))
return "".join(secrets.choice(string.digits) for _ in range(8))
async def send_verification_email(self, email: str, code: str, username: str) -> bool:
"""发送验证邮件"""
try:
msg = MIMEMultipart()
msg['From'] = f"{self.from_name} <{self.from_email}>"
msg['To'] = email
msg['Subject'] = "邮箱验证 - Email Verification"
msg["From"] = f"{self.from_name} <{self.from_email}>"
msg["To"] = email
msg["Subject"] = "邮箱验证 - Email Verification"
# HTML 邮件内容
html_content = f"""
<!DOCTYPE html>
@@ -101,15 +99,15 @@ class EmailService:
<h1>osu! 邮箱验证</h1>
<p>Email Verification</p>
</div>
<div class="content">
<h2>你好 {username}</h2>
<p>感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:</p>
<div class="code">{code}</div>
<p>这个验证码将在 <strong>10 分钟后过期</strong>。</p>
<div class="warning">
<strong>注意:</strong>
<ul>
@@ -118,19 +116,19 @@ class EmailService:
<li>验证码只能使用一次</li>
</ul>
</div>
<p>如果你有任何问题,请联系我们的支持团队。</p>
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
<h3>Hello {username}!</h3>
<p>Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:</p>
<p>This verification code will expire in <strong>10 minutes</strong>.</p>
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
</div>
<div class="footer">
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
<p>This email was sent automatically, please do not reply.</p>
@@ -138,26 +136,26 @@ class EmailService:
</div>
</body>
</html>
"""
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
""" # noqa: E501
msg.attach(MIMEText(html_content, "html", "utf-8"))
# 发送邮件
if not settings.enable_email_sending:
# 邮件发送功能禁用时只记录日志,不实际发送
logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}")
return True
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
if self.smtp_username and self.smtp_password:
server.starttls()
server.login(self.smtp_username, self.smtp_password)
server.send_message(msg)
logger.info(f"[Email Verification] Successfully sent verification code to {email}")
return True
except Exception as e:
logger.error(f"[Email Verification] Failed to send email: {e}")
return False

View File

@@ -4,40 +4,38 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import secrets
import string
from datetime import datetime, UTC, timedelta
from typing import Optional
from app.database.email_verification import EmailVerification, LoginSession
from app.service.email_service import email_service
from app.service.email_queue import email_queue # 导入邮件队列
from app.log import logger
from app.config import settings
from app.database.email_verification import EmailVerification, LoginSession
from app.log import logger
from app.service.email_queue import email_queue # 导入邮件队列
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import select
from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class EmailVerificationService:
"""邮件验证服务"""
@staticmethod
def generate_verification_code() -> str:
"""生成8位验证码"""
return ''.join(secrets.choice(string.digits) for _ in range(8))
return "".join(secrets.choice(string.digits) for _ in range(8))
@staticmethod
async def send_verification_email_via_queue(email: str, code: str, username: str, user_id: int) -> bool:
"""使用邮件队列发送验证邮件
Args:
email: 接收验证码的邮箱地址
code: 验证码
username: 用户名
user_id: 用户ID
Returns:
是否成功将邮件加入队列
"""
@@ -103,15 +101,15 @@ class EmailVerificationService:
<h1>osu! 邮箱验证</h1>
<p>Email Verification</p>
</div>
<div class="content">
<h2>你好 {username}</h2>
<p>请使用以下验证码验证您的账户:</p>
<div class="code">{code}</div>
<p>验证码将在 <strong>10 分钟内有效</strong>。</p>
<div class="warning">
<p><strong>重要提示:</strong></p>
<ul>
@@ -120,17 +118,17 @@ class EmailVerificationService:
<li>为了账户安全,请勿在其他网站使用相同的密码</li>
</ul>
</div>
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
<h3>Hello {username}!</h3>
<p>Please use the following verification code to verify your account:</p>
<p>This verification code will be valid for <strong>10 minutes</strong>.</p>
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
</div>
<div class="footer">
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
<p>This email was sent automatically, please do not reply.</p>
@@ -138,8 +136,8 @@ class EmailVerificationService:
</div>
</body>
</html>
"""
""" # noqa: E501
# 纯文本备用内容
plain_content = f"""
你好 {username}
@@ -162,34 +160,30 @@ This verification code will be valid for 10 minutes.
© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。
This email was sent automatically, please do not reply.
"""
# 将邮件加入队列
subject = "邮箱验证 - Email Verification"
metadata = {
"type": "email_verification",
"user_id": user_id,
"code": code
}
metadata = {"type": "email_verification", "user_id": user_id, "code": code}
await email_queue.enqueue_email(
to_email=email,
subject=subject,
content=plain_content,
html_content=html_content,
metadata=metadata
metadata=metadata,
)
return True
except Exception as e:
logger.error(f"[Email Verification] Failed to enqueue email: {e}")
return False
@staticmethod
def generate_session_token() -> str:
"""生成会话令牌"""
return secrets.token_urlsafe(32)
@staticmethod
async def create_verification_record(
db: AsyncSession,
@@ -197,27 +191,27 @@ This email was sent automatically, please do not reply.
user_id: int,
email: str,
ip_address: str | None = None,
user_agent: str | None = None
user_agent: str | None = None,
) -> tuple[EmailVerification, str]:
"""创建邮件验证记录"""
# 检查是否有未过期的验证码
existing_result = await db.exec(
select(EmailVerification).where(
EmailVerification.user_id == user_id,
EmailVerification.is_used == False,
EmailVerification.expires_at > datetime.now(UTC)
col(EmailVerification.is_used).is_(False),
EmailVerification.expires_at > datetime.now(UTC),
)
)
existing = existing_result.first()
if existing:
# 如果有未过期的验证码,直接返回
return existing, existing.verification_code
# 生成新的验证码
code = EmailVerificationService.generate_verification_code()
# 创建验证记录
verification = EmailVerification(
user_id=user_id,
@@ -225,23 +219,23 @@ This email was sent automatically, please do not reply.
verification_code=code,
expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期
ip_address=ip_address,
user_agent=user_agent
user_agent=user_agent,
)
db.add(verification)
await db.commit()
await db.refresh(verification)
# 存储到 Redis用于快速验证
await redis.setex(
f"email_verification:{user_id}:{code}",
600, # 10分钟过期
str(verification.id) if verification.id else "0"
str(verification.id) if verification.id else "0",
)
logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}")
return verification, code
@staticmethod
async def send_verification_email(
db: AsyncSession,
@@ -250,7 +244,7 @@ This email was sent automatically, please do not reply.
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None
user_agent: str | None = None,
) -> bool:
"""发送验证邮件"""
try:
@@ -258,33 +252,38 @@ This email was sent automatically, please do not reply.
if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
return True # 返回成功,但不执行验证流程
# 创建验证记录
verification, code = await EmailVerificationService.create_verification_record(
(
verification,
code,
) = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
# 使用邮件队列发送验证邮件
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
if success:
logger.info(f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})")
logger.info(
f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})"
)
return True
else:
logger.error(f"[Email Verification] Failed to enqueue verification email: {email} (user: {username})")
return False
except Exception as e:
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
return False
@staticmethod
async def verify_code(
db: AsyncSession,
redis: Redis,
user_id: int,
code: str,
ip_address: str | None = None
ip_address: str | None = None,
) -> tuple[bool, str]:
"""验证验证码"""
try:
@@ -294,46 +293,46 @@ This email was sent automatically, please do not reply.
# 仍然标记登录会话为已验证
await LoginSessionService.mark_session_verified(db, user_id)
return True, "验证成功(邮件验证功能已禁用)"
# 先从 Redis 检查
verification_id = await redis.get(f"email_verification:{user_id}:{code}")
if not verification_id:
return False, "验证码无效或已过期"
# 从数据库获取验证记录
result = await db.exec(
select(EmailVerification).where(
EmailVerification.id == int(verification_id),
EmailVerification.user_id == user_id,
EmailVerification.verification_code == code,
EmailVerification.is_used == False,
EmailVerification.expires_at > datetime.now(UTC)
col(EmailVerification.is_used).is_(False),
EmailVerification.expires_at > datetime.now(UTC),
)
)
verification = result.first()
if not verification:
return False, "验证码无效或已过期"
# 标记为已使用
verification.is_used = True
verification.used_at = datetime.now(UTC)
# 同时更新对应的登录会话状态
await LoginSessionService.mark_session_verified(db, user_id)
await db.commit()
# 删除 Redis 记录
await redis.delete(f"email_verification:{user_id}:{code}")
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
return True, "验证成功"
except Exception as e:
logger.error(f"[Email Verification] Exception during verification code validation: {e}")
return False, "验证过程中发生错误"
@staticmethod
async def resend_verification_code(
db: AsyncSession,
@@ -342,7 +341,7 @@ This email was sent automatically, please do not reply.
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None
user_agent: str | None = None,
) -> tuple[bool, str]:
"""重新发送验证码"""
try:
@@ -350,25 +349,25 @@ This email was sent automatically, please do not reply.
if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
return True, "验证码已发送(邮件验证功能已禁用)"
# 检查重发频率限制60秒内只能发送一次
rate_limit_key = f"email_verification_rate_limit:{user_id}"
if await redis.get(rate_limit_key):
return False, "请等待60秒后再重新发送"
# 设置频率限制
await redis.setex(rate_limit_key, 60, "1")
# 生成新的验证码
success = await EmailVerificationService.send_verification_email(
db, redis, user_id, username, email, ip_address, user_agent
)
if success:
return True, "验证码已重新发送"
else:
return False, "重新发送失败,请稍后再试"
except Exception as e:
logger.error(f"[Email Verification] Exception during resending verification code: {e}")
return False, "重新发送过程中发生错误"
@@ -376,7 +375,7 @@ This email was sent automatically, please do not reply.
class LoginSessionService:
"""登录会话服务"""
@staticmethod
async def create_session(
db: AsyncSession,
@@ -385,47 +384,40 @@ class LoginSessionService:
ip_address: str,
user_agent: str | None = None,
country_code: str | None = None,
is_new_location: bool = False
is_new_location: bool = False,
) -> LoginSession:
"""创建登录会话"""
from app.utils import simplify_user_agent
session_token = EmailVerificationService.generate_session_token()
# 简化 User-Agent 字符串
simplified_user_agent = simplify_user_agent(user_agent, max_length=250)
session = LoginSession(
user_id=user_id,
session_token=session_token,
ip_address=ip_address,
user_agent=simplified_user_agent,
user_agent=None,
country_code=country_code,
is_new_location=is_new_location,
expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期
is_verified=not is_new_location # 新位置需要验证
is_verified=not is_new_location, # 新位置需要验证
)
db.add(session)
await db.commit()
await db.refresh(session)
# 存储到 Redis
await redis.setex(
f"login_session:{session_token}",
86400, # 24小时
user_id
user_id,
)
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
return session
@staticmethod
async def verify_session(
db: AsyncSession,
redis: Redis,
session_token: str,
verification_code: str
db: AsyncSession, redis: Redis, session_token: str, verification_code: str
) -> tuple[bool, str]:
"""验证会话(通过邮件验证码)"""
try:
@@ -433,98 +425,89 @@ class LoginSessionService:
user_id = await redis.get(f"login_session:{session_token}")
if not user_id:
return False, "会话无效或已过期"
user_id = int(user_id)
# 验证邮件验证码
success, message = await EmailVerificationService.verify_code(
db, redis, user_id, verification_code
)
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code)
if not success:
return False, message
# 更新会话状态
result = await db.exec(
select(LoginSession).where(
LoginSession.session_token == session_token,
LoginSession.user_id == user_id,
LoginSession.is_verified == False
col(LoginSession.is_verified).is_(False),
)
)
session = result.first()
if session:
session.is_verified = True
session.verified_at = datetime.now(UTC)
await db.commit()
logger.info(f"[Login Session] User {user_id} session verification successful")
return True, "会话验证成功"
except Exception as e:
logger.error(f"[Login Session] Exception during session verification: {e}")
return False, "验证过程中发生错误"
@staticmethod
async def check_new_location(
db: AsyncSession,
user_id: int,
ip_address: str,
country_code: str | None = None
db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None
) -> bool:
"""检查是否为新位置登录"""
try:
# 查看过去30天内是否有相同IP或相同国家的登录记录
thirty_days_ago = datetime.now(UTC) - timedelta(days=30)
result = await db.exec(
select(LoginSession).where(
LoginSession.user_id == user_id,
LoginSession.created_at > thirty_days_ago,
(LoginSession.ip_address == ip_address) |
(LoginSession.country_code == country_code)
(LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code),
)
)
existing_sessions = result.all()
# 如果有历史记录,则不是新位置
return len(existing_sessions) == 0
except Exception as e:
logger.error(f"[Login Session] Exception during new location check: {e}")
# 出错时默认为新位置(更安全)
return True
@staticmethod
async def mark_session_verified(
db: AsyncSession,
user_id: int
) -> bool:
async def mark_session_verified(db: AsyncSession, user_id: int) -> bool:
"""标记用户的未验证会话为已验证"""
try:
# 查找用户所有未验证且未过期的会话
result = await db.exec(
select(LoginSession).where(
LoginSession.user_id == user_id,
LoginSession.is_verified == False,
LoginSession.expires_at > datetime.now(UTC)
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > datetime.now(UTC),
)
)
sessions = result.all()
# 标记所有会话为已验证
for session in sessions:
session.is_verified = True
session.verified_at = datetime.now(UTC)
if sessions:
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
return len(sessions) > 0
except Exception as e:
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
return False

View File

@@ -117,14 +117,10 @@ class EnhancedIntervalStatsManager:
@staticmethod
async def get_current_interval_info() -> IntervalInfo:
"""获取当前区间信息"""
start_time, end_time = (
EnhancedIntervalStatsManager.get_current_interval_boundaries()
)
start_time, end_time = EnhancedIntervalStatsManager.get_current_interval_boundaries()
interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time)
return IntervalInfo(
start_time=start_time, end_time=end_time, interval_key=interval_key
)
return IntervalInfo(start_time=start_time, end_time=end_time, interval_key=interval_key)
@staticmethod
async def initialize_current_interval() -> None:
@@ -133,9 +129,7 @@ class EnhancedIntervalStatsManager:
redis_async = get_redis()
try:
current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
# 存储当前区间信息
await _redis_exec(
@@ -147,9 +141,7 @@ class EnhancedIntervalStatsManager:
# 初始化区间用户集合(如果不存在)
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
playing_key = (
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
# 设置过期时间为35分钟
await redis_async.expire(online_key, 35 * 60)
@@ -179,7 +171,8 @@ class EnhancedIntervalStatsManager:
await EnhancedIntervalStatsManager._ensure_24h_history_exists()
logger.info(
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}"
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')}"
f" - {current_interval.end_time.strftime('%H:%M')}"
)
except Exception as e:
@@ -193,42 +186,32 @@ class EnhancedIntervalStatsManager:
try:
# 检查现有历史数据数量
history_length = await _redis_exec(
redis_sync.llen, REDIS_ONLINE_HISTORY_KEY
)
history_length = await _redis_exec(redis_sync.llen, REDIS_ONLINE_HISTORY_KEY)
if history_length < 48: # 少于48个数据点24小时*2
logger.info(
f"History has only {history_length} points, filling with zeros for 24h"
)
logger.info(f"History has only {history_length} points, filling with zeros for 24h")
# 计算需要填充的数据点数量
needed_points = 48 - history_length
# 从当前时间往前推创建缺失的时间点都填充为0
current_time = datetime.utcnow()
current_interval_start, _ = (
EnhancedIntervalStatsManager.get_current_interval_boundaries()
)
current_time = datetime.utcnow() # noqa: F841
current_interval_start, _ = EnhancedIntervalStatsManager.get_current_interval_boundaries()
# 从当前区间开始往前推创建历史数据点确保时间对齐到30分钟边界
fill_points = []
for i in range(needed_points):
# 每次往前推30分钟确保时间对齐
point_time = current_interval_start - timedelta(
minutes=30 * (i + 1)
)
point_time = current_interval_start - timedelta(minutes=30 * (i + 1))
# 确保时间对齐到30分钟边界
aligned_minute = (point_time.minute // 30) * 30
point_time = point_time.replace(
minute=aligned_minute, second=0, microsecond=0
)
point_time = point_time.replace(minute=aligned_minute, second=0, microsecond=0)
history_point = {
"timestamp": point_time.isoformat(),
"online_count": 0,
"playing_count": 0
"playing_count": 0,
}
fill_points.append(json.dumps(history_point))
@@ -238,9 +221,7 @@ class EnhancedIntervalStatsManager:
temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp"
if history_length > 0:
# 复制现有数据到临时key
existing_data = await _redis_exec(
redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1
)
existing_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
if existing_data:
for data in existing_data:
await _redis_exec(redis_sync.rpush, temp_key, data)
@@ -250,19 +231,13 @@ class EnhancedIntervalStatsManager:
# 先添加填充数据(最旧的)
for point in reversed(fill_points): # 反向添加,最旧的在最后
await _redis_exec(
redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point
)
await _redis_exec(redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point)
# 再添加原有数据(较新的)
if history_length > 0:
existing_data = await _redis_exec(
redis_sync.lrange, temp_key, 0, -1
)
existing_data = await _redis_exec(redis_sync.lrange, temp_key, 0, -1)
for data in existing_data:
await _redis_exec(
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data
)
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data)
# 清理临时key
await redis_async.delete(temp_key)
@@ -273,9 +248,7 @@ class EnhancedIntervalStatsManager:
# 设置过期时间
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
logger.info(
f"Filled {len(fill_points)} historical data points with zeros"
)
logger.info(f"Filled {len(fill_points)} historical data points with zeros")
except Exception as e:
logger.error(f"Error ensuring 24h history exists: {e}")
@@ -287,9 +260,7 @@ class EnhancedIntervalStatsManager:
redis_async = get_redis()
try:
current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
# 添加到区间在线用户集合
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
@@ -298,9 +269,7 @@ class EnhancedIntervalStatsManager:
# 如果用户在游玩,也添加到游玩用户集合
if is_playing:
playing_key = (
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
await _redis_exec(redis_sync.sadd, playing_key, str(user_id))
await redis_async.expire(playing_key, 35 * 60)
@@ -308,7 +277,8 @@ class EnhancedIntervalStatsManager:
await EnhancedIntervalStatsManager._update_interval_stats()
logger.debug(
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}"
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}"
f"-{current_interval.end_time.strftime('%H:%M')}"
)
except Exception as e:
@@ -321,15 +291,11 @@ class EnhancedIntervalStatsManager:
redis_async = get_redis()
try:
current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
# 获取区间内独特用户数
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
playing_key = (
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
unique_online = await _redis_exec(redis_sync.scard, online_key)
unique_playing = await _redis_exec(redis_sync.scard, playing_key)
@@ -339,16 +305,12 @@ class EnhancedIntervalStatsManager:
current_playing = await _get_playing_users_count(redis_async)
# 获取现有统计数据
existing_data = await _redis_exec(
redis_sync.get, current_interval.interval_key
)
existing_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
if existing_data:
stats = IntervalStats.from_dict(json.loads(existing_data))
# 更新峰值
stats.peak_online_count = max(stats.peak_online_count, current_online)
stats.peak_playing_count = max(
stats.peak_playing_count, current_playing
)
stats.peak_playing_count = max(stats.peak_playing_count, current_playing)
stats.total_samples += 1
else:
# 创建新的统计记录
@@ -377,7 +339,8 @@ class EnhancedIntervalStatsManager:
await redis_async.expire(current_interval.interval_key, 35 * 60)
logger.debug(
f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
f"Updated interval stats: online={unique_online}, playing={unique_playing}, "
f"peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
)
except Exception as e:
@@ -395,21 +358,21 @@ class EnhancedIntervalStatsManager:
# 上一个区间开始时间是当前区间开始时间减去30分钟
previous_start = current_start - timedelta(minutes=30)
previous_end = current_start # 上一个区间的结束时间就是当前区间的开始时间
interval_key = EnhancedIntervalStatsManager.generate_interval_key(previous_start)
previous_interval = IntervalInfo(
start_time=previous_start,
end_time=previous_end,
interval_key=interval_key
interval_key=interval_key,
)
# 获取最终统计数据
stats_data = await _redis_exec(
redis_sync.get, previous_interval.interval_key
)
stats_data = await _redis_exec(redis_sync.get, previous_interval.interval_key)
if not stats_data:
logger.warning(f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}")
logger.warning(
f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}"
)
return None
stats = IntervalStats.from_dict(json.loads(stats_data))
@@ -418,13 +381,11 @@ class EnhancedIntervalStatsManager:
history_point = {
"timestamp": previous_interval.start_time.isoformat(),
"online_count": stats.unique_online_users,
"playing_count": stats.unique_playing_users
"playing_count": stats.unique_playing_users,
}
# 添加到历史记录
await _redis_exec(
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)
)
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
# 只保留48个数据点24小时每30分钟一个点
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
# 设置过期时间为26小时确保有足够缓冲
@@ -452,12 +413,8 @@ class EnhancedIntervalStatsManager:
redis_sync = get_redis_message()
try:
current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
stats_data = await _redis_exec(
redis_sync.get, current_interval.interval_key
)
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
if stats_data:
return IntervalStats.from_dict(json.loads(stats_data))
@@ -506,8 +463,6 @@ class EnhancedIntervalStatsManager:
# 便捷函数,用于替换现有的统计更新函数
async def update_user_activity_in_interval(
user_id: int, is_playing: bool = False
) -> None:
async def update_user_activity_in_interval(user_id: int, is_playing: bool = False) -> None:
"""用户活动时更新区间统计(在登录、开始游玩等时调用)"""
await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing)

View File

@@ -11,12 +11,8 @@ def load_achievements() -> Medals:
for module in ACHIEVEMENTS_DIR.iterdir():
if module.is_file() and module.suffix == ".py":
module_name = module.stem
module_achievements = importlib.import_module(
f"app.achievements.{module_name}"
)
module_achievements = importlib.import_module(f"app.achievements.{module_name}")
medals = getattr(module_achievements, "MEDALS", {})
MEDALS.update(medals)
logger.success(
f"Successfully loaded {len(medals)} achievements from {module_name}.py"
)
logger.success(f"Successfully loaded {len(medals)} achievements from {module_name}.py")
return MEDALS

View File

@@ -47,6 +47,7 @@ class LoginLogService:
# 获取并简化User-Agent
from app.utils import simplify_user_agent
raw_user_agent = request.headers.get("User-Agent", "")
user_agent = simplify_user_agent(raw_user_agent, max_length=500)
@@ -67,9 +68,7 @@ class LoginLogService:
# 在后台线程中运行GeoIP查询避免阻塞
loop = asyncio.get_event_loop()
geo_info = await loop.run_in_executor(
None, lambda: geoip.lookup(ip_address)
)
geo_info = await loop.run_in_executor(None, lambda: geoip.lookup(ip_address))
if geo_info:
login_log.country_code = geo_info.get("country_iso", "")
@@ -89,10 +88,7 @@ class LoginLogService:
login_log.organization = geo_info.get("organization", "")
logger.debug(
f"GeoIP lookup for {ip_address}: "
f"{geo_info.get('country_name', 'Unknown')}"
)
logger.debug(f"GeoIP lookup for {ip_address}: {geo_info.get('country_name', 'Unknown')}")
else:
logger.warning(f"GeoIP lookup failed for {ip_address}")
@@ -104,9 +100,7 @@ class LoginLogService:
await db.commit()
await db.refresh(login_log)
logger.info(
f"Login recorded for user {user_id} from {ip_address} ({login_method})"
)
logger.info(f"Login recorded for user {user_id} from {ip_address} ({login_method})")
return login_log
@staticmethod
@@ -137,9 +131,7 @@ class LoginLogService:
request=request,
login_success=False,
login_method=login_method,
notes=f"Failed login attempt: {attempted_username}"
if attempted_username
else "Failed login attempt",
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt",
)

View File

@@ -13,6 +13,7 @@ import uuid
from app.database.chat import ChatMessage, MessageType
from app.dependencies.database import get_redis, with_db
from app.log import logger
from app.utils import bg_tasks
class MessageQueue:
@@ -34,7 +35,7 @@ class MessageQueue:
"""启动消息处理任务"""
if not self._processing:
self._processing = True
asyncio.create_task(self._process_message_queue())
bg_tasks.add_task(self._process_message_queue)
logger.info("Message queue processing started")
async def stop_processing(self):
@@ -59,12 +60,8 @@ class MessageQueue:
message_data["status"] = "pending" # pending, processing, completed, failed
# 将消息存储到 Redis
await self._run_in_executor(
lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data)
)
await self._run_in_executor(
self.redis.expire, f"msg:{temp_uuid}", 3600
) # 1小时过期
await self._run_in_executor(lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data))
await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
# 加入处理队列
await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid)
@@ -74,17 +71,13 @@ class MessageQueue:
async def get_message_status(self, temp_uuid: str) -> dict | None:
"""获取消息状态"""
message_data = await self._run_in_executor(
self.redis.hgetall, f"msg:{temp_uuid}"
)
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
if not message_data:
return None
return message_data
async def get_cached_messages(
self, channel_id: int, limit: int = 50, since: int = 0
) -> list[dict]:
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
"""
从 Redis 获取缓存的消息
@@ -97,15 +90,11 @@ class MessageQueue:
消息列表
"""
# 从 Redis 获取频道最近的消息 UUID 列表
message_uuids = await self._run_in_executor(
self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1
)
message_uuids = await self._run_in_executor(self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1)
messages = []
for uuid_str in message_uuids:
message_data = await self._run_in_executor(
self.redis.hgetall, f"msg:{uuid_str}"
)
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}")
if message_data:
# 检查是否满足 since 条件
if since > 0 and "message_id" in message_data:
@@ -116,22 +105,14 @@ class MessageQueue:
return messages[::-1] # 返回时间顺序
async def cache_channel_message(
self, channel_id: int, temp_uuid: str, max_cache: int = 100
):
async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100):
"""将消息 UUID 缓存到频道消息列表"""
# 添加到频道消息列表开头
await self._run_in_executor(
self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid
)
await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid)
# 限制缓存大小
await self._run_in_executor(
self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1
)
await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1)
# 设置过期时间24小时
await self._run_in_executor(
self.redis.expire, f"channel:{channel_id}:messages", 86400
)
await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400)
async def _process_message_queue(self):
"""异步处理消息队列,批量写入数据库"""
@@ -140,9 +121,7 @@ class MessageQueue:
# 批量获取消息
message_uuids = []
for _ in range(self._batch_size):
result = await self._run_in_executor(
lambda: self.redis.brpop(["message_queue"], timeout=1)
)
result = await self._run_in_executor(lambda: self.redis.brpop(["message_queue"], timeout=1))
if result:
message_uuids.append(result[1])
else:
@@ -166,16 +145,12 @@ class MessageQueue:
for temp_uuid in message_uuids:
try:
# 获取消息数据
message_data = await self._run_in_executor(
self.redis.hgetall, f"msg:{temp_uuid}"
)
message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}")
if not message_data:
continue
# 更新状态为处理中
await self._run_in_executor(
self.redis.hset, f"msg:{temp_uuid}", "status", "processing"
)
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing")
# 创建数据库消息对象
msg = ChatMessage(
@@ -190,9 +165,7 @@ class MessageQueue:
except Exception as e:
logger.error(f"Error preparing message {temp_uuid}: {e}")
await self._run_in_executor(
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
)
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
if messages_to_insert:
try:
@@ -211,16 +184,12 @@ class MessageQueue:
mapping={
"status": "completed",
"message_id": str(msg.message_id),
"created_at": msg.timestamp.isoformat()
if msg.timestamp
else "",
"created_at": msg.timestamp.isoformat() if msg.timestamp else "",
},
)
)
logger.info(
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}"
)
logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}")
except Exception as e:
logger.error(f"Error inserting messages to database: {e}")
@@ -228,9 +197,7 @@ class MessageQueue:
# 标记所有消息为失败
for _, temp_uuid in messages_to_insert:
await self._run_in_executor(
self.redis.hset, f"msg:{temp_uuid}", "status", "failed"
)
await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed")
# 全局消息队列实例

View File

@@ -33,36 +33,22 @@ class MessageQueueProcessor:
"""将消息缓存到 Redis"""
try:
# 存储消息数据
await self._redis_exec(
self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data
)
await self._redis_exec(
self.redis_message.expire, f"msg:{temp_uuid}", 3600
) # 1小时过期
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data)
await self._redis_exec(self.redis_message.expire, f"msg:{temp_uuid}", 3600) # 1小时过期
# 加入频道消息列表
await self._redis_exec(
self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid
)
await self._redis_exec(
self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99
) # 保持最新100条
await self._redis_exec(
self.redis_message.expire, f"channel:{channel_id}:messages", 86400
) # 24小时过期
await self._redis_exec(self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid)
await self._redis_exec(self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99) # 保持最新100条
await self._redis_exec(self.redis_message.expire, f"channel:{channel_id}:messages", 86400) # 24小时过期
# 加入异步处理队列
await self._redis_exec(
self.redis_message.lpush, "message_write_queue", temp_uuid
)
await self._redis_exec(self.redis_message.lpush, "message_write_queue", temp_uuid)
logger.info(f"Message cached to Redis: {temp_uuid}")
except Exception as e:
logger.error(f"Failed to cache message to Redis: {e}")
async def get_cached_messages(
self, channel_id: int, limit: int = 50, since: int = 0
) -> list[dict]:
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
"""从 Redis 获取缓存的消息"""
try:
message_uuids = await self._redis_exec(
@@ -78,15 +64,11 @@ class MessageQueueProcessor:
if isinstance(temp_uuid, bytes):
temp_uuid = temp_uuid.decode("utf-8")
raw_data = await self._redis_exec(
self.redis_message.hgetall, f"msg:{temp_uuid}"
)
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
if raw_data:
# 解码 Redis 返回的字节数据
message_data = {
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode(
"utf-8"
)
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8")
if isinstance(v, bytes)
else v
for k, v in raw_data.items()
@@ -103,9 +85,7 @@ class MessageQueueProcessor:
logger.error(f"Failed to get cached messages: {e}")
return []
async def update_message_status(
self, temp_uuid: str, status: str, message_id: int | None = None
):
async def update_message_status(self, temp_uuid: str, status: str, message_id: int | None = None):
"""更新消息状态"""
try:
update_data = {"status": status}
@@ -113,26 +93,20 @@ class MessageQueueProcessor:
update_data["message_id"] = str(message_id)
update_data["db_timestamp"] = datetime.now().isoformat()
await self._redis_exec(
self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data
)
await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data)
except Exception as e:
logger.error(f"Failed to update message status: {e}")
async def get_message_status(self, temp_uuid: str) -> dict | None:
"""获取消息状态"""
try:
raw_data = await self._redis_exec(
self.redis_message.hgetall, f"msg:{temp_uuid}"
)
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
if not raw_data:
return None
# 解码 Redis 返回的字节数据
return {
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8")
if isinstance(v, bytes)
else v
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v
for k, v in raw_data.items()
}
except Exception as e:
@@ -148,9 +122,7 @@ class MessageQueueProcessor:
# 批量获取消息
message_uuids = []
for _ in range(20): # 批量处理20条消息
result = await self._redis_exec(
self.redis_message.brpop, ["message_write_queue"], timeout=1
)
result = await self._redis_exec(self.redis_message.brpop, ["message_write_queue"], timeout=1)
if result:
# result是 (queue_name, value) 的元组,需要解码
uuid_value = result[1]
@@ -179,17 +151,13 @@ class MessageQueueProcessor:
for temp_uuid in message_uuids:
try:
# 获取消息数据并解码
raw_data = await self._redis_exec(
self.redis_message.hgetall, f"msg:{temp_uuid}"
)
raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}")
if not raw_data:
continue
# 解码 Redis 返回的字节数据
message_data = {
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode(
"utf-8"
)
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8")
if isinstance(v, bytes)
else v
for k, v in raw_data.items()
@@ -215,10 +183,7 @@ class MessageQueueProcessor:
await session.refresh(msg)
# 更新成功状态包含临时消息ID映射
assert msg.message_id is not None
await self.update_message_status(
temp_uuid, "completed", msg.message_id
)
await self.update_message_status(temp_uuid, "completed", msg.message_id)
# 如果有临时消息ID存储映射关系并通知客户端更新
if message_data.get("temp_message_id"):
@@ -232,12 +197,11 @@ class MessageQueueProcessor:
# 发送消息ID更新通知到频道
channel_id = int(message_data["channel_id"])
await self._notify_message_update(
channel_id, temp_msg_id, msg.message_id, message_data
)
await self._notify_message_update(channel_id, temp_msg_id, msg.message_id, message_data)
logger.info(
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}"
f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, "
f"temp_id: {message_data.get('temp_message_id')}"
)
except Exception as e:
@@ -272,9 +236,7 @@ class MessageQueueProcessor:
json.dumps(update_event),
)
logger.info(
f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}"
)
logger.info(f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}")
except Exception as e:
logger.error(f"Failed to notify message update: {e}")
@@ -320,9 +282,7 @@ async def cache_message_to_redis(channel_id: int, message_data: dict, temp_uuid:
await message_queue_processor.cache_message(channel_id, message_data, temp_uuid)
async def get_cached_messages(
channel_id: int, limit: int = 50, since: int = 0
) -> list[dict]:
async def get_cached_messages(channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
"""从 Redis 获取缓存的消息 - 便捷接口"""
return await message_queue_processor.get_cached_messages(channel_id, limit, since)

View File

@@ -4,10 +4,10 @@
此模块提供在游玩状态下维护用户在线状态的功能,
解决游玩时显示离线的问题。
"""
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
from app.dependencies.database import get_redis
from app.log import logger
@@ -17,32 +17,32 @@ from app.router.v2.stats import REDIS_PLAYING_USERS_KEY, _redis_exec, get_redis_
async def maintain_playing_users_online_status():
"""
维护正在游玩用户的在线状态
定期刷新正在游玩用户的metadata在线标记
确保他们在游玩过程中显示为在线状态。
"""
redis_sync = get_redis_message()
redis_async = get_redis()
try:
# 获取所有正在游玩的用户
playing_users = await _redis_exec(redis_sync.smembers, REDIS_PLAYING_USERS_KEY)
if not playing_users:
return
logger.debug(f"Maintaining online status for {len(playing_users)} playing users")
# 为每个游玩用户刷新metadata在线标记
for user_id in playing_users:
user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id)
metadata_key = f"metadata:online:{user_id_str}"
# 设置或刷新metadata在线标记过期时间为1小时
await redis_async.set(metadata_key, "playing", ex=3600)
logger.debug(f"Updated metadata online status for {len(playing_users)} playing users")
except Exception as e:
logger.error(f"Error maintaining playing users online status: {e}")
@@ -50,11 +50,11 @@ async def maintain_playing_users_online_status():
async def start_online_status_maintenance_task():
"""
启动在线状态维护任务
每5分钟运行一次维护任务确保游玩用户保持在线状态
"""
logger.info("Starting online status maintenance task")
while True:
try:
await maintain_playing_users_online_status()

View File

@@ -3,9 +3,9 @@
此模块负责统一管理用户的在线状态确保用户在连接WebSocket后立即显示为在线。
"""
from __future__ import annotations
import asyncio
from datetime import datetime
from app.dependencies.database import get_redis
@@ -15,92 +15,93 @@ from app.router.v2.stats import add_online_user
class OnlineStatusManager:
"""在线状态管理器"""
@staticmethod
async def set_user_online(user_id: int, hub_type: str = "general") -> None:
"""
设置用户为在线状态
Args:
user_id: 用户ID
hub_type: Hub类型 (metadata, spectator, multiplayer等)
"""
try:
redis = get_redis()
# 1. 添加到在线用户集合
await add_online_user(user_id)
# 2. 设置metadata在线标记这是is_online检查的关键
metadata_key = f"metadata:online:{user_id}"
await redis.set(metadata_key, hub_type, ex=7200) # 2小时过期
# 3. 设置最后活跃时间戳
last_seen_key = f"user:last_seen:{user_id}"
await redis.set(last_seen_key, int(datetime.utcnow().timestamp()), ex=7200)
logger.debug(f"[OnlineStatusManager] User {user_id} set online via {hub_type}")
except Exception as e:
logger.error(f"[OnlineStatusManager] Error setting user {user_id} online: {e}")
@staticmethod
async def refresh_user_online_status(user_id: int, hub_type: str = "active") -> None:
"""
刷新用户的在线状态
Args:
user_id: 用户ID
hub_type: 当前活动类型
"""
try:
redis = get_redis()
# 刷新metadata在线标记
metadata_key = f"metadata:online:{user_id}"
await redis.set(metadata_key, hub_type, ex=7200)
# 刷新最后活跃时间
last_seen_key = f"user:last_seen:{user_id}"
await redis.set(last_seen_key, int(datetime.utcnow().timestamp()), ex=7200)
logger.debug(f"[OnlineStatusManager] Refreshed online status for user {user_id}")
except Exception as e:
logger.error(f"[OnlineStatusManager] Error refreshing user {user_id} status: {e}")
@staticmethod
async def set_user_offline(user_id: int) -> None:
"""
设置用户为离线状态
Args:
user_id: 用户ID
"""
try:
redis = get_redis()
# 删除metadata在线标记
metadata_key = f"metadata:online:{user_id}"
await redis.delete(metadata_key)
# 从在线用户集合中移除
from app.router.v2.stats import remove_online_user
await remove_online_user(user_id)
logger.debug(f"[OnlineStatusManager] User {user_id} set offline")
except Exception as e:
logger.error(f"[OnlineStatusManager] Error setting user {user_id} offline: {e}")
@staticmethod
async def is_user_online(user_id: int) -> bool:
"""
检查用户是否在线
Args:
user_id: 用户ID
Returns:
bool: 用户是否在线
"""
@@ -112,19 +113,19 @@ class OnlineStatusManager:
except Exception as e:
logger.error(f"[OnlineStatusManager] Error checking user {user_id} online status: {e}")
return False
@staticmethod
async def get_online_users_count() -> int:
"""
获取在线用户数量
Returns:
int: 在线用户数量
"""
try:
from app.router.v2.stats import _get_online_users_count
from app.dependencies.database import get_redis
from app.router.v2.stats import _get_online_users_count
redis = get_redis()
return await _get_online_users_count(redis)
except Exception as e:

View File

@@ -50,7 +50,6 @@ class OptimizedMessageService:
Returns:
消息响应对象
"""
assert sender.id is not None
# 准备消息数据
message_data = {
@@ -97,9 +96,7 @@ class OptimizedMessageService:
logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}")
return temp_response
async def get_cached_messages(
self, channel_id: int, limit: int = 50, since: int = 0
) -> list[dict]:
async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]:
"""
获取缓存的消息
@@ -125,9 +122,7 @@ class OptimizedMessageService:
"""
return await self.message_queue.get_message_status(temp_uuid)
async def wait_for_message_persisted(
self, temp_uuid: str, timeout: int = 30
) -> dict | None:
async def wait_for_message_persisted(self, temp_uuid: str, timeout: int = 30) -> dict | None: # noqa: ASYNC109
"""
等待消息持久化到数据库

View File

@@ -4,74 +4,67 @@
from __future__ import annotations
from datetime import UTC, datetime
import json
import secrets
import string
from datetime import datetime, UTC, timedelta
from typing import Optional, Tuple
import json
from app.config import settings
from app.auth import get_password_hash, invalidate_user_tokens
from app.database import User
from app.dependencies.database import with_db
from app.service.email_service import EmailService
from app.service.email_queue import email_queue # 导入邮件队列
from app.log import logger
from app.auth import get_password_hash, invalidate_user_tokens
from app.service.email_queue import email_queue # 导入邮件队列
from app.service.email_service import EmailService
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from redis.asyncio import Redis
from sqlmodel import select
class PasswordResetService:
"""密码重置服务 - 使用Redis管理验证码"""
# Redis键前缀
RESET_CODE_PREFIX = "password_reset:code:" # 存储验证码
RESET_RATE_LIMIT_PREFIX = "password_reset:rate_limit:" # 限制请求频率
def __init__(self):
self.email_service = EmailService()
def generate_reset_code(self) -> str:
"""生成8位重置验证码"""
return ''.join(secrets.choice(string.digits) for _ in range(8))
return "".join(secrets.choice(string.digits) for _ in range(8))
def _get_reset_code_key(self, email: str) -> str:
"""获取验证码Redis键"""
return f"{self.RESET_CODE_PREFIX}{email.lower()}"
def _get_rate_limit_key(self, email: str) -> str:
"""获取频率限制Redis键"""
return f"{self.RESET_RATE_LIMIT_PREFIX}{email.lower()}"
async def request_password_reset(
self,
email: str,
ip_address: str,
user_agent: str,
redis: Redis
) -> Tuple[bool, str]:
self, email: str, ip_address: str, user_agent: str, redis: Redis
) -> tuple[bool, str]:
"""
请求密码重置
Args:
email: 邮箱地址
ip_address: 请求IP
user_agent: 用户代理
redis: Redis连接
Returns:
Tuple[success, message]
"""
email = email.lower().strip()
async with with_db() as session:
# 查找用户
user_query = select(User).where(User.email == email)
user_result = await session.exec(user_query)
user = user_result.first()
if not user:
# 为了安全考虑,不告诉用户邮箱不存在,但仍然要检查频率限制
rate_limit_key = self._get_rate_limit_key(email)
@@ -80,15 +73,15 @@ class PasswordResetService:
# 设置一个假的频率限制,防止恶意用户探测邮箱
await redis.setex(rate_limit_key, 60, "1")
return True, "如果该邮箱地址存在,您将收到密码重置邮件"
# 检查频率限制
rate_limit_key = self._get_rate_limit_key(email)
if await redis.get(rate_limit_key):
return False, "请求过于频繁,请稍后再试"
# 生成重置验证码
reset_code = self.generate_reset_code()
# 存储验证码信息到Redis
reset_code_key = self._get_reset_code_key(email)
reset_data = {
@@ -98,22 +91,18 @@ class PasswordResetService:
"created_at": datetime.now(UTC).isoformat(),
"ip_address": ip_address,
"user_agent": user_agent,
"used": False
"used": False,
}
try:
# 先设置频率限制
await redis.setex(rate_limit_key, 60, "1")
# 存储验证码10分钟过期
await redis.setex(reset_code_key, 600, json.dumps(reset_data))
# 发送重置邮件
email_sent = await self.send_password_reset_email(
email=email,
code=reset_code,
username=user.username
)
email_sent = await self.send_password_reset_email(email=email, code=reset_code, username=user.username)
if email_sent:
logger.info(f"[Password Reset] Sent reset code to user {user.id} ({email})")
return True, "密码重置邮件已发送,请查收邮箱"
@@ -123,17 +112,17 @@ class PasswordResetService:
await redis.delete(rate_limit_key)
logger.warning(f"[Password Reset] Email sending failed, cleaned up Redis data for {email}")
return False, "邮件发送失败,请稍后重试"
except Exception as e:
except Exception:
# Redis操作失败清理可能的部分数据
try:
await redis.delete(reset_code_key)
await redis.delete(rate_limit_key)
except:
except Exception:
pass
logger.error(f"[Password Reset] Redis operation failed: {e}")
logger.exception("[Password Reset] Redis operation failed")
return False, "服务暂时不可用,请稍后重试"
async def send_password_reset_email(self, email: str, code: str, username: str) -> bool:
"""发送密码重置邮件(使用邮件队列)"""
try:
@@ -206,15 +195,15 @@ class PasswordResetService:
<h1>osu! 密码重置</h1>
<p>Password Reset Request</p>
</div>
<div class="content">
<h2>你好 {username}</h2>
<p>我们收到了您的密码重置请求。如果这是您本人操作,请使用以下验证码重置密码:</p>
<div class="code">{code}</div>
<p>这个验证码将在 <strong>10 分钟后过期</strong>。</p>
<div class="danger">
<strong>⚠️ 安全提醒:</strong>
<ul>
@@ -224,19 +213,19 @@ class PasswordResetService:
<li>建议设置一个强密码以保护您的账户安全</li>
</ul>
</div>
<p>如果您有任何问题,请联系我们的支持团队。</p>
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
<h3>Hello {username}!</h3>
<p>We received a request to reset your password. If this was you, please use the following verification code to reset your password:</p>
<p>This verification code will expire in <strong>10 minutes</strong>.</p>
<p><strong>Security Notice:</strong> Do not share this verification code with anyone. If you did not request a password reset, please ignore this email.</p>
</div>
<div class="footer">
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
<p>This email was sent automatically, please do not reply.</p>
@@ -244,8 +233,8 @@ class PasswordResetService:
</div>
</body>
</html>
"""
""" # noqa: E501
# 纯文本内容(作为备用)
plain_content = f"""
你好 {username}
@@ -270,120 +259,123 @@ class PasswordResetService:
# 添加邮件到队列
subject = "密码重置 - Password Reset"
metadata = {"type": "password_reset", "email": email, "code": code}
await email_queue.enqueue_email(
to_email=email,
subject=subject,
content=plain_content,
html_content=html_content,
metadata=metadata
metadata=metadata,
)
logger.info(f"[Password Reset] Enqueued reset code email to {email}")
return True
except Exception as e:
logger.error(f"[Password Reset] Failed to enqueue email: {e}")
return False
async def reset_password(
self,
email: str,
reset_code: str,
new_password: str,
ip_address: str,
redis: Redis
) -> Tuple[bool, str]:
redis: Redis,
) -> tuple[bool, str]:
"""
重置密码
Args:
email: 邮箱地址
reset_code: 重置验证码
new_password: 新密码
ip_address: 请求IP
redis: Redis连接
Returns:
Tuple[success, message]
"""
email = email.lower().strip()
reset_code = reset_code.strip()
async with with_db() as session:
# 从Redis获取验证码数据
reset_code_key = self._get_reset_code_key(email)
reset_data_str = await redis.get(reset_code_key)
if not reset_data_str:
return False, "验证码无效或已过期"
try:
reset_data = json.loads(reset_data_str)
except json.JSONDecodeError:
return False, "验证码数据格式错误"
# 验证验证码
if reset_data.get("reset_code") != reset_code:
return False, "验证码错误"
# 检查是否已使用
if reset_data.get("used", False):
return False, "验证码已使用"
# 验证邮箱匹配
if reset_data.get("email") != email:
return False, "邮箱地址不匹配"
# 查找用户
user_query = select(User).where(User.email == email)
user_result = await session.exec(user_query)
user = user_result.first()
if not user:
return False, "用户不存在"
if user.id is None:
return False, "用户ID无效"
# 验证用户ID匹配
if reset_data.get("user_id") != user.id:
return False, "用户信息不匹配"
# 密码强度检查
if len(new_password) < 6:
return False, "密码长度至少为6位"
try:
# 先标记验证码为已使用(在数据库操作之前)
reset_data["used"] = True
reset_data["used_at"] = datetime.now(UTC).isoformat()
# 保存用户ID用于日志记录
user_id = user.id
# 更新用户密码
password_hash = get_password_hash(new_password)
user.pw_bcrypt = password_hash # 使用正确的字段名称 pw_bcrypt 而不是 password_hash
# 提交数据库更改
await session.commit()
# 使该用户的所有现有令牌失效(使其他客户端登录失效)
tokens_deleted = await invalidate_user_tokens(session, user_id)
# 数据库操作成功后更新Redis状态
await redis.setex(reset_code_key, 300, json.dumps(reset_data)) # 保留5分钟用于日志记录
logger.info(f"[Password Reset] User {user_id} ({email}) successfully reset password from IP {ip_address}, invalidated {tokens_deleted} tokens")
logger.info(
f"[Password Reset] User {user_id} ({email}) successfully reset password from IP {ip_address},"
f" invalidated {tokens_deleted} tokens"
)
return True, "密码重置成功,所有设备已被登出"
except Exception as e:
# 不要在异常处理中访问user.id可能触发数据库操作
user_id = reset_data.get("user_id", "未知")
logger.error(f"[Password Reset] Failed to reset password for user {user_id}: {e}")
await session.rollback()
# 数据库回滚时需要恢复Redis中的验证码状态
try:
# 恢复验证码为未使用状态
@@ -394,35 +386,39 @@ class PasswordResetService:
"created_at": reset_data.get("created_at"),
"ip_address": reset_data.get("ip_address"),
"user_agent": reset_data.get("user_agent"),
"used": False # 恢复为未使用状态
"used": False, # 恢复为未使用状态
}
# 计算剩余的TTL时间
created_at = datetime.fromisoformat(reset_data.get("created_at", ""))
elapsed = (datetime.now(UTC) - created_at).total_seconds()
remaining_ttl = max(0, 600 - int(elapsed)) # 600秒总过期时间
if remaining_ttl > 0:
await redis.setex(reset_code_key, remaining_ttl, json.dumps(original_reset_data))
await redis.setex(
reset_code_key,
remaining_ttl,
json.dumps(original_reset_data),
)
logger.info(f"[Password Reset] Restored Redis state after database rollback for {email}")
else:
# 如果已经过期,直接删除
await redis.delete(reset_code_key)
logger.info(f"[Password Reset] Removed expired reset code after database rollback for {email}")
except Exception as redis_error:
logger.error(f"[Password Reset] Failed to restore Redis state after rollback: {redis_error}")
return False, "密码重置失败,请稍后重试"
async def get_reset_attempts_count(self, email: str, redis: Redis) -> int:
"""
获取邮箱的重置尝试次数(通过检查频率限制键)
Args:
email: 邮箱地址
redis: Redis连接
Returns:
尝试次数
"""

View File

@@ -34,9 +34,7 @@ class DateTimeEncoder(json.JSONEncoder):
def safe_json_dumps(data) -> str:
"""安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(
data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")
)
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
class RankingCacheService:
@@ -225,9 +223,7 @@ class RankingCacheService:
) -> None:
"""刷新排行榜缓存"""
if self._refreshing:
logger.debug(
f"Ranking cache refresh already in progress for {ruleset}:{type}"
)
logger.debug(f"Ranking cache refresh already in progress for {ruleset}:{type}")
return
# 使用配置文件的设置
@@ -253,9 +249,7 @@ class RankingCacheService:
order_by = col(UserStatistics.ranked_score).desc()
if country:
wheres.append(
col(UserStatistics.user).has(country_code=country.upper())
)
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
# 获取总用户数用于统计
total_users_query = select(UserStatistics).where(*wheres)
@@ -277,11 +271,7 @@ class RankingCacheService:
for page in range(1, max_pages + 1):
try:
statistics_list = await session.exec(
select(UserStatistics)
.where(*wheres)
.order_by(order_by)
.limit(50)
.offset(50 * (page - 1))
select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1))
)
statistics_data = statistics_list.all()
@@ -291,9 +281,7 @@ class RankingCacheService:
# 转换为响应格式并确保正确序列化
ranking_data = []
for statistics in statistics_data:
user_stats_resp = await UserStatisticsResp.from_db(
statistics, session, None, include
)
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
# 将 UserStatisticsResp 转换为字典,处理所有序列化问题
user_dict = json.loads(user_stats_resp.model_dump_json())
ranking_data.append(user_dict)
@@ -323,9 +311,7 @@ class RankingCacheService:
) -> None:
"""刷新地区排行榜缓存"""
if self._refreshing:
logger.debug(
f"Country ranking cache refresh already in progress for {ruleset}"
)
logger.debug(f"Country ranking cache refresh already in progress for {ruleset}")
return
if max_pages is None:
@@ -449,9 +435,7 @@ class RankingCacheService:
for country in top_countries:
for mode in game_modes:
for ranking_type in ranking_types:
task = self.refresh_ranking_cache(
session, mode, ranking_type, country
)
task = self.refresh_ranking_cache(session, mode, ranking_type, country)
refresh_tasks.append(task)
# 地区排行榜
@@ -493,9 +477,7 @@ class RankingCacheService:
if keys:
await self.redis.delete(*keys)
deleted_keys += len(keys)
logger.info(
f"Invalidated {len(keys)} cache keys for {ruleset}:{type}"
)
logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}")
elif ruleset:
# 删除特定游戏模式的所有缓存
patterns = [
@@ -563,9 +545,7 @@ class RankingCacheService:
"cached_user_rankings": len(ranking_keys),
"cached_country_rankings": len(country_keys),
"total_cached_rankings": len(total_keys),
"estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
),
"estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0),
"refreshing": self._refreshing,
}
except Exception as e:

View File

@@ -35,12 +35,8 @@ async def recalculate():
fetcher = await get_fetcher()
redis = get_redis()
for mode in GameMode:
await session.execute(
delete(PPBestScore).where(col(PPBestScore.gamemode) == mode)
)
await session.execute(
delete(BestScore).where(col(BestScore.gamemode) == mode)
)
await session.execute(delete(PPBestScore).where(col(PPBestScore.gamemode) == mode))
await session.execute(delete(BestScore).where(col(BestScore.gamemode) == mode))
await session.commit()
logger.info(f"Recalculating for mode: {mode}")
statistics_list = (
@@ -53,32 +49,21 @@ async def recalculate():
).all()
await asyncio.gather(
*[
_recalculate_pp(
statistics.user_id, statistics.mode, session, fetcher, redis
)
_recalculate_pp(statistics.user_id, statistics.mode, session, fetcher, redis)
for statistics in statistics_list
]
)
await asyncio.gather(
*[
_recalculate_best_score(
statistics.user_id, statistics.mode, session
)
_recalculate_best_score(statistics.user_id, statistics.mode, session)
for statistics in statistics_list
]
)
await session.commit()
await asyncio.gather(
*[
_recalculate_statistics(statistics, session)
for statistics in statistics_list
]
)
await asyncio.gather(*[_recalculate_statistics(statistics, session) for statistics in statistics_list])
await session.commit()
logger.success(
f"Recalculated for mode: {mode}, total users: {len(statistics_list)}"
)
logger.success(f"Recalculated for mode: {mode}, total users: {len(statistics_list)}")
async def _recalculate_pp(
@@ -104,9 +89,7 @@ async def _recalculate_pp(
beatmap_id = score.beatmap_id
while time > 0:
try:
db_beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=beatmap_id
)
db_beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=beatmap_id)
except HTTPError:
time -= 1
await asyncio.sleep(2)
@@ -116,9 +99,7 @@ async def _recalculate_pp(
score.pp = 0
return
try:
pp = await pre_fetch_and_calculate_pp(
score, beatmap_id, session, redis, fetcher
)
pp = await pre_fetch_and_calculate_pp(score, beatmap_id, session, redis, fetcher)
score.pp = pp
if pp == 0:
return
@@ -138,15 +119,10 @@ async def _recalculate_pp(
await asyncio.sleep(2)
continue
except Exception:
logger.exception(
f"Error calculating pp for score {score.id} on beatmap {beatmap_id}"
)
logger.exception(f"Error calculating pp for score {score.id} on beatmap {beatmap_id}")
return
if time <= 0:
logger.warning(
f"Failed to fetch beatmap {beatmap_id} after 10 attempts, "
"retrying later..."
)
logger.warning(f"Failed to fetch beatmap {beatmap_id} after 10 attempts, retrying later...")
return score
while len(scores) > 0:
@@ -271,9 +247,7 @@ async def _recalculate_statistics(statistics: UserStatistics, session: AsyncSess
statistics.count_100 += score.n100 + score.nkatu
statistics.count_50 += score.n50
statistics.count_miss += score.nmiss
statistics.total_hits += (
score.n300 + score.ngeki + score.n100 + score.nkatu + score.n50
)
statistics.total_hits += score.n300 + score.ngeki + score.n100 + score.nkatu + score.n50
if ranked and score.passed:
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)

View File

@@ -18,6 +18,7 @@ from app.database.chat import ChatMessage, ChatMessageResp, MessageType
from app.database.lazer_user import RANKING_INCLUDES, User, UserResp
from app.dependencies.database import get_redis_message, with_db
from app.log import logger
from app.utils import bg_tasks
class RedisMessageSystem:
@@ -67,12 +68,11 @@ class RedisMessageSystem:
# 获取频道类型以判断是否需要存储到数据库
async with with_db() as session:
from app.database.chat import ChatChannel, ChannelType
from app.database.chat import ChannelType, ChatChannel
from sqlmodel import select
channel_result = await session.exec(
select(ChatChannel.type).where(ChatChannel.channel_id == channel_id)
)
channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id))
channel_type = channel_result.first()
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
@@ -132,17 +132,14 @@ class RedisMessageSystem:
if is_multiplayer:
logger.info(
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}, will not be persisted to database"
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id},"
" will not be persisted to database"
)
else:
logger.info(
f"Message {message_id} sent to Redis cache for channel {channel_id}"
)
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
return response
async def get_messages(
self, channel_id: int, limit: int = 50, since: int = 0
) -> list[ChatMessageResp]:
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]:
"""
获取频道消息 - 优先从 Redis 获取最新消息
@@ -166,9 +163,7 @@ class RedisMessageSystem:
# 获取发送者信息
sender = await session.get(User, msg_data["sender_id"])
if sender:
user_resp = await UserResp.from_db(
sender, session, RANKING_INCLUDES
)
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp
@@ -223,39 +218,28 @@ class RedisMessageSystem:
async def _generate_message_id(self, channel_id: int) -> int:
"""生成唯一的消息ID - 确保全局唯一且严格递增"""
# 使用全局计数器确保所有频道的消息ID都是严格递增的
message_id = await self._redis_exec(
self.redis.incr, "global_message_id_counter"
)
message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter")
# 同时更新频道的最后消息ID用于客户端状态同步
await self._redis_exec(
self.redis.set, f"channel:{channel_id}:last_msg_id", message_id
)
await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id)
return message_id
async def _store_to_redis(
self, message_id: int, channel_id: int, message_data: dict[str, Any]
):
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]):
"""存储消息到 Redis"""
try:
# 检查是否是多人房间消息
is_multiplayer = message_data.get("is_multiplayer", False)
# 存储消息数据
await self._redis_exec(
self.redis.hset,
f"msg:{channel_id}:{message_id}",
mapping={
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)
for k, v in message_data.items()
},
mapping={k: json.dumps(v) if isinstance(v, dict | list) else str(v) for k, v in message_data.items()},
)
# 设置消息过期时间7天
await self._redis_exec(
self.redis.expire, f"msg:{channel_id}:{message_id}", 604800
)
await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800)
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
channel_messages_key = f"channel:{channel_id}:messages"
@@ -264,14 +248,10 @@ class RedisMessageSystem:
try:
key_type = await self._redis_exec(self.redis.type, channel_messages_key)
if key_type and key_type != "zset":
logger.warning(
f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}"
)
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
await self._redis_exec(self.redis.delete, channel_messages_key)
except Exception as type_check_error:
logger.warning(
f"Failed to check key type for {channel_messages_key}: {type_check_error}"
)
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
# 如果检查失败,直接删除键以确保清理
await self._redis_exec(self.redis.delete, channel_messages_key)
@@ -283,15 +263,11 @@ class RedisMessageSystem:
)
# 保持频道消息列表大小最多1000条
await self._redis_exec(
self.redis.zremrangebyrank, channel_messages_key, 0, -1001
)
await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001)
# 只有非多人房间消息才添加到待持久化队列
if not is_multiplayer:
await self._redis_exec(
self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}"
)
await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}")
logger.debug(f"Message {message_id} added to persistence queue")
else:
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
@@ -300,9 +276,7 @@ class RedisMessageSystem:
logger.error(f"Failed to store message to Redis: {e}")
raise
async def _get_from_redis(
self, channel_id: int, limit: int = 50, since: int = 0
) -> list[dict[str, Any]]:
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]:
"""从 Redis 获取消息"""
try:
# 获取消息键列表按消息ID排序
@@ -340,9 +314,7 @@ class RedisMessageSystem:
# 尝试解析 JSON
try:
if k in ["grade_counts", "level"] or v.startswith(
("{", "[")
):
if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
message_data[k] = json.loads(v)
elif k in ["message_id", "channel_id", "sender_id"]:
message_data[k] = int(v)
@@ -368,9 +340,7 @@ class RedisMessageSystem:
logger.error(f"Failed to get messages from Redis: {e}")
return []
async def _backfill_from_database(
self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int
):
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int):
"""从数据库补充历史消息"""
try:
# 找到最小的消息ID
@@ -404,9 +374,7 @@ class RedisMessageSystem:
except Exception as e:
logger.error(f"Failed to backfill from database: {e}")
async def _get_from_database_only(
self, channel_id: int, limit: int, since: int
) -> list[ChatMessageResp]:
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]:
"""仅从数据库获取消息(回退方案)"""
try:
async with with_db() as session:
@@ -417,20 +385,14 @@ class RedisMessageSystem:
if since > 0:
# 获取指定ID之后的消息按ID正序
query = query.where(col(ChatMessage.message_id) > since)
query = query.order_by(col(ChatMessage.message_id).asc()).limit(
limit
)
query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit)
else:
# 获取最新消息按ID倒序最新的在前面
query = query.order_by(col(ChatMessage.message_id).desc()).limit(
limit
)
query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit)
messages = (await session.exec(query)).all()
results = [
await ChatMessageResp.from_db(msg, session) for msg in messages
]
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
# 如果是 since > 0保持正序否则反转为时间正序
if since == 0:
@@ -451,9 +413,7 @@ class RedisMessageSystem:
# 获取待处理的消息
message_keys = []
for _ in range(self.max_batch_size):
key = await self._redis_exec(
self.redis.brpop, ["pending_messages"], timeout=1
)
key = await self._redis_exec(self.redis.brpop, ["pending_messages"], timeout=1)
if key:
# key 是 (queue_name, value) 的元组
value = key[1]
@@ -483,9 +443,7 @@ class RedisMessageSystem:
channel_id, message_id = map(int, key.split(":"))
# 从 Redis 获取消息数据
raw_data = await self._redis_exec(
self.redis.hgetall, f"msg:{channel_id}:{message_id}"
)
raw_data = await self._redis_exec(self.redis.hgetall, f"msg:{channel_id}:{message_id}")
if not raw_data:
continue
@@ -546,9 +504,7 @@ class RedisMessageSystem:
# 提交批次
try:
await session.commit()
logger.info(
f"Batch of {len(message_keys)} messages committed to database"
)
logger.info(f"Batch of {len(message_keys)} messages committed to database")
except Exception as e:
logger.error(f"Failed to commit message batch: {e}")
await session.rollback()
@@ -559,7 +515,7 @@ class RedisMessageSystem:
self._running = True
self._batch_timer = asyncio.create_task(self._batch_persist_to_database())
# 启动时初始化消息ID计数器
asyncio.create_task(self._initialize_message_counter())
bg_tasks.add_task(self._initialize_message_counter)
logger.info("Redis message system started")
async def _initialize_message_counter(self):
@@ -576,27 +532,19 @@ class RedisMessageSystem:
max_id = result.one() or 0
# 检查 Redis 中的计数器值
current_counter = await self._redis_exec(
self.redis.get, "global_message_id_counter"
)
current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter")
current_counter = int(current_counter) if current_counter else 0
# 设置计数器为两者中的最大值
initial_counter = max(max_id, current_counter)
await self._redis_exec(
self.redis.set, "global_message_id_counter", initial_counter
)
await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter)
logger.info(
f"Initialized global message ID counter to {initial_counter}"
)
logger.info(f"Initialized global message ID counter to {initial_counter}")
except Exception as e:
logger.error(f"Failed to initialize message counter: {e}")
# 如果初始化失败,设置一个安全的起始值
await self._redis_exec(
self.redis.setnx, "global_message_id_counter", 1000000
)
await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000)
async def _cleanup_redis_keys(self):
"""清理可能存在问题的 Redis 键"""
@@ -612,9 +560,7 @@ class RedisMessageSystem:
try:
key_type = await self._redis_exec(self.redis.type, key)
if key_type and key_type != "zset":
logger.warning(
f"Cleaning up Redis key {key} with wrong type: {key_type}"
)
logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}")
await self._redis_exec(self.redis.delete, key)
except Exception as cleanup_error:
logger.warning(f"Failed to cleanup key {key}: {cleanup_error}")

View File

@@ -14,15 +14,11 @@ from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
async def create_playlist_room_from_api(
session: AsyncSession, room: APIUploadedRoom, host_id: int
) -> Room:
async def create_playlist_room_from_api(session: AsyncSession, room: APIUploadedRoom, host_id: int) -> Room:
db_room = room.to_room()
db_room.host_id = host_id
db_room.starts_at = datetime.now(UTC)
db_room.ends_at = db_room.starts_at + timedelta(
minutes=db_room.duration if db_room.duration is not None else 0
)
db_room.ends_at = db_room.starts_at + timedelta(minutes=db_room.duration if db_room.duration is not None else 0)
session.add(db_room)
await session.commit()
await session.refresh(db_room)
@@ -87,13 +83,9 @@ async def create_playlist_room(
return db_room
async def add_playlists_to_room(
session: AsyncSession, room_id: int, playlist: list[Playlist], owner_id: int
):
async def add_playlists_to_room(session: AsyncSession, room_id: int, playlist: list[Playlist], owner_id: int):
for item in playlist:
if not (
await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap)))
).first():
if not (await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap)))).first():
fetcher = await get_fetcher()
await Beatmap.get_or_fetch(session, fetcher, item.beatmap_id)
item.id = await Playlist.get_next_id_for_room(room_id, session)

View File

@@ -4,15 +4,15 @@ API 状态管理 - 模拟 osu! 的 APIState 和会话管理
from __future__ import annotations
from enum import Enum
from typing import Optional
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
class APIState(str, Enum):
"""API 连接状态,对应 osu! 的 APIState"""
OFFLINE = "offline"
CONNECTING = "connecting"
REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证
@@ -22,6 +22,7 @@ class APIState(str, Enum):
class UserSession(BaseModel):
"""用户会话信息"""
user_id: int
username: str
email: str
@@ -38,10 +39,10 @@ class UserSession(BaseModel):
class SessionManager:
"""会话管理器"""
def __init__(self):
self._sessions: dict[str, UserSession] = {}
def create_session(
self,
user_id: int,
@@ -49,19 +50,19 @@ class SessionManager:
email: str,
ip_address: str,
country_code: str | None = None,
is_new_location: bool = False
is_new_location: bool = False,
) -> UserSession:
"""创建新的用户会话"""
import secrets
session_token = secrets.token_urlsafe(32)
# 根据是否为新位置决定初始状态
if is_new_location:
state = APIState.REQUIRES_SECOND_FACTOR_AUTH
else:
state = APIState.ONLINE
session = UserSession(
user_id=user_id,
username=username,
@@ -71,33 +72,33 @@ class SessionManager:
requires_verification=is_new_location,
ip_address=ip_address,
country_code=country_code,
is_new_location=is_new_location
is_new_location=is_new_location,
)
self._sessions[session_token] = session
return session
def get_session(self, session_token: str) -> UserSession | None:
"""获取会话"""
return self._sessions.get(session_token)
def update_session_state(self, session_token: str, state: APIState):
"""更新会话状态"""
if session_token in self._sessions:
self._sessions[session_token].state = state
def mark_verification_sent(self, session_token: str):
"""标记验证邮件已发送"""
if session_token in self._sessions:
session = self._sessions[session_token]
session.verification_sent = True
session.last_verification_attempt = datetime.now()
def increment_failed_attempts(self, session_token: str):
"""增加失败尝试次数"""
if session_token in self._sessions:
self._sessions[session_token].failed_attempts += 1
def verify_session(self, session_token: str) -> bool:
"""验证会话成功"""
if session_token in self._sessions:
@@ -106,11 +107,11 @@ class SessionManager:
session.requires_verification = False
return True
return False
def remove_session(self, session_token: str):
"""移除会话"""
self._sessions.pop(session_token, None)
def cleanup_expired_sessions(self):
"""清理过期会话"""
# 这里可以实现清理逻辑

View File

@@ -26,14 +26,12 @@ async def cleanup_stale_online_users() -> tuple[int, int]:
# 检查在线用户的最后活动时间
current_time = datetime.utcnow()
stale_threshold = current_time - timedelta(hours=2) # 2小时无活动视为过期
stale_threshold = current_time - timedelta(hours=2) # 2小时无活动视为过期 # noqa: F841
# 对于在线用户我们检查metadata在线标记
stale_online_users = []
for user_id in online_users:
user_id_str = (
user_id.decode() if isinstance(user_id, bytes) else str(user_id)
)
user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id)
metadata_key = f"metadata:online:{user_id_str}"
# 如果metadata标记不存在说明用户已经离线
@@ -42,9 +40,7 @@ async def cleanup_stale_online_users() -> tuple[int, int]:
# 清理过期的在线用户
if stale_online_users:
await _redis_exec(
redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users
)
await _redis_exec(redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users)
online_cleaned = len(stale_online_users)
logger.info(f"Cleaned {online_cleaned} stale online users")
@@ -52,22 +48,19 @@ async def cleanup_stale_online_users() -> tuple[int, int]:
# 只有当用户明确不在任何hub连接中时才移除
stale_playing_users = []
for user_id in playing_users:
user_id_str = (
user_id.decode() if isinstance(user_id, bytes) else str(user_id)
)
user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id)
metadata_key = f"metadata:online:{user_id_str}"
# 只有当metadata在线标记完全不存在且用户也不在在线列表中时
# 才认为用户真正离线
if (not await redis_async.exists(metadata_key) and
user_id_str not in [u.decode() if isinstance(u, bytes) else str(u) for u in online_users]):
if not await redis_async.exists(metadata_key) and user_id_str not in [
u.decode() if isinstance(u, bytes) else str(u) for u in online_users
]:
stale_playing_users.append(user_id_str)
# 清理过期的游玩用户
if stale_playing_users:
await _redis_exec(
redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users
)
await _redis_exec(redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users)
playing_cleaned = len(stale_playing_users)
logger.info(f"Cleaned {playing_cleaned} stale playing users")

View File

@@ -61,26 +61,29 @@ class StatsScheduler:
try:
# 计算下次区间结束时间
now = datetime.utcnow()
# 计算当前区间的结束时间
current_minute = (now.minute // 30) * 30
current_interval_end = now.replace(minute=current_minute, second=0, microsecond=0) + timedelta(minutes=30)
current_interval_end = now.replace(minute=current_minute, second=0, microsecond=0) + timedelta(
minutes=30
)
# 如果当前时间已经超过了当前区间结束时间,说明需要等待下一个区间结束
if now >= current_interval_end:
current_interval_end += timedelta(minutes=30)
# 计算需要等待的时间
sleep_seconds = (current_interval_end - now).total_seconds()
# 添加小的缓冲时间,确保区间真正结束后再处理
sleep_seconds += 10 # 额外等待10秒
# 限制等待时间范围
sleep_seconds = max(min(sleep_seconds, 32 * 60), 10)
logger.debug(
f"Next interval finalization in {sleep_seconds / 60:.1f} minutes at {current_interval_end.strftime('%H:%M:%S')}"
f"Next interval finalization in {sleep_seconds / 60:.1f} "
f"minutes at {current_interval_end.strftime('%H:%M:%S')}"
)
await asyncio.sleep(sleep_seconds)
@@ -137,7 +140,8 @@ class StatsScheduler:
online_cleaned, playing_cleaned = await cleanup_stale_online_users()
if online_cleaned > 0 or playing_cleaned > 0:
logger.info(
f"Initial cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users"
f"Initial cleanup: removed {online_cleaned} stale online users,"
f" {playing_cleaned} stale playing users"
)
await refresh_redis_key_expiry()

View File

@@ -31,9 +31,7 @@ class RedisSubscriber:
async def listen(self):
while True:
message = await self.pubsub.get_message(
ignore_subscribe_messages=True, timeout=None
)
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=None)
if message is not None and message["type"] == "message":
matched_handlers: list[Callable[[str, str], Awaitable[Any]]] = []
@@ -53,10 +51,7 @@ class RedisSubscriber:
if matched_handlers:
await asyncio.gather(
*[
handler(message["channel"], message["data"])
for handler in matched_handlers
]
*[handler(message["channel"], message["data"]) for handler in matched_handlers]
)
def start(self):

View File

@@ -46,12 +46,7 @@ class ScoreSubscriber(RedisSubscriber):
return
async with with_db() as session:
score = await session.get(Score, score_id)
if (
not score
or not score.passed
or score.room_id is None
or score.playlist_item_id is None
):
if not score or not score.passed or score.room_id is None or score.playlist_item_id is None:
return
if not self.room_subscriber.get(score.room_id, []):
return

View File

@@ -47,17 +47,13 @@ class UserCacheService:
self._refreshing = False
self._background_tasks: set = set()
def _get_v1_user_cache_key(
self, user_id: int, ruleset: GameMode | None = None
) -> str:
def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str:
"""生成 V1 用户缓存键"""
if ruleset:
return f"v1_user:{user_id}:ruleset:{ruleset}"
return f"v1_user:{user_id}"
async def get_v1_user_from_cache(
self, user_id: int, ruleset: GameMode | None = None
) -> dict | None:
async def get_v1_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> dict | None:
"""从缓存获取 V1 用户信息"""
try:
cache_key = self._get_v1_user_cache_key(user_id, ruleset)
@@ -96,9 +92,7 @@ class UserCacheService:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
logger.info(
f"Invalidated {len(keys)} V1 cache entries for user {user_id}"
)
logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}")
except Exception as e:
logger.error(f"Error invalidating V1 user cache: {e}")
@@ -126,9 +120,7 @@ class UserCacheService:
"""生成用户谱面集缓存键"""
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
async def get_user_from_cache(
self, user_id: int, ruleset: GameMode | None = None
) -> UserResp | None:
async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserResp | None:
"""从缓存获取用户信息"""
try:
cache_key = self._get_user_cache_key(user_id, ruleset)
@@ -172,14 +164,10 @@ class UserCacheService:
) -> list[ScoreResp] | None:
"""从缓存获取用户成绩"""
try:
cache_key = self._get_user_scores_cache_key(
user_id, score_type, mode, limit, offset
)
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(
f"User scores cache hit for user {user_id}, type {score_type}"
)
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
data = json.loads(cached_data)
return [ScoreResp(**score_data) for score_data in data]
return None
@@ -201,16 +189,12 @@ class UserCacheService:
try:
if expire_seconds is None:
expire_seconds = settings.user_scores_cache_expire_seconds
cache_key = self._get_user_scores_cache_key(
user_id, score_type, mode, limit, offset
)
cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset)
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
scores_json_list = [score.model_dump_json() for score in scores]
cached_data = f"[{','.join(scores_json_list)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(
f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s"
)
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
except Exception as e:
logger.error(f"Error caching user scores: {e}")
@@ -219,14 +203,10 @@ class UserCacheService:
) -> list[Any] | None:
"""从缓存获取用户谱面集"""
try:
cache_key = self._get_user_beatmapsets_cache_key(
user_id, beatmapset_type, limit, offset
)
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(
f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}"
)
logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}")
return json.loads(cached_data)
return None
except Exception as e:
@@ -246,9 +226,7 @@ class UserCacheService:
try:
if expire_seconds is None:
expire_seconds = settings.user_beatmapsets_cache_expire_seconds
cache_key = self._get_user_beatmapsets_cache_key(
user_id, beatmapset_type, limit, offset
)
cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset)
# 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps
serialized_beatmapsets = []
for bms in beatmapsets:
@@ -258,9 +236,7 @@ class UserCacheService:
serialized_beatmapsets.append(safe_json_dumps(bms))
cached_data = f"[{','.join(serialized_beatmapsets)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(
f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s"
)
logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s")
except Exception as e:
logger.error(f"Error caching user beatmapsets: {e}")
@@ -276,9 +252,7 @@ class UserCacheService:
except Exception as e:
logger.error(f"Error invalidating user cache: {e}")
async def invalidate_user_scores_cache(
self, user_id: int, mode: GameMode | None = None
):
async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None):
"""使用户成绩缓存失效"""
try:
# 删除用户成绩相关缓存
@@ -287,9 +261,7 @@ class UserCacheService:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
logger.info(
f"Invalidated {len(keys)} score cache entries for user {user_id}"
)
logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}")
except Exception as e:
logger.error(f"Error invalidating user scores cache: {e}")
@@ -303,9 +275,7 @@ class UserCacheService:
logger.info(f"Preloading cache for {len(user_ids)} users")
# 批量获取用户
users = (
await session.exec(select(User).where(col(User.id).in_(user_ids)))
).all()
users = (await session.exec(select(User).where(col(User.id).in_(user_ids)))).all()
# 串行缓存用户信息,避免并发数据库访问问题
cached_count = 0
@@ -332,9 +302,7 @@ class UserCacheService:
except Exception as e:
logger.error(f"Error caching single user {user.id}: {e}")
async def refresh_user_cache_on_score_submit(
self, session: AsyncSession, user_id: int, mode: GameMode
):
async def refresh_user_cache_on_score_submit(self, session: AsyncSession, user_id: int, mode: GameMode):
"""成绩提交后刷新用户缓存"""
try:
# 使相关缓存失效(包括 v1 和 v2
@@ -367,24 +335,12 @@ class UserCacheService:
continue
return {
"cached_users": len(
[
k
for k in user_keys
if ":scores:" not in k and ":beatmapsets:" not in k
]
),
"cached_v1_users": len(
[k for k in v1_user_keys if ":scores:" not in k]
),
"cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]),
"cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]),
"cached_user_scores": len([k for k in user_keys if ":scores:" in k]),
"cached_user_beatmapsets": len(
[k for k in user_keys if ":beatmapsets:" in k]
),
"cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]),
"total_cached_entries": len(all_keys),
"estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
),
"estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0),
"refreshing": self._refreshing,
}
except Exception as e: