refactor(recalculate): optimize batch processing and use semaphore

This commit is contained in:
MingxuanGame
2025-08-25 14:03:10 +00:00
parent 900fa9b121
commit 46b60e555f

View File

@@ -33,6 +33,13 @@ from sqlalchemy.orm import joinedload
from sqlmodel import col, delete, select from sqlmodel import col, delete, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
SEMAPHORE = asyncio.Semaphore(50)
async def run_in_batches(coros, batch_size=200):
for i in range(0, len(coros), batch_size):
await asyncio.gather(*coros[i : i + batch_size])
async def recalculate(): async def recalculate():
async with AsyncSession(engine, autoflush=False) as session: async with AsyncSession(engine, autoflush=False) as session:
@@ -51,20 +58,26 @@ async def recalculate():
) )
) )
).all() ).all()
await asyncio.gather( await run_in_batches(
*[ [
_recalculate_pp(statistics.user_id, statistics.mode, session, fetcher, redis) _recalculate_pp(statistics.user_id, statistics.mode, session, fetcher, redis)
for statistics in statistics_list for statistics in statistics_list
] ],
batch_size=200,
) )
await asyncio.gather( await run_in_batches(
*[ [
_recalculate_best_score(statistics.user_id, statistics.mode, session) _recalculate_best_score(statistics.user_id, statistics.mode, session)
for statistics in statistics_list for statistics in statistics_list
] ],
batch_size=200,
) )
await session.commit() await session.commit()
await asyncio.gather(*[_recalculate_statistics(statistics, session) for statistics in statistics_list]) for statistics in statistics_list:
await session.refresh(statistics)
await run_in_batches(
[_recalculate_statistics(statistics, session) for statistics in statistics_list], batch_size=200
)
await session.commit() await session.commit()
logger.success(f"Recalculated for mode: {mode}, total users: {len(statistics_list)}") logger.success(f"Recalculated for mode: {mode}, total users: {len(statistics_list)}")
@@ -78,6 +91,7 @@ async def _recalculate_pp(
fetcher: Fetcher, fetcher: Fetcher,
redis: Redis, redis: Redis,
): ):
async with SEMAPHORE:
scores = ( scores = (
await session.exec( await session.exec(
select(Score).where( select(Score).where(
@@ -147,6 +161,7 @@ async def _recalculate_best_score(
gamemode: GameMode, gamemode: GameMode,
session: AsyncSession, session: AsyncSession,
): ):
async with SEMAPHORE:
beatmap_best_score: dict[int, list[BestScore]] = {} beatmap_best_score: dict[int, list[BestScore]] = {}
scores = ( scores = (
await session.exec( await session.exec(
@@ -196,7 +211,7 @@ async def _recalculate_best_score(
async def _recalculate_statistics(statistics: UserStatistics, session: AsyncSession): async def _recalculate_statistics(statistics: UserStatistics, session: AsyncSession):
await session.refresh(statistics) async with SEMAPHORE:
pp_sum = 0 pp_sum = 0
acc_sum = 0 acc_sum = 0
bps = await get_user_best_pp(session, statistics.user_id, statistics.mode) bps = await get_user_best_pp(session, statistics.user_id, statistics.mode)