refactor(recalculate): optimize batch processing and use semaphore

This commit is contained in:
MingxuanGame
2025-08-25 14:03:10 +00:00
parent 900fa9b121
commit 46b60e555f

View File

@@ -33,6 +33,13 @@ from sqlalchemy.orm import joinedload
from sqlmodel import col, delete, select from sqlmodel import col, delete, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
SEMAPHORE = asyncio.Semaphore(50)
async def run_in_batches(coros, batch_size=200):
for i in range(0, len(coros), batch_size):
await asyncio.gather(*coros[i : i + batch_size])
async def recalculate(): async def recalculate():
async with AsyncSession(engine, autoflush=False) as session: async with AsyncSession(engine, autoflush=False) as session:
@@ -51,20 +58,26 @@ async def recalculate():
) )
) )
).all() ).all()
await asyncio.gather( await run_in_batches(
*[ [
_recalculate_pp(statistics.user_id, statistics.mode, session, fetcher, redis) _recalculate_pp(statistics.user_id, statistics.mode, session, fetcher, redis)
for statistics in statistics_list for statistics in statistics_list
] ],
batch_size=200,
) )
await asyncio.gather( await run_in_batches(
*[ [
_recalculate_best_score(statistics.user_id, statistics.mode, session) _recalculate_best_score(statistics.user_id, statistics.mode, session)
for statistics in statistics_list for statistics in statistics_list
] ],
batch_size=200,
) )
await session.commit() await session.commit()
await asyncio.gather(*[_recalculate_statistics(statistics, session) for statistics in statistics_list]) for statistics in statistics_list:
await session.refresh(statistics)
await run_in_batches(
[_recalculate_statistics(statistics, session) for statistics in statistics_list], batch_size=200
)
await session.commit() await session.commit()
logger.success(f"Recalculated for mode: {mode}, total users: {len(statistics_list)}") logger.success(f"Recalculated for mode: {mode}, total users: {len(statistics_list)}")
@@ -78,68 +91,69 @@ async def _recalculate_pp(
fetcher: Fetcher, fetcher: Fetcher,
redis: Redis, redis: Redis,
): ):
scores = ( async with SEMAPHORE:
await session.exec( scores = (
select(Score).where( await session.exec(
Score.user_id == user_id, select(Score).where(
Score.gamemode == gamemode, Score.user_id == user_id,
col(Score.passed).is_(True), Score.gamemode == gamemode,
col(Score.passed).is_(True),
)
) )
) ).all()
).all() prev: dict[int, PPBestScore] = {}
prev: dict[int, PPBestScore] = {}
async def cal(score: Score): async def cal(score: Score):
time = 10 time = 10
beatmap_id = score.beatmap_id beatmap_id = score.beatmap_id
while time > 0: while time > 0:
try: try:
db_beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=beatmap_id) db_beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=beatmap_id)
except HTTPError: except HTTPError:
time -= 1 time -= 1
await asyncio.sleep(2) await asyncio.sleep(2)
continue continue
ranked = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp ranked = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp
if not ranked or not mods_can_get_pp(int(score.gamemode), score.mods): if not ranked or not mods_can_get_pp(int(score.gamemode), score.mods):
score.pp = 0 score.pp = 0
return
try:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
pp = await calculate_pp(score, beatmap_raw, session)
if pp == 0:
return return
if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp: try:
best_score = PPBestScore( beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
user_id=user_id, pp = await calculate_pp(score, beatmap_raw, session)
beatmap_id=beatmap_id, if pp == 0:
acc=score.accuracy, return
score_id=score.id, if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp:
pp=pp, best_score = PPBestScore(
gamemode=score.gamemode, user_id=user_id,
) beatmap_id=beatmap_id,
prev[score.beatmap_id] = best_score acc=score.accuracy,
return score_id=score.id,
except HTTPError: pp=pp,
time -= 1 gamemode=score.gamemode,
await asyncio.sleep(2) )
continue prev[score.beatmap_id] = best_score
except Exception: return
logger.exception(f"Error calculating pp for score {score.id} on beatmap {beatmap_id}") except HTTPError:
return time -= 1
if time <= 0: await asyncio.sleep(2)
logger.warning(f"Failed to fetch beatmap {beatmap_id} after 10 attempts, retrying later...") continue
return score except Exception:
logger.exception(f"Error calculating pp for score {score.id} on beatmap {beatmap_id}")
return
if time <= 0:
logger.warning(f"Failed to fetch beatmap {beatmap_id} after 10 attempts, retrying later...")
return score
while len(scores) > 0: while len(scores) > 0:
results = await asyncio.gather(*[cal(s) for s in scores]) results = await asyncio.gather(*[cal(s) for s in scores])
scores = [s for s in results if s is not None] scores = [s for s in results if s is not None]
if len(scores) == 0: if len(scores) == 0:
break break
await asyncio.sleep(30) await asyncio.sleep(30)
logger.info(f"Retry to calculate for {gamemode}, total: {len(scores)}") logger.info(f"Retry to calculate for {gamemode}, total: {len(scores)}")
for best_score in prev.values(): for best_score in prev.values():
session.add(best_score) session.add(best_score)
async def _recalculate_best_score( async def _recalculate_best_score(
@@ -147,144 +161,145 @@ async def _recalculate_best_score(
gamemode: GameMode, gamemode: GameMode,
session: AsyncSession, session: AsyncSession,
): ):
beatmap_best_score: dict[int, list[BestScore]] = {} async with SEMAPHORE:
scores = ( beatmap_best_score: dict[int, list[BestScore]] = {}
await session.exec( scores = (
select(Score).where( await session.exec(
Score.gamemode == gamemode, select(Score).where(
col(Score.passed).is_(True), Score.gamemode == gamemode,
Score.user_id == user_id, col(Score.passed).is_(True),
Score.user_id == user_id,
)
) )
) ).all()
).all() for score in scores:
for score in scores: if not (
if not ( (await score.awaitable_attrs.beatmap).beatmap_status.has_leaderboard()
(await score.awaitable_attrs.beatmap).beatmap_status.has_leaderboard() | settings.enable_all_beatmap_leaderboard
| settings.enable_all_beatmap_leaderboard ):
): continue
continue mod_for_save = mod_to_save(score.mods)
mod_for_save = mod_to_save(score.mods) bs = BestScore(
bs = BestScore( user_id=score.user_id,
user_id=score.user_id, score_id=score.id,
score_id=score.id, beatmap_id=score.beatmap_id,
beatmap_id=score.beatmap_id, gamemode=score.gamemode,
gamemode=score.gamemode, total_score=score.total_score,
total_score=score.total_score, mods=mod_for_save,
mods=mod_for_save, rank=score.rank,
rank=score.rank,
)
if score.beatmap_id not in beatmap_best_score:
beatmap_best_score[score.beatmap_id] = [bs]
else:
b = next(
(
s
for s in beatmap_best_score[score.beatmap_id]
if s.mods == mod_for_save and s.beatmap_id == score.beatmap_id
),
None,
) )
if b is None: if score.beatmap_id not in beatmap_best_score:
beatmap_best_score[score.beatmap_id].append(bs) beatmap_best_score[score.beatmap_id] = [bs]
elif score.total_score > b.total_score: else:
beatmap_best_score[score.beatmap_id].remove(b) b = next(
beatmap_best_score[score.beatmap_id].append(bs) (
s
for s in beatmap_best_score[score.beatmap_id]
if s.mods == mod_for_save and s.beatmap_id == score.beatmap_id
),
None,
)
if b is None:
beatmap_best_score[score.beatmap_id].append(bs)
elif score.total_score > b.total_score:
beatmap_best_score[score.beatmap_id].remove(b)
beatmap_best_score[score.beatmap_id].append(bs)
for best_score_in_beatmap in beatmap_best_score.values(): for best_score_in_beatmap in beatmap_best_score.values():
for score in best_score_in_beatmap: for score in best_score_in_beatmap:
session.add(score) session.add(score)
async def _recalculate_statistics(statistics: UserStatistics, session: AsyncSession): async def _recalculate_statistics(statistics: UserStatistics, session: AsyncSession):
await session.refresh(statistics) async with SEMAPHORE:
pp_sum = 0 pp_sum = 0
acc_sum = 0 acc_sum = 0
bps = await get_user_best_pp(session, statistics.user_id, statistics.mode) bps = await get_user_best_pp(session, statistics.user_id, statistics.mode)
for i, s in enumerate(bps): for i, s in enumerate(bps):
pp_sum += calculate_weighted_pp(s.pp, i) pp_sum += calculate_weighted_pp(s.pp, i)
acc_sum += calculate_weighted_acc(s.acc, i) acc_sum += calculate_weighted_acc(s.acc, i)
if len(bps): if len(bps):
# https://github.com/ppy/osu-queue-score-statistics/blob/c538ae/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/UserTotalPerformanceAggregateHelper.cs#L41-L45 # https://github.com/ppy/osu-queue-score-statistics/blob/c538ae/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/UserTotalPerformanceAggregateHelper.cs#L41-L45
acc_sum *= 100 / (20 * (1 - math.pow(0.95, len(bps)))) acc_sum *= 100 / (20 * (1 - math.pow(0.95, len(bps))))
acc_sum = clamp(acc_sum, 0.0, 100.0) acc_sum = clamp(acc_sum, 0.0, 100.0)
statistics.pp = pp_sum statistics.pp = pp_sum
statistics.hit_accuracy = acc_sum statistics.hit_accuracy = acc_sum
statistics.play_count = 0 statistics.play_count = 0
statistics.total_score = 0 statistics.total_score = 0
statistics.maximum_combo = 0 statistics.maximum_combo = 0
statistics.play_time = 0 statistics.play_time = 0
statistics.total_hits = 0 statistics.total_hits = 0
statistics.count_100 = 0 statistics.count_100 = 0
statistics.count_300 = 0 statistics.count_300 = 0
statistics.count_50 = 0 statistics.count_50 = 0
statistics.count_miss = 0 statistics.count_miss = 0
statistics.ranked_score = 0 statistics.ranked_score = 0
statistics.grade_ss = 0 statistics.grade_ss = 0
statistics.grade_ssh = 0 statistics.grade_ssh = 0
statistics.grade_s = 0 statistics.grade_s = 0
statistics.grade_sh = 0 statistics.grade_sh = 0
statistics.grade_a = 0 statistics.grade_a = 0
scores = ( scores = (
await session.exec( await session.exec(
select(Score) select(Score)
.where( .where(
Score.user_id == statistics.user_id, Score.user_id == statistics.user_id,
Score.gamemode == statistics.mode, Score.gamemode == statistics.mode,
)
.options(joinedload(Score.beatmap))
) )
.options(joinedload(Score.beatmap)) ).all()
)
).all()
cached_beatmap_best: dict[int, Score] = {} cached_beatmap_best: dict[int, Score] = {}
for score in scores: for score in scores:
beatmap: Beatmap = score.beatmap beatmap: Beatmap = score.beatmap
ranked = beatmap.beatmap_status.has_pp() | settings.enable_all_mods_pp ranked = beatmap.beatmap_status.has_pp() | settings.enable_all_mods_pp
statistics.play_count += 1 statistics.play_count += 1
statistics.total_score += score.total_score statistics.total_score += score.total_score
playtime, is_valid = calculate_playtime(score, beatmap.hit_length) playtime, is_valid = calculate_playtime(score, beatmap.hit_length)
if is_valid: if is_valid:
statistics.play_time += playtime statistics.play_time += playtime
statistics.count_300 += score.n300 + score.ngeki statistics.count_300 += score.n300 + score.ngeki
statistics.count_100 += score.n100 + score.nkatu statistics.count_100 += score.n100 + score.nkatu
statistics.count_50 += score.n50 statistics.count_50 += score.n50
statistics.count_miss += score.nmiss statistics.count_miss += score.nmiss
statistics.total_hits += score.n300 + score.ngeki + score.n100 + score.nkatu + score.n50 statistics.total_hits += score.n300 + score.ngeki + score.n100 + score.nkatu + score.n50
if ranked and score.passed: if ranked and score.passed:
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
previous = cached_beatmap_best.get(score.beatmap_id) previous = cached_beatmap_best.get(score.beatmap_id)
difference = score.total_score - (previous.total_score if previous else 0) difference = score.total_score - (previous.total_score if previous else 0)
if difference > 0: if difference > 0:
cached_beatmap_best[score.beatmap_id] = score cached_beatmap_best[score.beatmap_id] = score
statistics.ranked_score += difference statistics.ranked_score += difference
match score.rank: match score.rank:
case Rank.X:
statistics.grade_ss += 1
case Rank.XH:
statistics.grade_ssh += 1
case Rank.S:
statistics.grade_s += 1
case Rank.SH:
statistics.grade_sh += 1
case Rank.A:
statistics.grade_a += 1
if previous is not None:
match previous.rank:
case Rank.X: case Rank.X:
statistics.grade_ss -= 1 statistics.grade_ss += 1
case Rank.XH: case Rank.XH:
statistics.grade_ssh -= 1 statistics.grade_ssh += 1
case Rank.S: case Rank.S:
statistics.grade_s -= 1 statistics.grade_s += 1
case Rank.SH: case Rank.SH:
statistics.grade_sh -= 1 statistics.grade_sh += 1
case Rank.A: case Rank.A:
statistics.grade_a -= 1 statistics.grade_a += 1
statistics.level_current = calculate_score_to_level(statistics.total_score) if previous is not None:
match previous.rank:
case Rank.X:
statistics.grade_ss -= 1
case Rank.XH:
statistics.grade_ssh -= 1
case Rank.S:
statistics.grade_s -= 1
case Rank.SH:
statistics.grade_sh -= 1
case Rank.A:
statistics.grade_a -= 1
statistics.level_current = calculate_score_to_level(statistics.total_score)
if __name__ == "__main__": if __name__ == "__main__":