From 46b60e555f5119d9b0acaed31bee9bceb7c4fa96 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Mon, 25 Aug 2025 14:03:10 +0000 Subject: [PATCH] refactor(recalculate): optimize batch processing and use semaphore --- tools/recalculate.py | 383 ++++++++++++++++++++++--------------------- 1 file changed, 199 insertions(+), 184 deletions(-) diff --git a/tools/recalculate.py b/tools/recalculate.py index bd897e8..ddbe4e7 100644 --- a/tools/recalculate.py +++ b/tools/recalculate.py @@ -33,6 +33,13 @@ from sqlalchemy.orm import joinedload from sqlmodel import col, delete, select from sqlmodel.ext.asyncio.session import AsyncSession +SEMAPHORE = asyncio.Semaphore(50) + + +async def run_in_batches(coros, batch_size=200): + for i in range(0, len(coros), batch_size): + await asyncio.gather(*coros[i : i + batch_size]) + async def recalculate(): async with AsyncSession(engine, autoflush=False) as session: @@ -51,20 +58,26 @@ async def recalculate(): ) ) ).all() - await asyncio.gather( - *[ + await run_in_batches( + [ _recalculate_pp(statistics.user_id, statistics.mode, session, fetcher, redis) for statistics in statistics_list - ] + ], + batch_size=200, ) - await asyncio.gather( - *[ + await run_in_batches( + [ _recalculate_best_score(statistics.user_id, statistics.mode, session) for statistics in statistics_list - ] + ], + batch_size=200, ) await session.commit() - await asyncio.gather(*[_recalculate_statistics(statistics, session) for statistics in statistics_list]) + for statistics in statistics_list: + await session.refresh(statistics) + await run_in_batches( + [_recalculate_statistics(statistics, session) for statistics in statistics_list], batch_size=200 + ) await session.commit() logger.success(f"Recalculated for mode: {mode}, total users: {len(statistics_list)}") @@ -78,68 +91,69 @@ async def _recalculate_pp( fetcher: Fetcher, redis: Redis, ): - scores = ( - await session.exec( - select(Score).where( - Score.user_id == user_id, - Score.gamemode == gamemode, - col(Score.passed).is_(True), + async with SEMAPHORE: + scores = ( + await session.exec( + select(Score).where( + Score.user_id == user_id, + Score.gamemode == gamemode, + col(Score.passed).is_(True), + ) ) - ) - ).all() - prev: dict[int, PPBestScore] = {} + ).all() + prev: dict[int, PPBestScore] = {} - async def cal(score: Score): - time = 10 - beatmap_id = score.beatmap_id - while time > 0: - try: - db_beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=beatmap_id) - except HTTPError: - time -= 1 - await asyncio.sleep(2) - continue - ranked = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp - if not ranked or not mods_can_get_pp(int(score.gamemode), score.mods): - score.pp = 0 - return - try: - beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) - pp = await calculate_pp(score, beatmap_raw, session) - if pp == 0: + async def cal(score: Score): + time = 10 + beatmap_id = score.beatmap_id + while time > 0: + try: + db_beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=beatmap_id) + except HTTPError: + time -= 1 + await asyncio.sleep(2) + continue + ranked = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp + if not ranked or not mods_can_get_pp(int(score.gamemode), score.mods): + score.pp = 0 return - if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp: - best_score = PPBestScore( - user_id=user_id, - beatmap_id=beatmap_id, - acc=score.accuracy, - score_id=score.id, - pp=pp, - gamemode=score.gamemode, - ) - prev[score.beatmap_id] = best_score - return - except HTTPError: - time -= 1 - await asyncio.sleep(2) - continue - except Exception: - logger.exception(f"Error calculating pp for score {score.id} on beatmap {beatmap_id}") - return - if time <= 0: - logger.warning(f"Failed to fetch beatmap {beatmap_id} after 10 attempts, retrying later...") - return score + try: + beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) + pp = await calculate_pp(score, beatmap_raw, session) + if pp == 0: + return + if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp: + best_score = PPBestScore( + user_id=user_id, + beatmap_id=beatmap_id, + acc=score.accuracy, + score_id=score.id, + pp=pp, + gamemode=score.gamemode, + ) + prev[score.beatmap_id] = best_score + return + except HTTPError: + time -= 1 + await asyncio.sleep(2) + continue + except Exception: + logger.exception(f"Error calculating pp for score {score.id} on beatmap {beatmap_id}") + return + if time <= 0: + logger.warning(f"Failed to fetch beatmap {beatmap_id} after 10 attempts, retrying later...") + return score - while len(scores) > 0: - results = await asyncio.gather(*[cal(s) for s in scores]) - scores = [s for s in results if s is not None] - if len(scores) == 0: - break - await asyncio.sleep(30) - logger.info(f"Retry to calculate for {gamemode}, total: {len(scores)}") + while len(scores) > 0: + results = await asyncio.gather(*[cal(s) for s in scores]) + scores = [s for s in results if s is not None] + if len(scores) == 0: + break + await asyncio.sleep(30) + logger.info(f"Retry to calculate for {gamemode}, total: {len(scores)}") - for best_score in prev.values(): - session.add(best_score) + for best_score in prev.values(): + session.add(best_score) async def _recalculate_best_score( @@ -147,144 +161,145 @@ async def _recalculate_best_score( gamemode: GameMode, session: AsyncSession, ): - beatmap_best_score: dict[int, list[BestScore]] = {} - scores = ( - await session.exec( - select(Score).where( - Score.gamemode == gamemode, - col(Score.passed).is_(True), - Score.user_id == user_id, + async with SEMAPHORE: + beatmap_best_score: dict[int, list[BestScore]] = {} + scores = ( + await session.exec( + select(Score).where( + Score.gamemode == gamemode, + col(Score.passed).is_(True), + Score.user_id == user_id, + ) ) - ) - ).all() - for score in scores: - if not ( - (await score.awaitable_attrs.beatmap).beatmap_status.has_leaderboard() - | settings.enable_all_beatmap_leaderboard - ): - continue - mod_for_save = mod_to_save(score.mods) - bs = BestScore( - user_id=score.user_id, - score_id=score.id, - beatmap_id=score.beatmap_id, - gamemode=score.gamemode, - total_score=score.total_score, - mods=mod_for_save, - rank=score.rank, - ) - if score.beatmap_id not in beatmap_best_score: - beatmap_best_score[score.beatmap_id] = [bs] - else: - b = next( - ( - s - for s in beatmap_best_score[score.beatmap_id] - if s.mods == mod_for_save and s.beatmap_id == score.beatmap_id - ), - None, + ).all() + for score in scores: + if not ( + (await score.awaitable_attrs.beatmap).beatmap_status.has_leaderboard() + | settings.enable_all_beatmap_leaderboard + ): + continue + mod_for_save = mod_to_save(score.mods) + bs = BestScore( + user_id=score.user_id, + score_id=score.id, + beatmap_id=score.beatmap_id, + gamemode=score.gamemode, + total_score=score.total_score, + mods=mod_for_save, + rank=score.rank, ) - if b is None: - beatmap_best_score[score.beatmap_id].append(bs) - elif score.total_score > b.total_score: - beatmap_best_score[score.beatmap_id].remove(b) - beatmap_best_score[score.beatmap_id].append(bs) + if score.beatmap_id not in beatmap_best_score: + beatmap_best_score[score.beatmap_id] = [bs] + else: + b = next( + ( + s + for s in beatmap_best_score[score.beatmap_id] + if s.mods == mod_for_save and s.beatmap_id == score.beatmap_id + ), + None, + ) + if b is None: + beatmap_best_score[score.beatmap_id].append(bs) + elif score.total_score > b.total_score: + beatmap_best_score[score.beatmap_id].remove(b) + beatmap_best_score[score.beatmap_id].append(bs) - for best_score_in_beatmap in beatmap_best_score.values(): - for score in best_score_in_beatmap: - session.add(score) + for best_score_in_beatmap in beatmap_best_score.values(): + for score in best_score_in_beatmap: + session.add(score) async def _recalculate_statistics(statistics: UserStatistics, session: AsyncSession): - await session.refresh(statistics) - pp_sum = 0 - acc_sum = 0 - bps = await get_user_best_pp(session, statistics.user_id, statistics.mode) - for i, s in enumerate(bps): - pp_sum += calculate_weighted_pp(s.pp, i) - acc_sum += calculate_weighted_acc(s.acc, i) - if len(bps): - # https://github.com/ppy/osu-queue-score-statistics/blob/c538ae/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/UserTotalPerformanceAggregateHelper.cs#L41-L45 - acc_sum *= 100 / (20 * (1 - math.pow(0.95, len(bps)))) - acc_sum = clamp(acc_sum, 0.0, 100.0) - statistics.pp = pp_sum - statistics.hit_accuracy = acc_sum + async with SEMAPHORE: + pp_sum = 0 + acc_sum = 0 + bps = await get_user_best_pp(session, statistics.user_id, statistics.mode) + for i, s in enumerate(bps): + pp_sum += calculate_weighted_pp(s.pp, i) + acc_sum += calculate_weighted_acc(s.acc, i) + if len(bps): + # https://github.com/ppy/osu-queue-score-statistics/blob/c538ae/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/UserTotalPerformanceAggregateHelper.cs#L41-L45 + acc_sum *= 100 / (20 * (1 - math.pow(0.95, len(bps)))) + acc_sum = clamp(acc_sum, 0.0, 100.0) + statistics.pp = pp_sum + statistics.hit_accuracy = acc_sum - statistics.play_count = 0 - statistics.total_score = 0 - statistics.maximum_combo = 0 - statistics.play_time = 0 - statistics.total_hits = 0 - statistics.count_100 = 0 - statistics.count_300 = 0 - statistics.count_50 = 0 - statistics.count_miss = 0 - statistics.ranked_score = 0 - statistics.grade_ss = 0 - statistics.grade_ssh = 0 - statistics.grade_s = 0 - statistics.grade_sh = 0 - statistics.grade_a = 0 + statistics.play_count = 0 + statistics.total_score = 0 + statistics.maximum_combo = 0 + statistics.play_time = 0 + statistics.total_hits = 0 + statistics.count_100 = 0 + statistics.count_300 = 0 + statistics.count_50 = 0 + statistics.count_miss = 0 + statistics.ranked_score = 0 + statistics.grade_ss = 0 + statistics.grade_ssh = 0 + statistics.grade_s = 0 + statistics.grade_sh = 0 + statistics.grade_a = 0 - scores = ( - await session.exec( - select(Score) - .where( - Score.user_id == statistics.user_id, - Score.gamemode == statistics.mode, + scores = ( + await session.exec( + select(Score) + .where( + Score.user_id == statistics.user_id, + Score.gamemode == statistics.mode, + ) + .options(joinedload(Score.beatmap)) ) - .options(joinedload(Score.beatmap)) - ) - ).all() + ).all() - cached_beatmap_best: dict[int, Score] = {} + cached_beatmap_best: dict[int, Score] = {} - for score in scores: - beatmap: Beatmap = score.beatmap - ranked = beatmap.beatmap_status.has_pp() | settings.enable_all_mods_pp + for score in scores: + beatmap: Beatmap = score.beatmap + ranked = beatmap.beatmap_status.has_pp() | settings.enable_all_mods_pp - statistics.play_count += 1 - statistics.total_score += score.total_score - playtime, is_valid = calculate_playtime(score, beatmap.hit_length) - if is_valid: - statistics.play_time += playtime - statistics.count_300 += score.n300 + score.ngeki - statistics.count_100 += score.n100 + score.nkatu - statistics.count_50 += score.n50 - statistics.count_miss += score.nmiss - statistics.total_hits += score.n300 + score.ngeki + score.n100 + score.nkatu + score.n50 + statistics.play_count += 1 + statistics.total_score += score.total_score + playtime, is_valid = calculate_playtime(score, beatmap.hit_length) + if is_valid: + statistics.play_time += playtime + statistics.count_300 += score.n300 + score.ngeki + statistics.count_100 += score.n100 + score.nkatu + statistics.count_50 += score.n50 + statistics.count_miss += score.nmiss + statistics.total_hits += score.n300 + score.ngeki + score.n100 + score.nkatu + score.n50 - if ranked and score.passed: - statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) - previous = cached_beatmap_best.get(score.beatmap_id) - difference = score.total_score - (previous.total_score if previous else 0) - if difference > 0: - cached_beatmap_best[score.beatmap_id] = score - statistics.ranked_score += difference - match score.rank: - case Rank.X: - statistics.grade_ss += 1 - case Rank.XH: - statistics.grade_ssh += 1 - case Rank.S: - statistics.grade_s += 1 - case Rank.SH: - statistics.grade_sh += 1 - case Rank.A: - statistics.grade_a += 1 - if previous is not None: - match previous.rank: + if ranked and score.passed: + statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) + previous = cached_beatmap_best.get(score.beatmap_id) + difference = score.total_score - (previous.total_score if previous else 0) + if difference > 0: + cached_beatmap_best[score.beatmap_id] = score + statistics.ranked_score += difference + match score.rank: case Rank.X: - statistics.grade_ss -= 1 + statistics.grade_ss += 1 case Rank.XH: - statistics.grade_ssh -= 1 + statistics.grade_ssh += 1 case Rank.S: - statistics.grade_s -= 1 + statistics.grade_s += 1 case Rank.SH: - statistics.grade_sh -= 1 + statistics.grade_sh += 1 case Rank.A: - statistics.grade_a -= 1 - statistics.level_current = calculate_score_to_level(statistics.total_score) + statistics.grade_a += 1 + if previous is not None: + match previous.rank: + case Rank.X: + statistics.grade_ss -= 1 + case Rank.XH: + statistics.grade_ssh -= 1 + case Rank.S: + statistics.grade_s -= 1 + case Rank.SH: + statistics.grade_sh -= 1 + case Rank.A: + statistics.grade_a -= 1 + statistics.level_current = calculate_score_to_level(statistics.total_score) if __name__ == "__main__":