diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 62e4e17..71a4df3 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -257,3 +257,11 @@ async def calculate_beatmap_attributes( attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset) await redis.set(key, attr.model_dump_json()) return attr + + +async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []): + if beatmaps: + keys = [f"beatmap:{bid}:raw" for bid in beatmaps] + await redis.delete(*keys) + return + await redis.delete("beatmap:*:raw") diff --git a/tools/recalculate.py b/tools/recalculate.py index 8c9a4cd..2fe42c5 100644 --- a/tools/recalculate.py +++ b/tools/recalculate.py @@ -19,7 +19,7 @@ from app.calculators.performance import CalculateError from app.config import settings from app.const import BANCHOBOT_ID from app.database import TotalScoreBestScore, UserStatistics -from app.database.beatmap import Beatmap, calculate_beatmap_attributes +from app.database.beatmap import Beatmap, calculate_beatmap_attributes, clear_cached_beatmap_raws from app.database.best_scores import BestScore from app.database.score import Score, calculate_playtime, calculate_user_pp from app.dependencies.database import engine, get_redis @@ -40,11 +40,72 @@ logger = log("Recalculate") warnings.filterwarnings("ignore") +class BeatmapCacheManager: + """管理beatmap缓存,确保不超过指定数量""" + + def __init__(self, max_count: int, additional_count: int, redis: Redis): + self.max_count = max_count + self.additional_count = additional_count + self.redis = redis + self.beatmap_ids: list[int] = [] # 记录处理的beatmap id(按顺序) + self.beatmap_id_set: set[int] = set() # 用于快速查找(唯一性) + self.lock = asyncio.Lock() + + async def add_beatmap(self, beatmap_id: int) -> None: + """添加beatmap到缓存跟踪列表""" + if self.max_count <= 0: # 不限制 + return + + async with self.lock: + # 如果已经存在,不重复添加 + if beatmap_id in self.beatmap_id_set: + return + + self.beatmap_ids.append(beatmap_id) + self.beatmap_id_set.add(beatmap_id) + + # 检查是否需要清理 + threshold = self.max_count + max(0, self.additional_count) + if len(self.beatmap_ids) > threshold: + # 计算需要删除的数量 + to_remove_count = max(1, self.additional_count) + await self._cleanup(to_remove_count) + + async def _cleanup(self, count: int) -> None: + """清理最早的count个beatmap缓存""" + if count <= 0 or not self.beatmap_ids: + return + + # 获取要删除的beatmap ids + to_remove = self.beatmap_ids[:count] + self.beatmap_ids = self.beatmap_ids[count:] + + # 从set中移除 + for bid in to_remove: + self.beatmap_id_set.discard(bid) + + # 从Redis中删除缓存 + await clear_cached_beatmap_raws(self.redis, to_remove) + logger.info(f"Cleaned up {len(to_remove)} beatmap caches (total: {len(self.beatmap_ids)})") + + def get_stats(self) -> dict: + """获取统计信息""" + threshold = self.max_count + max(0, self.additional_count) if self.max_count > 0 else "unlimited" + return { + "total_beatmaps": len(self.beatmap_ids), + "max_count": self.max_count, + "additional_count": self.additional_count, + "threshold": threshold, + } + + @dataclass(frozen=True) class GlobalConfig: dry_run: bool concurrency: int output_csv: str | None + max_cached_beatmaps_count: int + additional_count: int @dataclass(frozen=True) @@ -93,6 +154,20 @@ def parse_cli_args( type=str, help="Output results to a CSV file at the specified path", ) + parser.add_argument( + "--max-cached-beatmaps-count", + dest="max_cached_beatmaps_count", + type=int, + default=1500, + help="Maximum number of beatmaps to cache (<=0 means no limit)", + ) + parser.add_argument( + "--additional-count", + dest="additional_count", + type=int, + default=100, + help="Number of additional beatmaps before cleanup (<=0 means cleanup immediately)", + ) subparsers = parser.add_subparsers(dest="command", help="Available commands") @@ -194,6 +269,8 @@ def parse_cli_args( dry_run=args.dry_run, concurrency=max(1, args.concurrency), output_csv=args.output_csv, + max_cached_beatmaps_count=args.max_cached_beatmaps_count, + additional_count=args.additional_count, ) if args.command == "all": @@ -534,6 +611,7 @@ async def recalc_score_pp( fetcher: Fetcher, redis: Redis, score: Score, + cache_manager: BeatmapCacheManager | None = None, ) -> float | None: attempts = 10 while attempts > 0: @@ -558,6 +636,9 @@ async def recalc_score_pp( try: beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, score.beatmap_id) + # 记录使用的beatmap + if cache_manager: + await cache_manager.add_beatmap(score.beatmap_id) new_pp = await calculate_pp(score, beatmap_raw, session) score.pp = new_pp return new_pp @@ -728,6 +809,7 @@ async def recalculate_user_mode_performance( fetcher: Fetcher, redis: Redis, semaphore: asyncio.Semaphore, + cache_manager: BeatmapCacheManager | None = None, csv_writer: CSVWriter | None = None, ) -> None: """Recalculate performance points and best scores (without TotalScoreBestScore).""" @@ -767,7 +849,7 @@ async def recalculate_user_mode_performance( for score in passed_scores: if target_set and score.id not in target_set: continue - result_pp = await recalc_score_pp(session, fetcher, redis, score) + result_pp = await recalc_score_pp(session, fetcher, redis, score, cache_manager) if result_pp is None: failed += 1 else: @@ -969,6 +1051,7 @@ async def recalculate_beatmap_rating( fetcher: Fetcher, redis: Redis, semaphore: asyncio.Semaphore, + cache_manager: BeatmapCacheManager | None = None, csv_writer: CSVWriter | None = None, ) -> None: """Recalculate difficulty rating for a beatmap.""" @@ -989,6 +1072,9 @@ async def recalculate_beatmap_rating( try: ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode attributes = await calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher) + # 记录使用的beatmap + if cache_manager: + await cache_manager.add_beatmap(beatmap_id) beatmap.difficulty_rating = attributes.star_rating break except CalculateError as exc: @@ -1070,6 +1156,14 @@ async def recalculate_performance( logger.info("No targets matched the provided filters; nothing to recalculate") return + # 创建缓存管理器 + cache_manager = BeatmapCacheManager( + max_count=global_config.max_cached_beatmaps_count, + additional_count=global_config.additional_count, + redis=redis, + ) + logger.info(f"Beatmap cache manager initialized: {cache_manager.get_stats()}") + scope = "full" if config.recalculate_all else "filtered" logger.info( "Recalculating performance for {} user/mode pairs ({}) | dry-run={} | concurrency={}", @@ -1083,12 +1177,15 @@ async def recalculate_performance( semaphore = asyncio.Semaphore(global_config.concurrency) coroutines = [ recalculate_user_mode_performance( - user_id, mode, score_ids, global_config, fetcher, redis, semaphore, csv_writer + user_id, mode, score_ids, global_config, fetcher, redis, semaphore, cache_manager, csv_writer ) for (user_id, mode), score_ids in targets.items() ] await run_in_batches(coroutines, global_config.concurrency) + # 显示最终统计 + logger.info(f"Beatmap cache final stats: {cache_manager.get_stats()}") + async def recalculate_leaderboard( config: LeaderboardConfig, @@ -1146,6 +1243,14 @@ async def recalculate_rating( logger.info("No beatmaps matched the provided filters; nothing to recalculate") return + # 创建缓存管理器 + cache_manager = BeatmapCacheManager( + max_count=global_config.max_cached_beatmaps_count, + additional_count=global_config.additional_count, + redis=redis, + ) + logger.info(f"Beatmap cache manager initialized: {cache_manager.get_stats()}") + scope = "full" if config.recalculate_all else "filtered" logger.info( "Recalculating rating for {} beatmaps ({}) | dry-run={} | concurrency={}", @@ -1158,11 +1263,14 @@ async def recalculate_rating( async with CSVWriter(global_config.output_csv) as csv_writer: semaphore = asyncio.Semaphore(global_config.concurrency) coroutines = [ - recalculate_beatmap_rating(beatmap_id, global_config, fetcher, redis, semaphore, csv_writer) + recalculate_beatmap_rating(beatmap_id, global_config, fetcher, redis, semaphore, cache_manager, csv_writer) for beatmap_id in beatmap_ids ] await run_in_batches(coroutines, global_config.concurrency) + # 显示最终统计 + logger.info(f"Beatmap cache final stats: {cache_manager.get_stats()}") + def _get_csv_path_for_subcommand(base_path: str | None, subcommand: str) -> str | None: """Generate a CSV path with subcommand name inserted before extension.""" @@ -1197,6 +1305,8 @@ async def main() -> None: dry_run=global_config.dry_run, concurrency=global_config.concurrency, output_csv=rating_csv_path, + max_cached_beatmaps_count=global_config.max_cached_beatmaps_count, + additional_count=global_config.additional_count, ) await recalculate_rating(rating_config, rating_global_config) @@ -1214,6 +1324,8 @@ async def main() -> None: dry_run=global_config.dry_run, concurrency=global_config.concurrency, output_csv=perf_csv_path, + max_cached_beatmaps_count=global_config.max_cached_beatmaps_count, + additional_count=global_config.additional_count, ) await recalculate_performance(perf_config, perf_global_config) @@ -1231,6 +1343,8 @@ async def main() -> None: dry_run=global_config.dry_run, concurrency=global_config.concurrency, output_csv=lead_csv_path, + max_cached_beatmaps_count=global_config.max_cached_beatmaps_count, + additional_count=global_config.additional_count, ) await recalculate_leaderboard(lead_config, lead_global_config)