feat(recalculate): add --additional-count & --max-cached-beatmaps-count to batch calculate
This commit is contained in:
@@ -257,3 +257,11 @@ async def calculate_beatmap_attributes(
|
|||||||
attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset)
|
attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset)
|
||||||
await redis.set(key, attr.model_dump_json())
|
await redis.set(key, attr.model_dump_json())
|
||||||
return attr
|
return attr
|
||||||
|
|
||||||
|
|
||||||
|
async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []):
|
||||||
|
if beatmaps:
|
||||||
|
keys = [f"beatmap:{bid}:raw" for bid in beatmaps]
|
||||||
|
await redis.delete(*keys)
|
||||||
|
return
|
||||||
|
await redis.delete("beatmap:*:raw")
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from app.calculators.performance import CalculateError
|
|||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.const import BANCHOBOT_ID
|
from app.const import BANCHOBOT_ID
|
||||||
from app.database import TotalScoreBestScore, UserStatistics
|
from app.database import TotalScoreBestScore, UserStatistics
|
||||||
from app.database.beatmap import Beatmap, calculate_beatmap_attributes
|
from app.database.beatmap import Beatmap, calculate_beatmap_attributes, clear_cached_beatmap_raws
|
||||||
from app.database.best_scores import BestScore
|
from app.database.best_scores import BestScore
|
||||||
from app.database.score import Score, calculate_playtime, calculate_user_pp
|
from app.database.score import Score, calculate_playtime, calculate_user_pp
|
||||||
from app.dependencies.database import engine, get_redis
|
from app.dependencies.database import engine, get_redis
|
||||||
@@ -40,11 +40,72 @@ logger = log("Recalculate")
|
|||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
|
class BeatmapCacheManager:
|
||||||
|
"""管理beatmap缓存,确保不超过指定数量"""
|
||||||
|
|
||||||
|
def __init__(self, max_count: int, additional_count: int, redis: Redis):
|
||||||
|
self.max_count = max_count
|
||||||
|
self.additional_count = additional_count
|
||||||
|
self.redis = redis
|
||||||
|
self.beatmap_ids: list[int] = [] # 记录处理的beatmap id(按顺序)
|
||||||
|
self.beatmap_id_set: set[int] = set() # 用于快速查找(唯一性)
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def add_beatmap(self, beatmap_id: int) -> None:
|
||||||
|
"""添加beatmap到缓存跟踪列表"""
|
||||||
|
if self.max_count <= 0: # 不限制
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
# 如果已经存在,不重复添加
|
||||||
|
if beatmap_id in self.beatmap_id_set:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.beatmap_ids.append(beatmap_id)
|
||||||
|
self.beatmap_id_set.add(beatmap_id)
|
||||||
|
|
||||||
|
# 检查是否需要清理
|
||||||
|
threshold = self.max_count + max(0, self.additional_count)
|
||||||
|
if len(self.beatmap_ids) > threshold:
|
||||||
|
# 计算需要删除的数量
|
||||||
|
to_remove_count = max(1, self.additional_count)
|
||||||
|
await self._cleanup(to_remove_count)
|
||||||
|
|
||||||
|
async def _cleanup(self, count: int) -> None:
|
||||||
|
"""清理最早的count个beatmap缓存"""
|
||||||
|
if count <= 0 or not self.beatmap_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取要删除的beatmap ids
|
||||||
|
to_remove = self.beatmap_ids[:count]
|
||||||
|
self.beatmap_ids = self.beatmap_ids[count:]
|
||||||
|
|
||||||
|
# 从set中移除
|
||||||
|
for bid in to_remove:
|
||||||
|
self.beatmap_id_set.discard(bid)
|
||||||
|
|
||||||
|
# 从Redis中删除缓存
|
||||||
|
await clear_cached_beatmap_raws(self.redis, to_remove)
|
||||||
|
logger.info(f"Cleaned up {len(to_remove)} beatmap caches (total: {len(self.beatmap_ids)})")
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""获取统计信息"""
|
||||||
|
threshold = self.max_count + max(0, self.additional_count) if self.max_count > 0 else "unlimited"
|
||||||
|
return {
|
||||||
|
"total_beatmaps": len(self.beatmap_ids),
|
||||||
|
"max_count": self.max_count,
|
||||||
|
"additional_count": self.additional_count,
|
||||||
|
"threshold": threshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GlobalConfig:
|
class GlobalConfig:
|
||||||
dry_run: bool
|
dry_run: bool
|
||||||
concurrency: int
|
concurrency: int
|
||||||
output_csv: str | None
|
output_csv: str | None
|
||||||
|
max_cached_beatmaps_count: int
|
||||||
|
additional_count: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -93,6 +154,20 @@ def parse_cli_args(
|
|||||||
type=str,
|
type=str,
|
||||||
help="Output results to a CSV file at the specified path",
|
help="Output results to a CSV file at the specified path",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-cached-beatmaps-count",
|
||||||
|
dest="max_cached_beatmaps_count",
|
||||||
|
type=int,
|
||||||
|
default=1500,
|
||||||
|
help="Maximum number of beatmaps to cache (<=0 means no limit)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--additional-count",
|
||||||
|
dest="additional_count",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Number of additional beatmaps before cleanup (<=0 means cleanup immediately)",
|
||||||
|
)
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
|
|
||||||
@@ -194,6 +269,8 @@ def parse_cli_args(
|
|||||||
dry_run=args.dry_run,
|
dry_run=args.dry_run,
|
||||||
concurrency=max(1, args.concurrency),
|
concurrency=max(1, args.concurrency),
|
||||||
output_csv=args.output_csv,
|
output_csv=args.output_csv,
|
||||||
|
max_cached_beatmaps_count=args.max_cached_beatmaps_count,
|
||||||
|
additional_count=args.additional_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.command == "all":
|
if args.command == "all":
|
||||||
@@ -534,6 +611,7 @@ async def recalc_score_pp(
|
|||||||
fetcher: Fetcher,
|
fetcher: Fetcher,
|
||||||
redis: Redis,
|
redis: Redis,
|
||||||
score: Score,
|
score: Score,
|
||||||
|
cache_manager: BeatmapCacheManager | None = None,
|
||||||
) -> float | None:
|
) -> float | None:
|
||||||
attempts = 10
|
attempts = 10
|
||||||
while attempts > 0:
|
while attempts > 0:
|
||||||
@@ -558,6 +636,9 @@ async def recalc_score_pp(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, score.beatmap_id)
|
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, score.beatmap_id)
|
||||||
|
# 记录使用的beatmap
|
||||||
|
if cache_manager:
|
||||||
|
await cache_manager.add_beatmap(score.beatmap_id)
|
||||||
new_pp = await calculate_pp(score, beatmap_raw, session)
|
new_pp = await calculate_pp(score, beatmap_raw, session)
|
||||||
score.pp = new_pp
|
score.pp = new_pp
|
||||||
return new_pp
|
return new_pp
|
||||||
@@ -728,6 +809,7 @@ async def recalculate_user_mode_performance(
|
|||||||
fetcher: Fetcher,
|
fetcher: Fetcher,
|
||||||
redis: Redis,
|
redis: Redis,
|
||||||
semaphore: asyncio.Semaphore,
|
semaphore: asyncio.Semaphore,
|
||||||
|
cache_manager: BeatmapCacheManager | None = None,
|
||||||
csv_writer: CSVWriter | None = None,
|
csv_writer: CSVWriter | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Recalculate performance points and best scores (without TotalScoreBestScore)."""
|
"""Recalculate performance points and best scores (without TotalScoreBestScore)."""
|
||||||
@@ -767,7 +849,7 @@ async def recalculate_user_mode_performance(
|
|||||||
for score in passed_scores:
|
for score in passed_scores:
|
||||||
if target_set and score.id not in target_set:
|
if target_set and score.id not in target_set:
|
||||||
continue
|
continue
|
||||||
result_pp = await recalc_score_pp(session, fetcher, redis, score)
|
result_pp = await recalc_score_pp(session, fetcher, redis, score, cache_manager)
|
||||||
if result_pp is None:
|
if result_pp is None:
|
||||||
failed += 1
|
failed += 1
|
||||||
else:
|
else:
|
||||||
@@ -969,6 +1051,7 @@ async def recalculate_beatmap_rating(
|
|||||||
fetcher: Fetcher,
|
fetcher: Fetcher,
|
||||||
redis: Redis,
|
redis: Redis,
|
||||||
semaphore: asyncio.Semaphore,
|
semaphore: asyncio.Semaphore,
|
||||||
|
cache_manager: BeatmapCacheManager | None = None,
|
||||||
csv_writer: CSVWriter | None = None,
|
csv_writer: CSVWriter | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Recalculate difficulty rating for a beatmap."""
|
"""Recalculate difficulty rating for a beatmap."""
|
||||||
@@ -989,6 +1072,9 @@ async def recalculate_beatmap_rating(
|
|||||||
try:
|
try:
|
||||||
ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode
|
ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode
|
||||||
attributes = await calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher)
|
attributes = await calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher)
|
||||||
|
# 记录使用的beatmap
|
||||||
|
if cache_manager:
|
||||||
|
await cache_manager.add_beatmap(beatmap_id)
|
||||||
beatmap.difficulty_rating = attributes.star_rating
|
beatmap.difficulty_rating = attributes.star_rating
|
||||||
break
|
break
|
||||||
except CalculateError as exc:
|
except CalculateError as exc:
|
||||||
@@ -1070,6 +1156,14 @@ async def recalculate_performance(
|
|||||||
logger.info("No targets matched the provided filters; nothing to recalculate")
|
logger.info("No targets matched the provided filters; nothing to recalculate")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 创建缓存管理器
|
||||||
|
cache_manager = BeatmapCacheManager(
|
||||||
|
max_count=global_config.max_cached_beatmaps_count,
|
||||||
|
additional_count=global_config.additional_count,
|
||||||
|
redis=redis,
|
||||||
|
)
|
||||||
|
logger.info(f"Beatmap cache manager initialized: {cache_manager.get_stats()}")
|
||||||
|
|
||||||
scope = "full" if config.recalculate_all else "filtered"
|
scope = "full" if config.recalculate_all else "filtered"
|
||||||
logger.info(
|
logger.info(
|
||||||
"Recalculating performance for {} user/mode pairs ({}) | dry-run={} | concurrency={}",
|
"Recalculating performance for {} user/mode pairs ({}) | dry-run={} | concurrency={}",
|
||||||
@@ -1083,12 +1177,15 @@ async def recalculate_performance(
|
|||||||
semaphore = asyncio.Semaphore(global_config.concurrency)
|
semaphore = asyncio.Semaphore(global_config.concurrency)
|
||||||
coroutines = [
|
coroutines = [
|
||||||
recalculate_user_mode_performance(
|
recalculate_user_mode_performance(
|
||||||
user_id, mode, score_ids, global_config, fetcher, redis, semaphore, csv_writer
|
user_id, mode, score_ids, global_config, fetcher, redis, semaphore, cache_manager, csv_writer
|
||||||
)
|
)
|
||||||
for (user_id, mode), score_ids in targets.items()
|
for (user_id, mode), score_ids in targets.items()
|
||||||
]
|
]
|
||||||
await run_in_batches(coroutines, global_config.concurrency)
|
await run_in_batches(coroutines, global_config.concurrency)
|
||||||
|
|
||||||
|
# 显示最终统计
|
||||||
|
logger.info(f"Beatmap cache final stats: {cache_manager.get_stats()}")
|
||||||
|
|
||||||
|
|
||||||
async def recalculate_leaderboard(
|
async def recalculate_leaderboard(
|
||||||
config: LeaderboardConfig,
|
config: LeaderboardConfig,
|
||||||
@@ -1146,6 +1243,14 @@ async def recalculate_rating(
|
|||||||
logger.info("No beatmaps matched the provided filters; nothing to recalculate")
|
logger.info("No beatmaps matched the provided filters; nothing to recalculate")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 创建缓存管理器
|
||||||
|
cache_manager = BeatmapCacheManager(
|
||||||
|
max_count=global_config.max_cached_beatmaps_count,
|
||||||
|
additional_count=global_config.additional_count,
|
||||||
|
redis=redis,
|
||||||
|
)
|
||||||
|
logger.info(f"Beatmap cache manager initialized: {cache_manager.get_stats()}")
|
||||||
|
|
||||||
scope = "full" if config.recalculate_all else "filtered"
|
scope = "full" if config.recalculate_all else "filtered"
|
||||||
logger.info(
|
logger.info(
|
||||||
"Recalculating rating for {} beatmaps ({}) | dry-run={} | concurrency={}",
|
"Recalculating rating for {} beatmaps ({}) | dry-run={} | concurrency={}",
|
||||||
@@ -1158,11 +1263,14 @@ async def recalculate_rating(
|
|||||||
async with CSVWriter(global_config.output_csv) as csv_writer:
|
async with CSVWriter(global_config.output_csv) as csv_writer:
|
||||||
semaphore = asyncio.Semaphore(global_config.concurrency)
|
semaphore = asyncio.Semaphore(global_config.concurrency)
|
||||||
coroutines = [
|
coroutines = [
|
||||||
recalculate_beatmap_rating(beatmap_id, global_config, fetcher, redis, semaphore, csv_writer)
|
recalculate_beatmap_rating(beatmap_id, global_config, fetcher, redis, semaphore, cache_manager, csv_writer)
|
||||||
for beatmap_id in beatmap_ids
|
for beatmap_id in beatmap_ids
|
||||||
]
|
]
|
||||||
await run_in_batches(coroutines, global_config.concurrency)
|
await run_in_batches(coroutines, global_config.concurrency)
|
||||||
|
|
||||||
|
# 显示最终统计
|
||||||
|
logger.info(f"Beatmap cache final stats: {cache_manager.get_stats()}")
|
||||||
|
|
||||||
|
|
||||||
def _get_csv_path_for_subcommand(base_path: str | None, subcommand: str) -> str | None:
|
def _get_csv_path_for_subcommand(base_path: str | None, subcommand: str) -> str | None:
|
||||||
"""Generate a CSV path with subcommand name inserted before extension."""
|
"""Generate a CSV path with subcommand name inserted before extension."""
|
||||||
@@ -1197,6 +1305,8 @@ async def main() -> None:
|
|||||||
dry_run=global_config.dry_run,
|
dry_run=global_config.dry_run,
|
||||||
concurrency=global_config.concurrency,
|
concurrency=global_config.concurrency,
|
||||||
output_csv=rating_csv_path,
|
output_csv=rating_csv_path,
|
||||||
|
max_cached_beatmaps_count=global_config.max_cached_beatmaps_count,
|
||||||
|
additional_count=global_config.additional_count,
|
||||||
)
|
)
|
||||||
await recalculate_rating(rating_config, rating_global_config)
|
await recalculate_rating(rating_config, rating_global_config)
|
||||||
|
|
||||||
@@ -1214,6 +1324,8 @@ async def main() -> None:
|
|||||||
dry_run=global_config.dry_run,
|
dry_run=global_config.dry_run,
|
||||||
concurrency=global_config.concurrency,
|
concurrency=global_config.concurrency,
|
||||||
output_csv=perf_csv_path,
|
output_csv=perf_csv_path,
|
||||||
|
max_cached_beatmaps_count=global_config.max_cached_beatmaps_count,
|
||||||
|
additional_count=global_config.additional_count,
|
||||||
)
|
)
|
||||||
await recalculate_performance(perf_config, perf_global_config)
|
await recalculate_performance(perf_config, perf_global_config)
|
||||||
|
|
||||||
@@ -1231,6 +1343,8 @@ async def main() -> None:
|
|||||||
dry_run=global_config.dry_run,
|
dry_run=global_config.dry_run,
|
||||||
concurrency=global_config.concurrency,
|
concurrency=global_config.concurrency,
|
||||||
output_csv=lead_csv_path,
|
output_csv=lead_csv_path,
|
||||||
|
max_cached_beatmaps_count=global_config.max_cached_beatmaps_count,
|
||||||
|
additional_count=global_config.additional_count,
|
||||||
)
|
)
|
||||||
await recalculate_leaderboard(lead_config, lead_global_config)
|
await recalculate_leaderboard(lead_config, lead_global_config)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user