feat(recalculate): add --additional-count & --max-cached-beatmaps-count to batch calculate
This commit is contained in:
@@ -257,3 +257,11 @@ async def calculate_beatmap_attributes(
|
||||
attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset)
|
||||
await redis.set(key, attr.model_dump_json())
|
||||
return attr
|
||||
|
||||
|
||||
async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []):
|
||||
if beatmaps:
|
||||
keys = [f"beatmap:{bid}:raw" for bid in beatmaps]
|
||||
await redis.delete(*keys)
|
||||
return
|
||||
await redis.delete("beatmap:*:raw")
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.calculators.performance import CalculateError
|
||||
from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import TotalScoreBestScore, UserStatistics
|
||||
from app.database.beatmap import Beatmap, calculate_beatmap_attributes
|
||||
from app.database.beatmap import Beatmap, calculate_beatmap_attributes, clear_cached_beatmap_raws
|
||||
from app.database.best_scores import BestScore
|
||||
from app.database.score import Score, calculate_playtime, calculate_user_pp
|
||||
from app.dependencies.database import engine, get_redis
|
||||
@@ -40,11 +40,72 @@ logger = log("Recalculate")
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class BeatmapCacheManager:
|
||||
"""管理beatmap缓存,确保不超过指定数量"""
|
||||
|
||||
def __init__(self, max_count: int, additional_count: int, redis: Redis):
|
||||
self.max_count = max_count
|
||||
self.additional_count = additional_count
|
||||
self.redis = redis
|
||||
self.beatmap_ids: list[int] = [] # 记录处理的beatmap id(按顺序)
|
||||
self.beatmap_id_set: set[int] = set() # 用于快速查找(唯一性)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def add_beatmap(self, beatmap_id: int) -> None:
|
||||
"""添加beatmap到缓存跟踪列表"""
|
||||
if self.max_count <= 0: # 不限制
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
# 如果已经存在,不重复添加
|
||||
if beatmap_id in self.beatmap_id_set:
|
||||
return
|
||||
|
||||
self.beatmap_ids.append(beatmap_id)
|
||||
self.beatmap_id_set.add(beatmap_id)
|
||||
|
||||
# 检查是否需要清理
|
||||
threshold = self.max_count + max(0, self.additional_count)
|
||||
if len(self.beatmap_ids) > threshold:
|
||||
# 计算需要删除的数量
|
||||
to_remove_count = max(1, self.additional_count)
|
||||
await self._cleanup(to_remove_count)
|
||||
|
||||
async def _cleanup(self, count: int) -> None:
|
||||
"""清理最早的count个beatmap缓存"""
|
||||
if count <= 0 or not self.beatmap_ids:
|
||||
return
|
||||
|
||||
# 获取要删除的beatmap ids
|
||||
to_remove = self.beatmap_ids[:count]
|
||||
self.beatmap_ids = self.beatmap_ids[count:]
|
||||
|
||||
# 从set中移除
|
||||
for bid in to_remove:
|
||||
self.beatmap_id_set.discard(bid)
|
||||
|
||||
# 从Redis中删除缓存
|
||||
await clear_cached_beatmap_raws(self.redis, to_remove)
|
||||
logger.info(f"Cleaned up {len(to_remove)} beatmap caches (total: {len(self.beatmap_ids)})")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
threshold = self.max_count + max(0, self.additional_count) if self.max_count > 0 else "unlimited"
|
||||
return {
|
||||
"total_beatmaps": len(self.beatmap_ids),
|
||||
"max_count": self.max_count,
|
||||
"additional_count": self.additional_count,
|
||||
"threshold": threshold,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GlobalConfig:
|
||||
dry_run: bool
|
||||
concurrency: int
|
||||
output_csv: str | None
|
||||
max_cached_beatmaps_count: int
|
||||
additional_count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -93,6 +154,20 @@ def parse_cli_args(
|
||||
type=str,
|
||||
help="Output results to a CSV file at the specified path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-cached-beatmaps-count",
|
||||
dest="max_cached_beatmaps_count",
|
||||
type=int,
|
||||
default=1500,
|
||||
help="Maximum number of beatmaps to cache (<=0 means no limit)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional-count",
|
||||
dest="additional_count",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of additional beatmaps before cleanup (<=0 means cleanup immediately)",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
@@ -194,6 +269,8 @@ def parse_cli_args(
|
||||
dry_run=args.dry_run,
|
||||
concurrency=max(1, args.concurrency),
|
||||
output_csv=args.output_csv,
|
||||
max_cached_beatmaps_count=args.max_cached_beatmaps_count,
|
||||
additional_count=args.additional_count,
|
||||
)
|
||||
|
||||
if args.command == "all":
|
||||
@@ -534,6 +611,7 @@ async def recalc_score_pp(
|
||||
fetcher: Fetcher,
|
||||
redis: Redis,
|
||||
score: Score,
|
||||
cache_manager: BeatmapCacheManager | None = None,
|
||||
) -> float | None:
|
||||
attempts = 10
|
||||
while attempts > 0:
|
||||
@@ -558,6 +636,9 @@ async def recalc_score_pp(
|
||||
|
||||
try:
|
||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, score.beatmap_id)
|
||||
# 记录使用的beatmap
|
||||
if cache_manager:
|
||||
await cache_manager.add_beatmap(score.beatmap_id)
|
||||
new_pp = await calculate_pp(score, beatmap_raw, session)
|
||||
score.pp = new_pp
|
||||
return new_pp
|
||||
@@ -728,6 +809,7 @@ async def recalculate_user_mode_performance(
|
||||
fetcher: Fetcher,
|
||||
redis: Redis,
|
||||
semaphore: asyncio.Semaphore,
|
||||
cache_manager: BeatmapCacheManager | None = None,
|
||||
csv_writer: CSVWriter | None = None,
|
||||
) -> None:
|
||||
"""Recalculate performance points and best scores (without TotalScoreBestScore)."""
|
||||
@@ -767,7 +849,7 @@ async def recalculate_user_mode_performance(
|
||||
for score in passed_scores:
|
||||
if target_set and score.id not in target_set:
|
||||
continue
|
||||
result_pp = await recalc_score_pp(session, fetcher, redis, score)
|
||||
result_pp = await recalc_score_pp(session, fetcher, redis, score, cache_manager)
|
||||
if result_pp is None:
|
||||
failed += 1
|
||||
else:
|
||||
@@ -969,6 +1051,7 @@ async def recalculate_beatmap_rating(
|
||||
fetcher: Fetcher,
|
||||
redis: Redis,
|
||||
semaphore: asyncio.Semaphore,
|
||||
cache_manager: BeatmapCacheManager | None = None,
|
||||
csv_writer: CSVWriter | None = None,
|
||||
) -> None:
|
||||
"""Recalculate difficulty rating for a beatmap."""
|
||||
@@ -989,6 +1072,9 @@ async def recalculate_beatmap_rating(
|
||||
try:
|
||||
ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode
|
||||
attributes = await calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher)
|
||||
# 记录使用的beatmap
|
||||
if cache_manager:
|
||||
await cache_manager.add_beatmap(beatmap_id)
|
||||
beatmap.difficulty_rating = attributes.star_rating
|
||||
break
|
||||
except CalculateError as exc:
|
||||
@@ -1070,6 +1156,14 @@ async def recalculate_performance(
|
||||
logger.info("No targets matched the provided filters; nothing to recalculate")
|
||||
return
|
||||
|
||||
# 创建缓存管理器
|
||||
cache_manager = BeatmapCacheManager(
|
||||
max_count=global_config.max_cached_beatmaps_count,
|
||||
additional_count=global_config.additional_count,
|
||||
redis=redis,
|
||||
)
|
||||
logger.info(f"Beatmap cache manager initialized: {cache_manager.get_stats()}")
|
||||
|
||||
scope = "full" if config.recalculate_all else "filtered"
|
||||
logger.info(
|
||||
"Recalculating performance for {} user/mode pairs ({}) | dry-run={} | concurrency={}",
|
||||
@@ -1083,12 +1177,15 @@ async def recalculate_performance(
|
||||
semaphore = asyncio.Semaphore(global_config.concurrency)
|
||||
coroutines = [
|
||||
recalculate_user_mode_performance(
|
||||
user_id, mode, score_ids, global_config, fetcher, redis, semaphore, csv_writer
|
||||
user_id, mode, score_ids, global_config, fetcher, redis, semaphore, cache_manager, csv_writer
|
||||
)
|
||||
for (user_id, mode), score_ids in targets.items()
|
||||
]
|
||||
await run_in_batches(coroutines, global_config.concurrency)
|
||||
|
||||
# 显示最终统计
|
||||
logger.info(f"Beatmap cache final stats: {cache_manager.get_stats()}")
|
||||
|
||||
|
||||
async def recalculate_leaderboard(
|
||||
config: LeaderboardConfig,
|
||||
@@ -1146,6 +1243,14 @@ async def recalculate_rating(
|
||||
logger.info("No beatmaps matched the provided filters; nothing to recalculate")
|
||||
return
|
||||
|
||||
# 创建缓存管理器
|
||||
cache_manager = BeatmapCacheManager(
|
||||
max_count=global_config.max_cached_beatmaps_count,
|
||||
additional_count=global_config.additional_count,
|
||||
redis=redis,
|
||||
)
|
||||
logger.info(f"Beatmap cache manager initialized: {cache_manager.get_stats()}")
|
||||
|
||||
scope = "full" if config.recalculate_all else "filtered"
|
||||
logger.info(
|
||||
"Recalculating rating for {} beatmaps ({}) | dry-run={} | concurrency={}",
|
||||
@@ -1158,11 +1263,14 @@ async def recalculate_rating(
|
||||
async with CSVWriter(global_config.output_csv) as csv_writer:
|
||||
semaphore = asyncio.Semaphore(global_config.concurrency)
|
||||
coroutines = [
|
||||
recalculate_beatmap_rating(beatmap_id, global_config, fetcher, redis, semaphore, csv_writer)
|
||||
recalculate_beatmap_rating(beatmap_id, global_config, fetcher, redis, semaphore, cache_manager, csv_writer)
|
||||
for beatmap_id in beatmap_ids
|
||||
]
|
||||
await run_in_batches(coroutines, global_config.concurrency)
|
||||
|
||||
# 显示最终统计
|
||||
logger.info(f"Beatmap cache final stats: {cache_manager.get_stats()}")
|
||||
|
||||
|
||||
def _get_csv_path_for_subcommand(base_path: str | None, subcommand: str) -> str | None:
|
||||
"""Generate a CSV path with subcommand name inserted before extension."""
|
||||
@@ -1197,6 +1305,8 @@ async def main() -> None:
|
||||
dry_run=global_config.dry_run,
|
||||
concurrency=global_config.concurrency,
|
||||
output_csv=rating_csv_path,
|
||||
max_cached_beatmaps_count=global_config.max_cached_beatmaps_count,
|
||||
additional_count=global_config.additional_count,
|
||||
)
|
||||
await recalculate_rating(rating_config, rating_global_config)
|
||||
|
||||
@@ -1214,6 +1324,8 @@ async def main() -> None:
|
||||
dry_run=global_config.dry_run,
|
||||
concurrency=global_config.concurrency,
|
||||
output_csv=perf_csv_path,
|
||||
max_cached_beatmaps_count=global_config.max_cached_beatmaps_count,
|
||||
additional_count=global_config.additional_count,
|
||||
)
|
||||
await recalculate_performance(perf_config, perf_global_config)
|
||||
|
||||
@@ -1231,6 +1343,8 @@ async def main() -> None:
|
||||
dry_run=global_config.dry_run,
|
||||
concurrency=global_config.concurrency,
|
||||
output_csv=lead_csv_path,
|
||||
max_cached_beatmaps_count=global_config.max_cached_beatmaps_count,
|
||||
additional_count=global_config.additional_count,
|
||||
)
|
||||
await recalculate_leaderboard(lead_config, lead_global_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user