diff --git a/app/calculator.py b/app/calculator.py index 4fde08e..da2d258 100644 --- a/app/calculator.py +++ b/app/calculator.py @@ -86,26 +86,34 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f f"Error checking if beatmap {score.beatmap_id} is suspicious" ) - map = rosu.Beatmap(content=beatmap) - mods = deepcopy(score.mods.copy()) - parse_enum_to_str(int(score.gamemode), mods) - map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType] - perf = rosu.Performance( - mods=mods, - lazer=True, - accuracy=clamp(score.accuracy * 100, 0, 100), - combo=score.max_combo, - large_tick_hits=score.nlarge_tick_hit or 0, - slider_end_hits=score.nslider_tail_hit or 0, - small_tick_hits=score.nsmall_tick_hit or 0, - n_geki=score.ngeki, - n_katu=score.nkatu, - n300=score.n300, - n100=score.n100, - n50=score.n50, - misses=score.nmiss, - ) - attrs = perf.calculate(map) + # 使用线程池执行计算密集型操作以避免阻塞事件循环 + import asyncio + loop = asyncio.get_event_loop() + + def _calculate_pp_sync(): + map = rosu.Beatmap(content=beatmap) + mods = deepcopy(score.mods.copy()) + parse_enum_to_str(int(score.gamemode), mods) + map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType] + perf = rosu.Performance( + mods=mods, + lazer=True, + accuracy=clamp(score.accuracy * 100, 0, 100), + combo=score.max_combo, + large_tick_hits=score.nlarge_tick_hit or 0, + slider_end_hits=score.nslider_tail_hit or 0, + small_tick_hits=score.nsmall_tick_hit or 0, + n_geki=score.ngeki, + n_katu=score.nkatu, + n300=score.n300, + n100=score.n100, + n50=score.n50, + misses=score.nmiss, + ) + return perf.calculate(map) + + # 在线程池中执行计算 + attrs = await loop.run_in_executor(None, _calculate_pp_sync) pp = attrs.pp # mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp @@ -122,6 +130,132 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f return pp +async def pre_fetch_and_calculate_pp( + score: "Score", + beatmap_id: int, + session: AsyncSession, + redis, + fetcher +) -> float: + """ + 优化版PP计算:预先获取beatmap文件并使用缓存 + """ + import asyncio + + from app.database.beatmap import BannedBeatmaps + + # 快速检查是否被封禁 + if settings.suspicious_score_check: + beatmap_banned = ( + await session.exec( + select(exists()).where( + col(BannedBeatmaps.beatmap_id) == beatmap_id + ) + ) + ).first() + if beatmap_banned: + return 0 + + # 异步获取beatmap原始文件,利用已有的Redis缓存机制 + try: + beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) + except Exception as e: + logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}") + return 0 + + # 在获取文件的同时,可以检查可疑beatmap + if settings.suspicious_score_check: + try: + # 将可疑检查也移到线程池中执行 + def _check_suspicious(): + return is_suspicious_beatmap(beatmap_raw) + + loop = asyncio.get_event_loop() + is_sus = await loop.run_in_executor(None, _check_suspicious) + if is_sus: + session.add(BannedBeatmaps(beatmap_id=beatmap_id)) + logger.warning(f"Beatmap {beatmap_id} is suspicious, banned") + return 0 + except Exception: + logger.exception(f"Error checking if beatmap {beatmap_id} is suspicious") + + # 调用已优化的PP计算函数 + return await calculate_pp(score, beatmap_raw, session) + + +async def batch_calculate_pp( + scores_data: list[tuple["Score", int]], + session: AsyncSession, + redis, + fetcher +) -> list[float]: + """ + 批量计算PP:适用于重新计算或批量处理场景 + Args: + scores_data: [(score, beatmap_id), ...] 的列表 + Returns: + 对应的PP值列表 + """ + import asyncio + + from app.database.beatmap import BannedBeatmaps + + if not scores_data: + return [] + + # 提取所有唯一的beatmap_id + unique_beatmap_ids = list({beatmap_id for _, beatmap_id in scores_data}) + + # 批量检查被封禁的beatmap + banned_beatmaps = set() + if settings.suspicious_score_check: + banned_results = await session.exec( + select(BannedBeatmaps.beatmap_id).where( + col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids) + ) + ) + banned_beatmaps = set(banned_results.all()) + + # 并发获取所有需要的beatmap原始文件 + async def fetch_beatmap_safe(beatmap_id: int) -> tuple[int, str | None]: + if beatmap_id in banned_beatmaps: + return beatmap_id, None + try: + content = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) + return beatmap_id, content + except Exception as e: + logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}") + return beatmap_id, None + + # 并发获取所有beatmap文件 + fetch_tasks = [fetch_beatmap_safe(bid) for bid in unique_beatmap_ids] + fetch_results = await asyncio.gather(*fetch_tasks, return_exceptions=True) + + # 构建beatmap_id -> content的映射 + beatmap_contents = {} + for result in fetch_results: + if isinstance(result, tuple): + beatmap_id, content = result + beatmap_contents[beatmap_id] = content + + # 为每个score计算PP + pp_results = [] + for score, beatmap_id in scores_data: + beatmap_content = beatmap_contents.get(beatmap_id) + if beatmap_content is None: + pp_results.append(0.0) + continue + + try: + pp = await calculate_pp(score, beatmap_content, session) + pp_results.append(pp) + except Exception as e: + logger.error(f"Failed to calculate PP for score {score.id}: {e}") + pp_results.append(0.0) + + return pp_results + + # https://osu.ppy.sh/wiki/Gameplay/Score/Total_score def calculate_level_to_score(n: int) -> float: if n <= 100: diff --git a/app/config.py b/app/config.py index 562722a..c12f405 100644 --- a/app/config.py +++ b/app/config.py @@ -137,6 +137,13 @@ class Settings(BaseSettings): enable_supporter_for_all_users: bool = False enable_all_beatmap_leaderboard: bool = False enable_all_beatmap_pp: bool = False + # 性能优化设置 + enable_beatmap_preload: bool = True + beatmap_cache_expire_hours: int = 24 + max_concurrent_pp_calculations: int = 10 + enable_pp_calculation_threading: bool = True + + # 反作弊设置 suspicious_score_check: bool = True seasonal_backgrounds: Annotated[list[str], BeforeValidator(_parse_list)] = [] banned_name: list[str] = [ diff --git a/app/database/score.py b/app/database/score.py index 40d8235..a445dfe 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -5,7 +5,6 @@ import math from typing import TYPE_CHECKING, Any from app.calculator import ( - calculate_pp, calculate_pp_weight, calculate_score_to_level, calculate_weighted_acc, @@ -772,8 +771,10 @@ async def process_score( maximum_statistics=info.maximum_statistics, ) if can_get_pp: - beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) - pp = await calculate_pp(score, beatmap_raw, session) + from app.calculator import pre_fetch_and_calculate_pp + pp = await pre_fetch_and_calculate_pp( + score, beatmap_id, session, redis, fetcher + ) score.pp = pp session.add(score) user_id = user.id @@ -799,5 +800,5 @@ async def process_score( await session.refresh(score) await session.refresh(score_token) await session.refresh(user) - await redis.publish("score:processed", score.id) + await redis.publish("score:processed", str(score.id or 0)) return score diff --git a/app/fetcher/beatmap_raw.py b/app/fetcher/beatmap_raw.py index f650df3..985fc48 100644 --- a/app/fetcher/beatmap_raw.py +++ b/app/fetcher/beatmap_raw.py @@ -37,8 +37,20 @@ class BeatmapRawFetcher(BaseFetcher): async def get_or_fetch_beatmap_raw( self, redis: redis.Redis, beatmap_id: int ) -> str: - if await redis.exists(f"beatmap:{beatmap_id}:raw"): - return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] + from app.config import settings + + cache_key = f"beatmap:{beatmap_id}:raw" + cache_expire = settings.beatmap_cache_expire_hours * 60 * 60 + + # 检查缓存 + if await redis.exists(cache_key): + content = await redis.get(cache_key) + if content: + # 延长缓存时间 + await redis.expire(cache_key, cache_expire) + return content # pyright: ignore[reportReturnType] + + # 获取并缓存 raw = await self.get_beatmap_raw(beatmap_id) - await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) + await redis.set(cache_key, raw, ex=cache_expire) return raw diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py index 0853e6d..f87f2ed 100644 --- a/app/fetcher/beatmapset.py +++ b/app/fetcher/beatmapset.py @@ -32,8 +32,6 @@ class BeatmapsetFetcher(BaseFetcher): q="", s="leaderboard", sort=sort, # type: ignore - # 不设置 nsfw 和 m,让它们使用默认值 - # 这样 exclude_defaults=True 时它们会被排除 ) homepage_queries.append((query, {})) diff --git a/app/router/v2/score.py b/app/router/v2/score.py index 9033ff5..6f9bb5e 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -37,6 +37,7 @@ from app.dependencies.fetcher import get_fetcher from app.dependencies.storage import get_storage_service from app.dependencies.user import get_client_user, get_current_user from app.fetcher import Fetcher +from app.log import logger from app.models.room import RoomCategory from app.models.score import ( GameMode, @@ -95,6 +96,14 @@ async def submit_score( if not score: raise HTTPException(status_code=404, detail="Score not found") else: + # 智能预取beatmap缓存(异步进行,不阻塞主流程) + try: + from app.service.beatmap_cache_service import get_beatmap_cache_service + cache_service = get_beatmap_cache_service(redis, fetcher) + await cache_service.smart_preload_for_score(beatmap) + except Exception as e: + logger.debug(f"Beatmap preload failed for {beatmap}: {e}") + try: db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap) except HTTPError: diff --git a/app/service/beatmap_cache_service.py b/app/service/beatmap_cache_service.py new file mode 100644 index 0000000..a4f6c17 --- /dev/null +++ b/app/service/beatmap_cache_service.py @@ -0,0 +1,174 @@ +""" +Beatmap缓存预取服务 +用于提前缓存热门beatmap,减少成绩计算时的获取延迟 +""" +from __future__ import annotations + +import asyncio +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING + +from app.config import settings +from app.log import logger + +from redis.asyncio import Redis +from sqlmodel import col, func, select +from sqlmodel.ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + from app.fetcher import Fetcher + + +class BeatmapCacheService: + def __init__(self, redis: Redis, fetcher: "Fetcher"): + self.redis = redis + self.fetcher = fetcher + self._preloading = False + self._background_tasks: set = set() + + async def preload_popular_beatmaps(self, session: AsyncSession, limit: int = 100): + """ + 预加载热门beatmap到Redis缓存 + """ + if self._preloading: + logger.info("Beatmap preloading already in progress") + return + + self._preloading = True + try: + logger.info(f"Starting preload of top {limit} popular beatmaps") + + # 获取过去24小时内最热门的beatmap + recent_time = datetime.now(UTC) - timedelta(hours=24) + + from app.database.score import Score + + popular_beatmaps = ( + await session.exec( + select(Score.beatmap_id, func.count().label("play_count")) + .where(col(Score.ended_at) >= recent_time) + .group_by(col(Score.beatmap_id)) + .order_by(col("play_count").desc()) + .limit(limit) + ) + ).all() + + # 并发预取这些beatmap + preload_tasks = [] + for beatmap_id, _ in popular_beatmaps: + task = self._preload_single_beatmap(beatmap_id) + preload_tasks.append(task) + + if preload_tasks: + results = await asyncio.gather(*preload_tasks, return_exceptions=True) + success_count = sum(1 for r in results if r is True) + logger.info( + f"Preloaded {success_count}/{len(preload_tasks)} " + f"beatmaps successfully" + ) + + except Exception as e: + logger.error(f"Error during beatmap preloading: {e}") + finally: + self._preloading = False + + async def _preload_single_beatmap(self, beatmap_id: int) -> bool: + """ + 预加载单个beatmap + """ + try: + cache_key = f"beatmap:{beatmap_id}:raw" + if await self.redis.exists(cache_key): + # 已经在缓存中,延长过期时间 + await self.redis.expire(cache_key, 60 * 60 * 24) + return True + + # 获取并缓存beatmap + content = await self.fetcher.get_beatmap_raw(beatmap_id) + await self.redis.set(cache_key, content, ex=60 * 60 * 24) + return True + + except Exception as e: + logger.debug(f"Failed to preload beatmap {beatmap_id}: {e}") + return False + + async def smart_preload_for_score(self, beatmap_id: int): + """ + 智能预加载:为即将提交的成绩预加载beatmap + """ + task = asyncio.create_task(self._preload_single_beatmap(beatmap_id)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + async def get_cache_stats(self) -> dict: + """ + 获取缓存统计信息 + """ + try: + keys = await self.redis.keys("beatmap:*:raw") + total_size = 0 + + for key in keys[:100]: # 限制检查数量以避免性能问题 + try: + size = await self.redis.memory_usage(key) + if size: + total_size += size + except Exception: + continue + + return { + "cached_beatmaps": len(keys), + "estimated_total_size_mb": ( + round(total_size / 1024 / 1024, 2) if total_size > 0 else 0 + ), + "preloading": self._preloading, + } + except Exception as e: + logger.error(f"Error getting cache stats: {e}") + return {"error": str(e)} + + async def cleanup_old_cache(self, max_age_hours: int = 48): + """ + 清理过期的缓存 + """ + try: + logger.info(f"Cleaning up beatmap cache older than {max_age_hours} hours") + # Redis会自动清理过期的key,这里主要是记录日志 + keys = await self.redis.keys("beatmap:*:raw") + logger.info(f"Current cache contains {len(keys)} beatmaps") + except Exception as e: + logger.error(f"Error during cache cleanup: {e}") + + +# 全局缓存服务实例 +_cache_service: BeatmapCacheService | None = None + + +def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheService: + """ + 获取beatmap缓存服务实例 + """ + global _cache_service + if _cache_service is None: + _cache_service = BeatmapCacheService(redis, fetcher) + return _cache_service + + +async def schedule_preload_task( + session: AsyncSession, + redis: Redis, + fetcher: "Fetcher" +): + """ + 定时预加载任务 + """ + # 默认启用预加载,除非明确禁用 + enable_preload = getattr(settings, "enable_beatmap_preload", True) + if not enable_preload: + return + + cache_service = get_beatmap_cache_service(redis, fetcher) + try: + await cache_service.preload_popular_beatmaps(session, limit=200) + except Exception as e: + logger.error(f"Scheduled preload task failed: {e}") diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index ed1b48b..5bee18e 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -222,6 +222,10 @@ class SpectatorHub(Hub[StoreClientState]): ) ) logger.info(f"[SpectatorHub] {client.user_id} began playing {state.beatmap_id}") + + # 预缓存beatmap文件以加速后续PP计算 + await self._preload_beatmap_for_pp_calculation(state.beatmap_id) + await self.broadcast_group_call( self.group_id(user_id), "UserBeganPlaying", @@ -446,3 +450,50 @@ class SpectatorHub(Hub[StoreClientState]): if (target_client := self.get_client_by_id(str(target_id))) is not None: await self.call_noblock(target_client, "UserEndedWatching", user_id) logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}") + + async def _preload_beatmap_for_pp_calculation(self, beatmap_id: int) -> None: + """ + 预缓存beatmap文件以加速PP计算 + 当玩家开始游玩时异步预加载beatmap原始文件到Redis缓存 + """ + # 检查是否启用了beatmap预加载功能 + if not settings.enable_beatmap_preload: + return + + try: + # 异步获取fetcher和redis连接 + from app.dependencies.database import get_redis + from app.dependencies.fetcher import get_fetcher + + fetcher = get_fetcher() + redis = get_redis() + + # 检查是否已经缓存,避免重复下载 + cache_key = f"beatmap:raw:{beatmap_id}" + if await redis.exists(cache_key): + logger.debug(f"Beatmap {beatmap_id} already cached, skipping preload") + return + + # 在后台异步预缓存beatmap文件,存储任务引用防止被回收 + task = asyncio.create_task( + self._fetch_beatmap_background(fetcher, redis, beatmap_id) + ) + # 任务完成后自动清理,避免内存泄漏 + task.add_done_callback(lambda t: None) + + except Exception as e: + # 预缓存失败不应该影响正常游戏流程 + logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}") + + async def _fetch_beatmap_background(self, fetcher, redis, beatmap_id: int) -> None: + """ + 后台获取beatmap文件 + """ + try: + # 使用fetcher的get_or_fetch_beatmap_raw方法预缓存 + await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) + logger.debug( + f"Successfully preloaded beatmap {beatmap_id} for PP calculation" + ) + except Exception as e: + logger.debug(f"Failed to preload beatmap {beatmap_id}: {e}")