feat(recalculate): add --additional-count & --max-cached-beatmaps-count to batch calculate

This commit is contained in:
MingxuanGame
2025-11-08 18:31:00 +00:00
parent 5c2687e1e4
commit a46b17fce4
2 changed files with 126 additions and 4 deletions

View File

@@ -257,3 +257,11 @@ async def calculate_beatmap_attributes(
attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset) attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset)
await redis.set(key, attr.model_dump_json()) await redis.set(key, attr.model_dump_json())
return attr return attr
async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []):
if beatmaps:
keys = [f"beatmap:{bid}:raw" for bid in beatmaps]
await redis.delete(*keys)
return
await redis.delete("beatmap:*:raw")

View File

@@ -19,7 +19,7 @@ from app.calculators.performance import CalculateError
from app.config import settings from app.config import settings
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
from app.database import TotalScoreBestScore, UserStatistics from app.database import TotalScoreBestScore, UserStatistics
from app.database.beatmap import Beatmap, calculate_beatmap_attributes from app.database.beatmap import Beatmap, calculate_beatmap_attributes, clear_cached_beatmap_raws
from app.database.best_scores import BestScore from app.database.best_scores import BestScore
from app.database.score import Score, calculate_playtime, calculate_user_pp from app.database.score import Score, calculate_playtime, calculate_user_pp
from app.dependencies.database import engine, get_redis from app.dependencies.database import engine, get_redis
@@ -40,11 +40,72 @@ logger = log("Recalculate")
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
class BeatmapCacheManager:
"""管理beatmap缓存确保不超过指定数量"""
def __init__(self, max_count: int, additional_count: int, redis: Redis):
self.max_count = max_count
self.additional_count = additional_count
self.redis = redis
self.beatmap_ids: list[int] = [] # 记录处理的beatmap id按顺序
self.beatmap_id_set: set[int] = set() # 用于快速查找(唯一性)
self.lock = asyncio.Lock()
async def add_beatmap(self, beatmap_id: int) -> None:
"""添加beatmap到缓存跟踪列表"""
if self.max_count <= 0: # 不限制
return
async with self.lock:
# 如果已经存在,不重复添加
if beatmap_id in self.beatmap_id_set:
return
self.beatmap_ids.append(beatmap_id)
self.beatmap_id_set.add(beatmap_id)
# 检查是否需要清理
threshold = self.max_count + max(0, self.additional_count)
if len(self.beatmap_ids) > threshold:
# 计算需要删除的数量
to_remove_count = max(1, self.additional_count)
await self._cleanup(to_remove_count)
async def _cleanup(self, count: int) -> None:
"""清理最早的count个beatmap缓存"""
if count <= 0 or not self.beatmap_ids:
return
# 获取要删除的beatmap ids
to_remove = self.beatmap_ids[:count]
self.beatmap_ids = self.beatmap_ids[count:]
# 从set中移除
for bid in to_remove:
self.beatmap_id_set.discard(bid)
# 从Redis中删除缓存
await clear_cached_beatmap_raws(self.redis, to_remove)
logger.info(f"Cleaned up {len(to_remove)} beatmap caches (total: {len(self.beatmap_ids)})")
def get_stats(self) -> dict:
"""获取统计信息"""
threshold = self.max_count + max(0, self.additional_count) if self.max_count > 0 else "unlimited"
return {
"total_beatmaps": len(self.beatmap_ids),
"max_count": self.max_count,
"additional_count": self.additional_count,
"threshold": threshold,
}
@dataclass(frozen=True) @dataclass(frozen=True)
class GlobalConfig: class GlobalConfig:
dry_run: bool dry_run: bool
concurrency: int concurrency: int
output_csv: str | None output_csv: str | None
max_cached_beatmaps_count: int
additional_count: int
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -93,6 +154,20 @@ def parse_cli_args(
type=str, type=str,
help="Output results to a CSV file at the specified path", help="Output results to a CSV file at the specified path",
) )
parser.add_argument(
"--max-cached-beatmaps-count",
dest="max_cached_beatmaps_count",
type=int,
default=1500,
help="Maximum number of beatmaps to cache (<=0 means no limit)",
)
parser.add_argument(
"--additional-count",
dest="additional_count",
type=int,
default=100,
help="Number of additional beatmaps before cleanup (<=0 means cleanup immediately)",
)
subparsers = parser.add_subparsers(dest="command", help="Available commands") subparsers = parser.add_subparsers(dest="command", help="Available commands")
@@ -194,6 +269,8 @@ def parse_cli_args(
dry_run=args.dry_run, dry_run=args.dry_run,
concurrency=max(1, args.concurrency), concurrency=max(1, args.concurrency),
output_csv=args.output_csv, output_csv=args.output_csv,
max_cached_beatmaps_count=args.max_cached_beatmaps_count,
additional_count=args.additional_count,
) )
if args.command == "all": if args.command == "all":
@@ -534,6 +611,7 @@ async def recalc_score_pp(
fetcher: Fetcher, fetcher: Fetcher,
redis: Redis, redis: Redis,
score: Score, score: Score,
cache_manager: BeatmapCacheManager | None = None,
) -> float | None: ) -> float | None:
attempts = 10 attempts = 10
while attempts > 0: while attempts > 0:
@@ -558,6 +636,9 @@ async def recalc_score_pp(
try: try:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, score.beatmap_id) beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, score.beatmap_id)
# 记录使用的beatmap
if cache_manager:
await cache_manager.add_beatmap(score.beatmap_id)
new_pp = await calculate_pp(score, beatmap_raw, session) new_pp = await calculate_pp(score, beatmap_raw, session)
score.pp = new_pp score.pp = new_pp
return new_pp return new_pp
@@ -728,6 +809,7 @@ async def recalculate_user_mode_performance(
fetcher: Fetcher, fetcher: Fetcher,
redis: Redis, redis: Redis,
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
cache_manager: BeatmapCacheManager | None = None,
csv_writer: CSVWriter | None = None, csv_writer: CSVWriter | None = None,
) -> None: ) -> None:
"""Recalculate performance points and best scores (without TotalScoreBestScore).""" """Recalculate performance points and best scores (without TotalScoreBestScore)."""
@@ -767,7 +849,7 @@ async def recalculate_user_mode_performance(
for score in passed_scores: for score in passed_scores:
if target_set and score.id not in target_set: if target_set and score.id not in target_set:
continue continue
result_pp = await recalc_score_pp(session, fetcher, redis, score) result_pp = await recalc_score_pp(session, fetcher, redis, score, cache_manager)
if result_pp is None: if result_pp is None:
failed += 1 failed += 1
else: else:
@@ -969,6 +1051,7 @@ async def recalculate_beatmap_rating(
fetcher: Fetcher, fetcher: Fetcher,
redis: Redis, redis: Redis,
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
cache_manager: BeatmapCacheManager | None = None,
csv_writer: CSVWriter | None = None, csv_writer: CSVWriter | None = None,
) -> None: ) -> None:
"""Recalculate difficulty rating for a beatmap.""" """Recalculate difficulty rating for a beatmap."""
@@ -989,6 +1072,9 @@ async def recalculate_beatmap_rating(
try: try:
ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode
attributes = await calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher) attributes = await calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher)
# 记录使用的beatmap
if cache_manager:
await cache_manager.add_beatmap(beatmap_id)
beatmap.difficulty_rating = attributes.star_rating beatmap.difficulty_rating = attributes.star_rating
break break
except CalculateError as exc: except CalculateError as exc:
@@ -1070,6 +1156,14 @@ async def recalculate_performance(
logger.info("No targets matched the provided filters; nothing to recalculate") logger.info("No targets matched the provided filters; nothing to recalculate")
return return
# 创建缓存管理器
cache_manager = BeatmapCacheManager(
max_count=global_config.max_cached_beatmaps_count,
additional_count=global_config.additional_count,
redis=redis,
)
logger.info(f"Beatmap cache manager initialized: {cache_manager.get_stats()}")
scope = "full" if config.recalculate_all else "filtered" scope = "full" if config.recalculate_all else "filtered"
logger.info( logger.info(
"Recalculating performance for {} user/mode pairs ({}) | dry-run={} | concurrency={}", "Recalculating performance for {} user/mode pairs ({}) | dry-run={} | concurrency={}",
@@ -1083,12 +1177,15 @@ async def recalculate_performance(
semaphore = asyncio.Semaphore(global_config.concurrency) semaphore = asyncio.Semaphore(global_config.concurrency)
coroutines = [ coroutines = [
recalculate_user_mode_performance( recalculate_user_mode_performance(
user_id, mode, score_ids, global_config, fetcher, redis, semaphore, csv_writer user_id, mode, score_ids, global_config, fetcher, redis, semaphore, cache_manager, csv_writer
) )
for (user_id, mode), score_ids in targets.items() for (user_id, mode), score_ids in targets.items()
] ]
await run_in_batches(coroutines, global_config.concurrency) await run_in_batches(coroutines, global_config.concurrency)
# 显示最终统计
logger.info(f"Beatmap cache final stats: {cache_manager.get_stats()}")
async def recalculate_leaderboard( async def recalculate_leaderboard(
config: LeaderboardConfig, config: LeaderboardConfig,
@@ -1146,6 +1243,14 @@ async def recalculate_rating(
logger.info("No beatmaps matched the provided filters; nothing to recalculate") logger.info("No beatmaps matched the provided filters; nothing to recalculate")
return return
# 创建缓存管理器
cache_manager = BeatmapCacheManager(
max_count=global_config.max_cached_beatmaps_count,
additional_count=global_config.additional_count,
redis=redis,
)
logger.info(f"Beatmap cache manager initialized: {cache_manager.get_stats()}")
scope = "full" if config.recalculate_all else "filtered" scope = "full" if config.recalculate_all else "filtered"
logger.info( logger.info(
"Recalculating rating for {} beatmaps ({}) | dry-run={} | concurrency={}", "Recalculating rating for {} beatmaps ({}) | dry-run={} | concurrency={}",
@@ -1158,11 +1263,14 @@ async def recalculate_rating(
async with CSVWriter(global_config.output_csv) as csv_writer: async with CSVWriter(global_config.output_csv) as csv_writer:
semaphore = asyncio.Semaphore(global_config.concurrency) semaphore = asyncio.Semaphore(global_config.concurrency)
coroutines = [ coroutines = [
recalculate_beatmap_rating(beatmap_id, global_config, fetcher, redis, semaphore, csv_writer) recalculate_beatmap_rating(beatmap_id, global_config, fetcher, redis, semaphore, cache_manager, csv_writer)
for beatmap_id in beatmap_ids for beatmap_id in beatmap_ids
] ]
await run_in_batches(coroutines, global_config.concurrency) await run_in_batches(coroutines, global_config.concurrency)
# 显示最终统计
logger.info(f"Beatmap cache final stats: {cache_manager.get_stats()}")
def _get_csv_path_for_subcommand(base_path: str | None, subcommand: str) -> str | None: def _get_csv_path_for_subcommand(base_path: str | None, subcommand: str) -> str | None:
"""Generate a CSV path with subcommand name inserted before extension.""" """Generate a CSV path with subcommand name inserted before extension."""
@@ -1197,6 +1305,8 @@ async def main() -> None:
dry_run=global_config.dry_run, dry_run=global_config.dry_run,
concurrency=global_config.concurrency, concurrency=global_config.concurrency,
output_csv=rating_csv_path, output_csv=rating_csv_path,
max_cached_beatmaps_count=global_config.max_cached_beatmaps_count,
additional_count=global_config.additional_count,
) )
await recalculate_rating(rating_config, rating_global_config) await recalculate_rating(rating_config, rating_global_config)
@@ -1214,6 +1324,8 @@ async def main() -> None:
dry_run=global_config.dry_run, dry_run=global_config.dry_run,
concurrency=global_config.concurrency, concurrency=global_config.concurrency,
output_csv=perf_csv_path, output_csv=perf_csv_path,
max_cached_beatmaps_count=global_config.max_cached_beatmaps_count,
additional_count=global_config.additional_count,
) )
await recalculate_performance(perf_config, perf_global_config) await recalculate_performance(perf_config, perf_global_config)
@@ -1231,6 +1343,8 @@ async def main() -> None:
dry_run=global_config.dry_run, dry_run=global_config.dry_run,
concurrency=global_config.concurrency, concurrency=global_config.concurrency,
output_csv=lead_csv_path, output_csv=lead_csv_path,
max_cached_beatmaps_count=global_config.max_cached_beatmaps_count,
additional_count=global_config.additional_count,
) )
await recalculate_leaderboard(lead_config, lead_global_config) await recalculate_leaderboard(lead_config, lead_global_config)