refactor(recalculate): optimize batch processing and use semaphore
This commit is contained in:
@@ -33,6 +33,13 @@ from sqlalchemy.orm import joinedload
|
|||||||
from sqlmodel import col, delete, select
|
from sqlmodel import col, delete, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
SEMAPHORE = asyncio.Semaphore(50)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_in_batches(coros, batch_size=200):
|
||||||
|
for i in range(0, len(coros), batch_size):
|
||||||
|
await asyncio.gather(*coros[i : i + batch_size])
|
||||||
|
|
||||||
|
|
||||||
async def recalculate():
|
async def recalculate():
|
||||||
async with AsyncSession(engine, autoflush=False) as session:
|
async with AsyncSession(engine, autoflush=False) as session:
|
||||||
@@ -51,20 +58,26 @@ async def recalculate():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
await asyncio.gather(
|
await run_in_batches(
|
||||||
*[
|
[
|
||||||
_recalculate_pp(statistics.user_id, statistics.mode, session, fetcher, redis)
|
_recalculate_pp(statistics.user_id, statistics.mode, session, fetcher, redis)
|
||||||
for statistics in statistics_list
|
for statistics in statistics_list
|
||||||
]
|
],
|
||||||
|
batch_size=200,
|
||||||
)
|
)
|
||||||
await asyncio.gather(
|
await run_in_batches(
|
||||||
*[
|
[
|
||||||
_recalculate_best_score(statistics.user_id, statistics.mode, session)
|
_recalculate_best_score(statistics.user_id, statistics.mode, session)
|
||||||
for statistics in statistics_list
|
for statistics in statistics_list
|
||||||
]
|
],
|
||||||
|
batch_size=200,
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await asyncio.gather(*[_recalculate_statistics(statistics, session) for statistics in statistics_list])
|
for statistics in statistics_list:
|
||||||
|
await session.refresh(statistics)
|
||||||
|
await run_in_batches(
|
||||||
|
[_recalculate_statistics(statistics, session) for statistics in statistics_list], batch_size=200
|
||||||
|
)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
logger.success(f"Recalculated for mode: {mode}, total users: {len(statistics_list)}")
|
logger.success(f"Recalculated for mode: {mode}, total users: {len(statistics_list)}")
|
||||||
@@ -78,6 +91,7 @@ async def _recalculate_pp(
|
|||||||
fetcher: Fetcher,
|
fetcher: Fetcher,
|
||||||
redis: Redis,
|
redis: Redis,
|
||||||
):
|
):
|
||||||
|
async with SEMAPHORE:
|
||||||
scores = (
|
scores = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
select(Score).where(
|
select(Score).where(
|
||||||
@@ -147,6 +161,7 @@ async def _recalculate_best_score(
|
|||||||
gamemode: GameMode,
|
gamemode: GameMode,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
):
|
):
|
||||||
|
async with SEMAPHORE:
|
||||||
beatmap_best_score: dict[int, list[BestScore]] = {}
|
beatmap_best_score: dict[int, list[BestScore]] = {}
|
||||||
scores = (
|
scores = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
@@ -196,7 +211,7 @@ async def _recalculate_best_score(
|
|||||||
|
|
||||||
|
|
||||||
async def _recalculate_statistics(statistics: UserStatistics, session: AsyncSession):
|
async def _recalculate_statistics(statistics: UserStatistics, session: AsyncSession):
|
||||||
await session.refresh(statistics)
|
async with SEMAPHORE:
|
||||||
pp_sum = 0
|
pp_sum = 0
|
||||||
acc_sum = 0
|
acc_sum = 0
|
||||||
bps = await get_user_best_pp(session, statistics.user_id, statistics.mode)
|
bps = await get_user_best_pp(session, statistics.user_id, statistics.mode)
|
||||||
|
|||||||
Reference in New Issue
Block a user